MoCo
- class selfeeg.ssl.contrastive.MoCo(encoder: Module, projection_head: list[int] | Module, predictor: list[int] | Module = None, feat_size: int = -1, bank_size: int = 0, m: float = 0.999)[source]
Implementation of the MoCo SSL method.
To check how MoCo works, read the following paper [moco21] [moco31] .
- Parameters:
encoder (nn.Module) – The encoder part of the module. It is the one you wish to pretrain and transfer to the new model.
projection_head (Union[list[int], nn.Module]) –
The projection head to use. It can be:
an nn.Module
a list of ints.
In case a list is given, a nn.Sequential module with Dense, BatchNorm and Relu will be automtically created. The list will be used to set input and output dimension of each Dense Layer. For instance, if [64, 128, 64] is given, two hidden layers will be created. The first with input 64 and output 128, the second with input 128 and output 64.
predictor (Union[list[int], nn.Module], optional) –
The predictor to put after the projection head. Use it with 0 bank size to set Moco v3. Accepted arguments are the same as for the projection_head.
Default = None
feat_size (int, optional) –
The size of the feature vector (projector’s output last dim shape). It will be used to initialize the queue for MoCo v2. If not given the last element of the projection_head list is used. It must be given if a custom projection head is used.
Default = -1
bank_size (int, optional) –
The size of the queue, i.e. the number of projection to keep memory. If not given, fit will trigger the calculation of the MoCo v3 loss.
Default = 0
m (float, optional) –
The value of the momentum coefficient. Suggested values are in the range [0.9960, 0.9999].
Default = 0.999
Warning
This class will not check the compatibility of the encoder’s output and the projection head’s input (as well as between the projection head and the predictor). Make sure that they have the same size.
Warning
Using ADAM optimizer with MoCo v2 (with bank size) can prevent the training loss from decreasing. SGD is highly suggested.
References
[moco21]K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick, “Momentum contrast for unsupervised visual representation learning,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9729–9738, 2020.
[moco31]X. Chen, H. Fan, R. Girshick, and K. He, “Improved base- lines with momentum contrastive learning,” arXiv preprint arXiv:2003.04297, 2020.
Example
>>> import pickle, torch, selfeeg >>> import selfeeg.dataloading as dl >>> utils.create_dataset() >>> def loadEEG(path, return_label=False): ... with open(path, 'rb') as handle: ... EEG = pickle.load(handle) ... x , y= EEG['data'], EEG['label'] ... return (x, y) if return_label else x >>> def loss_fineTuning(yhat, ytrue): ... return F.binary_cross_entropy_with_logits(torch.squeeze(yhat), ytrue + 0.) >>> torch.manual_seed(1234) >>> EEGlen = dl.get_eeg_partition_number('Simulated_EEG',freq=128, window=1, ... overlap=0.3, load_function=loadEEG) >>> EEGsplit = dl.get_eeg_split_table (EEGlen, seed=1234) >>> TrainSet = dl.EEGDataset(EEGlen,EEGsplit,[128,1,0.3],'train',False,loadEEG) >>> Loader = torch.utils.data.DataLoader(TrainSet, batch_size=32) >>> enc = selfeeg.models.ShallowNetEncoder(8)
MoCo v2
>>> moco2 = selfeeg.ssl.MoCo(enc, [16,32,32], bank_size=4096) >>> print( moco2(torch.randn(32,8,128)).shape) # should return [32,32]) >>> loss_train = moco2.fit(Loader, 1, return_loss_info=True) >>> print(loss_train[0][0]) # should return 79.2622 >>> loss_test = moco2.test(Loader) # just to show it works >>> print(loss_test) # should return 85.8382
MoCo v3
>>> moco3 = selfeeg.ssl.MoCo(enc, [16,32,32], [32,32]) >>> print( moco3(torch.randn(32,8,128)).shape) # should return [32,32]) >>> loss_train = moco3.fit(Loader, 1, return_loss_info=True) >>> print(loss_train[0][0]) # should return 0.93120 >>> loss_test = moco3.test(Loader) # just to show it works >>> print(loss_test) # should return 0.8531
- fit(train_dataloader, epochs=1, optimizer=None, augmenter=None, loss_func: Callable = None, loss_args: list = [], lr_scheduler=None, EarlyStopper=None, validation_dataloader=None, verbose: bool = True, device: str = None, return_loss_info: bool = False)[source]
fitis a custom fit function designed to perform pretraining on a given model with the given dataloader.- Parameters:
train_dataloader (Dataloader) – The pytorch Dataloader used to get the training batches. It must return a batch as a single tensor X, thus without label tensor Y.
epochs (int, optional) –
The number of training epochs. Must be an integer bigger than 0.
Default = 1
optimizer (torch Optimizer, optional) –
The optimizer used for weight’s update. It can be any optimizer provided in the torch.optim module. If not given:
SGD with learning rate 0.01 will be used for moco v2
Adam with default parameters will be used for moco v3.
Default = torch.optim.Adam
augmenter (function, optional) –
Any function (or callable object) used to perform data augmentation on the batch. It is highly suggested to resort to the augmentation module, which implements different data augmentation functions and classes to combine them. If none is given, a default augmentation with random vertical flip + random noise is applied. Note that in this case data augmentation is also performed on the validation set, since it is part of the SSL algorithm.
Default = None
loss_func (Callable, optional) –
The custom loss function. It can be any loss function which accepts as input only the model’s predictions (2 torch Tensor) as required arguments and loss_args as optional arguments. Check the input arguments of
moco_lossto check how to design custom loss functions to give to this method.Default = None
loss_args (Union[list, dict], optional) –
The optional arguments to pass to the function. It can be a list or a dict.
Default = None
label_encoder (function, optional) –
A custom function used to encode the returned Dataloaders true labels. If None, the Dataloader’s true label is used directly.
Default = None
lr_scheduler (torch Scheduler) –
A pytorch learning rate scheduler used to update the learning rate during fine-tuning.
Default = None
EarlyStopper (EarlyStopping, optional) –
An instance of the provided EarlyStopping class.
Default = None
validation_dataloader (Dataloader, optional) –
The pytorch Dataloader used to get the validation batches. It must return a batch as a single tensor X, thus without label tensor Y. If not given, no validation loss will be calculated.
Default = None
verbose (bool, optional) –
Whether to print a progression bar or not.
Default = None
device (torch.device or str, optional) –
The device to use for fine-tuning. If given as a string it will be converted in a torch.device instance. If not given, ‘cpu’ device will be used.
Default = None
return_loss_info (bool, optional) –
Whether to return the calculated training validation losses at each epoch.
Default = False
- Returns:
loss_info (dict, optional) – A dictionary with keys being the epoch number (as integer) and values a two element list with the average epoch’s training and validation loss.
Note
If an EarlyStopping instance is given with monitoring loss set to validation loss, but no validation dataloader is given, monitoring loss will be automatically set to training loss.
- test(test_dataloader, augmenter=None, loss_func: Callable = None, loss_args: list = [], verbose: bool = True, device: str = None)[source]
A method to evaluate the loss on a test dataloader. Parameters are the same as described in the fit method, aside for those related to model training which are removed.
It is rare to evaluate the pretraing loss function on a test set. Nevertheless this function provides a way to do that. An example of usage could be to assess the quality of the learned features on the fine-tuning dataset.