PredictiveSSL

class selfeeg.ssl.predictive.PredictiveSSL(encoder: Module, head: list[int] | Module, return_logits: bool = True)[source]

Implementation of a standard predictive Pretraining. Contrary to contrastive, this pretraining performs a classification or regression task with a generated pseudo-label. A trivial example is the model trying to predict which random augmentation from a given set was applied to each sample of the batch.

Parameters:
  • encoder (nn.Module) – The encoder part of the module. It is the one you wish to pretrain and transfer to the new model.

  • head (Union[list[int], nn.Module]) –

    The predictive head to use. It can be:

    1. an nn.Module

    2. a list of ints.

    In case a list of ints is given, an nn.Sequential module with Dense, BatchNorm and Relu will be automtically created. The list will be used to set input and output dimension of each Dense Layer. For instance, if [128, 64, 2] is given, two hidden layers will be created. The first with input 128 and output 64, the second with input 64 and output 2.

  • return_logits (bool, optional) –

    Whether to return the output as logit or probability. It is suggested to not use False as the pytorch crossentropy loss function applies the softmax internally.

    Default = True

Warning

This class will not check the compatibility of the encoder’s output and the projection head’s input. Make sure that they have the same size.

Example

>>> import pickle, torch, selfeeg
>>> import selfeeg.dataloading as dl
>>> import selfeeg.augmentation as aug
>>> utils.create_dataset()
>>> def loadEEG(path, return_label=False):
...     with open(path, 'rb') as handle:
...         EEG = pickle.load(handle)
...     x , y= EEG['data'], EEG['label']
...     return (x, y) if return_label else x
>>> def loss_fineTuning(yhat, ytrue):
...     return F.binary_cross_entropy_with_logits(torch.squeeze(yhat), ytrue + 0.)
>>> EEGlen = dl.get_eeg_partition_number('Simulated_EEG',freq=128, window=1,
...                              overlap=0.3, load_function=loadEEG)
>>> EEGsplit = dl.get_eeg_split_table (EEGlen, seed=1234)
>>> TrainSet = dl.EEGDataset(EEGlen,EEGsplit,[128,1,0.3],'train',False,loadEEG)
>>> Loader = torch.utils.data.DataLoader(TrainSet, batch_size=32)
>>> enc = selfeeg.models.ShallowNetEncoder(8)
>>> pred = selfeeg.ssl.PredictiveSSL(enc, [16,16,2])
>>> loss_train = pred.fit(Loader, 1, augmenter=augment, return_loss_info=True)
fit(train_dataloader, epochs=1, optimizer=None, augmenter=None, loss_func: Callable = None, loss_args: list = [], lr_scheduler=None, EarlyStopper=None, validation_dataloader=None, augmenter_batch_calls=2, labels_on_dataloader=False, verbose=True, device: str = None, return_loss_info: bool = False)[source]

fit is a custom fit function designed to perform pretraining on a given model with the given dataloader.

Parameters:
  • train_dataloader (Dataloader) – The pytorch Dataloader used to get the training batches. It is supposed to return a batch with a single tensor X (no pseudo-labels), unless labels_on_dataloader is set to True.

  • epochs (int, optional) –

    The number of training epochs. Must be an integer bigger than 0.

    Default = 1

  • optimizer (torch Optimizer, optional) –

    The optimizer used for weight’s update. It can be any optimizer provided in the torch.optim module. If not given, Adam with default parameters will be instantiated.

    Default = torch.optim.Adam

  • augmenter (Callable, optional) –

    Any function (or callable object) used to perform data augmentation on the batch and generate the pseudo-labels (if not provided by the dataloader itself). Therefore, unless labels_on_dataloader is set to True, augmenter is expected to take in input a batch tensor X and return both the augmented version of X and the pseudo-label tensor Y. It is highly suggested to resort to the selfeeg’s augmentation module, which implements different data augmentation functions and classes to combine them. RandomAug, for example, can also return the index of the chosen augmentation to be used as a pseudo-label.

    Default = None

    Note

    This argument is optional because of the alternative way to provide pseudo-labels with the labels_on_dataloader argument, but in reality it must be given if the dataloader does not directly provide the pseudo-labels.

  • loss_func (Callable, optional) –

    The custom loss function. It can be any loss function that accepts as input only the model’s predictions as required arguments and loss_args as optional arguments. If not given, cross entroby loss will be automatically used.

    Default = None

  • loss_args (Union[list, dict], optional) –

    The optional arguments to pass to the function. It can be a list or a dict.

    Default = None

  • lr_scheduler (torch Scheduler) –

    A pytorch learning rate scheduler used to update the learning rate during the fine-tuning.

    Default = None

  • EarlyStopper (EarlyStopping, optional) –

    An instance of the provided EarlyStopping class.

    Default = None

    Note

    If an EarlyStopping instance is given with monitoring loss set to validation loss, but no validation dataloader is given, monitoring loss will be automatically set to training loss.

  • validation_dataloader (Dataloader, optional) –

    the pytorch Dataloader used to get the validation batches. It is supposed to return a batch with a single tensor X (no pseudo-labels), unless labels_on_dataloader is set to True. If not given, no validation loss will be calculated

    Default = None

  • augmenter_batch_calls (int, optional) –

    The number of times the augmenter is called for a single batch. Each call selects an equal portion of samples in the batch and gives it to the augmenter.

    Default = 2

    Note

    To better understand how this argument works, suppose to design a task where you want the model to predict which augmentation from a predefined set was performed on each sample from the batch. Selfeeg classes in the compose submodules operate at the batch level, but one might want to generate batches with multiple labels and not one with only a single label. augmenter_batch_calls solves this problem.

  • labels_on_dataloader (boolean, optional) –

    Set this to True if the dataloader already provides a set of pseudo-labels. If True augmenter and augmenter_batch_calls will be ignored.

    Note

    if you want to pretrain the model by simply solving another task and you need more functionalities, you can consider using the fine_tune function, which acts as a generic supervised training.

    Default = False

  • verbose (bool, optional) –

    Whether to print a progression bar or not.

    Default = None

  • device (torch.device or str, optional) –

    The device to use for fine-tuning. If given as a string it will be converted in a torch.device instance. If not given, ‘cpu’ device will be used.

    Default = None

  • return_loss_info (bool, optional) –

    Whether to return the calculated training validation losses at each epoch.

    Default = False

Returns:

loss_info (dict, optional) – A dictionary with keys being the epoch number (as integer) and values a two element list with the average epoch’s training and validation loss.

forward(x)[source]
test(test_dataloader, augmenter=None, loss_func=None, loss_args: list = [], augmenter_batch_calls=2, labels_on_dataloader=False, verbose: bool = True, device: str = None)[source]

Evaluate the loss on a test dataloader. Parameters are the same as described in the fit method, aside for those related to model training which are removed.

It is rare to evaluate the pretraing loss function on a test set. Nevertheless this function provides a way to do that. An example of usage could be to assess the quality of the learned features on the fine-tuning dataset.