SSLBase

class selfeeg.ssl.base.SSLBase(encoder: Module)[source]

Baseline Self-Supervised Learning nn.Module.

It is used as parent class by the other implemented SSL methods.

Parameters:

encoder (nn.Module) – The encoder part of the module. It is the part of the model you wish to pretrain and transfer to the new model.

Example

>>> import selfeeg
>>> import torch
>>> enc = selfeeg.models.ShallowNetEncoder(8)
>>> base = selfeeg.ssl.SSLBase(enc)
>>> torch.manual_seed(1234)
>>> a = torch.randn(64,32)
>>> print(base.evaluate_loss(losses.simclr_loss, [a])) # should return 9.4143
>>> enc2 = base.get_encoder()
>>> def check_models(model1, model2):
...     for p1, p2 in zip(model1.parameters(), model2.parameters()):
...         if p1.data.ne(p2.data).sum() > 0:
...             return False
...     return True
>>> print(check_models(base.encoder,enc2)) # should return True
>>> # assert that they are different objects
>>> enc2.conv1.weight = torch.nn.Parameter(enc2.conv1.weight*10)
>>> print(check_models(base.encoder,enc2)) # should return False
evaluate_loss(loss_fun: Callable, arguments: Tensor, loss_arg: list | dict = None) Tensor[source]

evaluate_loss evaluate a custom loss function using arguments as required arguments and loss_arg as optional ones.

Parameters:
  • loss_fun (function) –

    The custom loss function. It can be any loss function which accepts as input:

    1. the model’s prediction (or predictions)

    2. any element included in loss_args as optional arguments.

    Note that the number of required arguments can change based on the specific pretraining method used. For example, SimCLR accepts 1 or 2 required arguments, while BYOL must take 4.

  • arguments (torch.Tensor or list[torch.Tensors]) – the required arguments. Based on the way this function is used in a training pipeline it can be a single or multiple tensors.

  • loss_arg (Union[list, dict], optional) –

    The optional arguments to pass to the function. It can be a list or a dict.

    Default = None

Returns:

loss (torch.Tensor) – The output of the given loss function. It is expected to be a torch.Tensor.

forward(x)[source]
get_encoder(device='cpu')[source]

Returns a copy of the encoder on the selected device.

Parameters:

device (torch.device or str, optional) –

The pytorch device where the encoder must be moved.

Default = ‘cpu’

save_encoder(path: str = None)[source]

A method for saving the pretrained encoder.

Parameters:

path (str, optional) –

The saving path, that will be given to the torch.save() method. If None is given, the encoder will be saved in a created SSL_encoders subdirectory. The name will contain the pretraining method used (e.g. SimCLR, MoCo etc) and the current time.

Default = None