evaluate_loss

selfeeg.ssl.base.evaluate_loss(loss_fun: Callable, arguments: Tensor, loss_arg: list | dict = None)[source]

evaluates a custom loss function.

It requires arguments as required arguments and loss_arg as optional one. It is simply the SSLBase's evaluate_loss method exported as a function.

Parameters:
  • loss_fun (Callable) –

    The custom loss function. It can be any Callable object that accepts as input:

    1. the model’s prediction (or predictions) and the true labels as required argument

    2. any element included in loss_args as optional arguments.

    Note that for the fine_tune method the number of required arguments must be 2, i.e. the model’s prediction and true labels.

  • arguments (torch.Tensor or list[torch.Tensors]) – the required arguments. Based on the way this function is used in a training pipeline it can be a single or multiple tensors.

  • loss_arg (Union[list, dict], optional) –

    The optional arguments to pass to the function. It can be a list or a dict.

    Default = None

Returns:

loss (‘loss_fun output’) – The output of the given loss function. It is expected to be a torch.Tensor.

Example

>>> import torch
>>> import selfeeg.ssl
>>> torch.manual_seed(1234)
>>> ytrue = torch.randn(64, 1)
>>> yhat  = torch.randn(64, 1)
>>> loss = ssl.evaluate_loss(torch.nn.functional.mse_loss, [yhat,ytrue])
>>> print(loss) # will print 1.9893