fine_tune

selfeeg.ssl.base.fine_tune(model: Module, train_dataloader: DataLoader, epochs=1, optimizer=None, augmenter=None, loss_func: Callable = None, loss_args: list = [], validation_loss_func: Callable = None, validation_loss_args: list = [], label_encoder: Callable = None, lr_scheduler=None, EarlyStopper=None, validation_dataloader: DataLoader = None, verbose=True, device: str = None, return_loss_info: bool = False) dict | None[source]

performs fine-tuning of a given model.

Parameters:
  • model (nn.Module) – The pytorch model to fine tune. It must be a nn.Module.

  • train_dataloader (Dataloader) –

    The pytorch Dataloader used to get the training batches. The Dataloar must return a batch as a tuple (X, Y), with X the input tensor and Y the label tensor.

    Note

    from version 0.2.0 X and Y can also be lists of Tensors. This might be useful for multi-branch or multi-head models.

  • epochs (int, optional) –

    The number of training epochs. It must be an integer bigger than 0.

    Default = 1

  • optimizer (torch Optimizer, optional) –

    The optimizer used for weight’s update. It can be any optimizer provided in the torch.optim module. If not given, Adam with default parameters will be instantiated.

    Default = None

  • augmenter (Callable or list of Callables, optional) –

    Any function (or callable object) used to perform data augmentation on the batch. It is highly suggested to resort to the augmentation module, which implements different data augmentation functions and classes to combine them. Note that data augmentation is not performed on the validation set, since its goal is to increase the size of the training set and to get more different samples.

    Default = None

    Note

    from version 0.2.0 augmenter can be a list of Callables. This case is specific for scenarios when X is also a list of Tensors and you want to apply a specific augmentation for each of its elements. Augmentations are performed by using the command X[i] = augmenter[i](X[i]). It is possible to have len(augmenter)<len(X).

  • loss_func (Callable, optional) –

    The custom loss function. It can be any loss function which accepts as inputs the model’s prediction and the true labels as required arguments and loss_args as optional arguments.

    Default = None

  • loss_args (Union[list, dict], optional) –

    The optional arguments to pass to the loss function. It can be a list or a dict.

    Default = None

  • validation_loss_func (Callable, optional) –

    A custom validation loss function. It can be any loss function which accepts as inputs the model’s prediction and the true labels as required arguments, and loss_args as optional arguments. If None, loss_func will be used.

    Default = None

  • validation_loss_args (Union[list, dict], optional) –

    The optional arguments to pass to the validation loss function. It can be a list or a dict. If None, loss_args will be used.

    Default = None

  • label_encoder (callable of list of callables, optional) –

    A custom function used to encode the returned Dataloaders true labels. If None, the Dataloader’s true label is used directly. It can be any funtion which accept as input the batch label tensor Y.

    Note

    from version 0.2.0 label_encoder can be a list of Callables. This case is specific for scenarios when Y is also a list of Tensors and you want to apply a specific encoder for each of its elements. label encoding is performed with the command Y[i] = label_encoder[i](Y[i]). It is possible to have len(label_encoder)<len(Y).

    Default = None

  • lr_scheduler (torch Scheduler) –

    A pytorch learning rate scheduler used to update the learning rate during the fine-tuning.

    Default = None

  • EarlyStopper (EarlyStopping, optional) –

    An instance of the provided EarlyStopping class.

    Default = None

  • validation_dataloader (Dataloader, optional) –

    the pytorch Dataloader used to get the validation batches. It must return a batch as a tuple (X, Y), with X the feature tensor and Y the label tensor. If not given, no validation loss will be calculated

    Default = None

  • verbose (bool, optional) –

    Whether to print a progression bar or not.

    Default = None

  • device (torch.device or str, optional) –

    The device to use for fine-tuning. If given as a string it will be converted in a torch.device instance. If not given, ‘cpu’ device will be used.

    Default = None

  • return_loss_info (bool, optional) –

    Whether to return the calculated training validation losses at each epoch.

    Default = False

Returns:

loss_info (dict, optional) – A dictionary with keys being the epoch number (as integer) and values a two element list with the average epoch’s training and validation loss.

Note

If an EarlyStopping instance is given with monitoring loss set to validation loss, but no validation dataloader is given, monitoring loss will be automatically set to training loss.

Example

>>> import torch, pickle
>>> import selfeeg.dataloading as dl
>>> import selfeeg.models
>>> def loadEEG(path, return_label=False):
...     with open(path, 'rb') as handle:
...         EEG = pickle.load(handle)
...     x , y= EEG['data'], EEG['label']
...     return (x, y) if return_label else x
>>> def loss_fineTuning(yhat, ytrue):
...     return F.binary_cross_entropy_with_logits(torch.squeeze(yhat), ytrue + 0.)
>>> random.seed(1234)
>>> EEGlen = dl.get_eeg_partition_number(
...     'Simulated_EEG', 128, 2, 0.3, load_function=loadEEG)
>>> EEGsplit = dl.get_eeg_split_table (EEGlen, seed=1234)
>>> TrainSet = dl.EEGDataset(EEGlen,EEGsplit, [128,2,0.3], 'train', True, loadEEG,
...                          optional_load_fun_args=[True], label_on_load=True)
>>> TrainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=32)
>>> shanet= models.ShallowNet(2, 8, 256)
>>> loss_info = ssl.fine_tune(shanet, TrainLoader, loss_func=loss_fineTuning)