fine_tune
- selfeeg.ssl.base.fine_tune(model: Module, train_dataloader: DataLoader, epochs=1, optimizer=None, augmenter=None, loss_func: Callable = None, loss_args: list = [], validation_loss_func: Callable = None, validation_loss_args: list = [], label_encoder: Callable = None, lr_scheduler=None, EarlyStopper=None, validation_dataloader: DataLoader = None, verbose=True, device: str = None, return_loss_info: bool = False) dict | None[source]
performs fine-tuning of a given model.
- Parameters:
model (nn.Module) – The pytorch model to fine tune. It must be a nn.Module.
train_dataloader (Dataloader) –
The pytorch Dataloader used to get the training batches. The Dataloar must return a batch as a tuple (X, Y), with X the input tensor and Y the label tensor.
Note
from version 0.2.0 X and Y can also be lists of Tensors. This might be useful for multi-branch or multi-head models.
epochs (int, optional) –
The number of training epochs. It must be an integer bigger than 0.
Default = 1
optimizer (torch Optimizer, optional) –
The optimizer used for weight’s update. It can be any optimizer provided in the torch.optim module. If not given, Adam with default parameters will be instantiated.
Default = None
augmenter (Callable or list of Callables, optional) –
Any function (or callable object) used to perform data augmentation on the batch. It is highly suggested to resort to the augmentation module, which implements different data augmentation functions and classes to combine them. Note that data augmentation is not performed on the validation set, since its goal is to increase the size of the training set and to get more different samples.
Default = None
Note
from version 0.2.0 augmenter can be a list of Callables. This case is specific for scenarios when X is also a list of Tensors and you want to apply a specific augmentation for each of its elements. Augmentations are performed by using the command
X[i] = augmenter[i](X[i]). It is possible to havelen(augmenter)<len(X).loss_func (Callable, optional) –
The custom loss function. It can be any loss function which accepts as inputs the model’s prediction and the true labels as required arguments and loss_args as optional arguments.
Default = None
loss_args (Union[list, dict], optional) –
The optional arguments to pass to the loss function. It can be a list or a dict.
Default = None
validation_loss_func (Callable, optional) –
A custom validation loss function. It can be any loss function which accepts as inputs the model’s prediction and the true labels as required arguments, and loss_args as optional arguments. If None, loss_func will be used.
Default = None
validation_loss_args (Union[list, dict], optional) –
The optional arguments to pass to the validation loss function. It can be a list or a dict. If None, loss_args will be used.
Default = None
label_encoder (callable of list of callables, optional) –
A custom function used to encode the returned Dataloaders true labels. If None, the Dataloader’s true label is used directly. It can be any funtion which accept as input the batch label tensor Y.
Note
from version 0.2.0 label_encoder can be a list of Callables. This case is specific for scenarios when Y is also a list of Tensors and you want to apply a specific encoder for each of its elements. label encoding is performed with the command
Y[i] = label_encoder[i](Y[i]). It is possible to havelen(label_encoder)<len(Y).Default = None
lr_scheduler (torch Scheduler) –
A pytorch learning rate scheduler used to update the learning rate during the fine-tuning.
Default = None
EarlyStopper (EarlyStopping, optional) –
An instance of the provided EarlyStopping class.
Default = None
validation_dataloader (Dataloader, optional) –
the pytorch Dataloader used to get the validation batches. It must return a batch as a tuple (X, Y), with X the feature tensor and Y the label tensor. If not given, no validation loss will be calculated
Default = None
verbose (bool, optional) –
Whether to print a progression bar or not.
Default = None
device (torch.device or str, optional) –
The device to use for fine-tuning. If given as a string it will be converted in a torch.device instance. If not given, ‘cpu’ device will be used.
Default = None
return_loss_info (bool, optional) –
Whether to return the calculated training validation losses at each epoch.
Default = False
- Returns:
loss_info (dict, optional) – A dictionary with keys being the epoch number (as integer) and values a two element list with the average epoch’s training and validation loss.
Note
If an EarlyStopping instance is given with monitoring loss set to validation loss, but no validation dataloader is given, monitoring loss will be automatically set to training loss.
Example
>>> import torch, pickle >>> import selfeeg.dataloading as dl >>> import selfeeg.models >>> def loadEEG(path, return_label=False): ... with open(path, 'rb') as handle: ... EEG = pickle.load(handle) ... x , y= EEG['data'], EEG['label'] ... return (x, y) if return_label else x >>> def loss_fineTuning(yhat, ytrue): ... return F.binary_cross_entropy_with_logits(torch.squeeze(yhat), ytrue + 0.) >>> random.seed(1234) >>> EEGlen = dl.get_eeg_partition_number( ... 'Simulated_EEG', 128, 2, 0.3, load_function=loadEEG) >>> EEGsplit = dl.get_eeg_split_table (EEGlen, seed=1234) >>> TrainSet = dl.EEGDataset(EEGlen,EEGsplit, [128,2,0.3], 'train', True, loadEEG, ... optional_load_fun_args=[True], label_on_load=True) >>> TrainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=32) >>> shanet= models.ShallowNet(2, 8, 256) >>> loss_info = ssl.fine_tune(shanet, TrainLoader, loss_func=loss_fineTuning)