SSLBase
- class selfeeg.ssl.base.SSLBase(encoder: Module)[source]
Baseline Self-Supervised Learning nn.Module.
It is used as parent class by the other implemented SSL methods.
- Parameters:
encoder (nn.Module) – The encoder part of the module. It is the part of the model you wish to pretrain and transfer to the new model.
Example
>>> import selfeeg >>> import torch >>> enc = selfeeg.models.ShallowNetEncoder(8) >>> base = selfeeg.ssl.SSLBase(enc) >>> torch.manual_seed(1234) >>> a = torch.randn(64,32) >>> print(base.evaluate_loss(losses.simclr_loss, [a])) # should return 9.4143 >>> enc2 = base.get_encoder() >>> def check_models(model1, model2): ... for p1, p2 in zip(model1.parameters(), model2.parameters()): ... if p1.data.ne(p2.data).sum() > 0: ... return False ... return True >>> print(check_models(base.encoder,enc2)) # should return True >>> # assert that they are different objects >>> enc2.conv1.weight = torch.nn.Parameter(enc2.conv1.weight*10) >>> print(check_models(base.encoder,enc2)) # should return False
- evaluate_loss(loss_fun: Callable, arguments: Tensor, loss_arg: list | dict = None) Tensor[source]
evaluate_lossevaluate a custom loss function using arguments as required arguments and loss_arg as optional ones.- Parameters:
loss_fun (function) –
The custom loss function. It can be any loss function which accepts as input:
the model’s prediction (or predictions)
any element included in loss_args as optional arguments.
Note that the number of required arguments can change based on the specific pretraining method used. For example, SimCLR accepts 1 or 2 required arguments, while BYOL must take 4.
arguments (torch.Tensor or list[torch.Tensors]) – the required arguments. Based on the way this function is used in a training pipeline it can be a single or multiple tensors.
loss_arg (Union[list, dict], optional) –
The optional arguments to pass to the function. It can be a list or a dict.
Default = None
- Returns:
loss (torch.Tensor) – The output of the given loss function. It is expected to be a torch.Tensor.
- get_encoder(device='cpu')[source]
Returns a copy of the encoder on the selected device.
- Parameters:
device (torch.device or str, optional) –
The pytorch device where the encoder must be moved.
Default = ‘cpu’
- save_encoder(path: str = None)[source]
A method for saving the pretrained encoder.
- Parameters:
path (str, optional) –
The saving path, that will be given to the
torch.save()method. If None is given, the encoder will be saved in a created SSL_encoders subdirectory. The name will contain the pretraining method used (e.g. SimCLR, MoCo etc) and the current time.Default = None