evaluate_loss
- selfeeg.ssl.base.evaluate_loss(loss_fun: Callable, arguments: Tensor, loss_arg: list | dict = None)[source]
evaluates a custom loss function.
It requires arguments as required arguments and loss_arg as optional one. It is simply the
SSLBase's evaluate_lossmethod exported as a function.- Parameters:
loss_fun (Callable) –
The custom loss function. It can be any Callable object that accepts as input:
the model’s prediction (or predictions) and the true labels as required argument
any element included in loss_args as optional arguments.
Note that for the
fine_tunemethod the number of required arguments must be 2, i.e. the model’s prediction and true labels.arguments (torch.Tensor or list[torch.Tensors]) – the required arguments. Based on the way this function is used in a training pipeline it can be a single or multiple tensors.
loss_arg (Union[list, dict], optional) –
The optional arguments to pass to the function. It can be a list or a dict.
Default = None
- Returns:
loss (‘loss_fun output’) – The output of the given loss function. It is expected to be a torch.Tensor.
Example
>>> import torch >>> import selfeeg.ssl >>> torch.manual_seed(1234) >>> ytrue = torch.randn(64, 1) >>> yhat = torch.randn(64, 1) >>> loss = ssl.evaluate_loss(torch.nn.functional.mse_loss, [yhat,ytrue]) >>> print(loss) # will print 1.9893