byol_loss

selfeeg.losses.losses.byol_loss(p1: Tensor, z1: Tensor, p2: Tensor, z2: Tensor, projections_norm: bool = True) Tensor[source]

Simple pytorch implementation of the BYOL loss function presented in [BYOL] .

Parameters:
  • p1 (torch.Tensor) – 2D Tensor with one augmented batch predictor output.

  • z1 (torch.Tensor) – 2D Tensor with one augmented batch projection output.

  • p2 (torch.Tensor) – Same as p1 but with the other augmented batch.

  • z2 (torch.Tensor) – Same as z1 with the other augmented batch.

  • projections_norm (bool, optional) –

    Whether to normalize the projections or not.

    Default= True

Returns:

loss (torch.Tensor) – The calculated loss.

References

[BYOL]

J.-B. Grill, F. Strub, F. Altché, C. Tallec, P. Richemond, E. Buchatskaya, C. Doersch, B. Avila Pires, Z. Guo, M. Gheshlaghi Azar, et al., “Bootstrap your own latent - a new approach to self-supervised learning,” Advances in neural information processing systems, vol. 33, pp. 21271– 21284, 2020.

Example

>>> import torch
>>> import selfeeg.losses
>>> torch.manual_seed(1234)
>>> p1 = torch.randn(64, 32)
>>> z1 = torch.randn(64, 32)
>>> p2 = torch.randn(64, 32)
>>> z2 = torch.randn(64, 32)
>>> loss = losses.byol_loss(p1,z1,p2,z2)
>>> print(loss) # will return 3.9357