byol_loss
- selfeeg.losses.losses.byol_loss(p1: Tensor, z1: Tensor, p2: Tensor, z2: Tensor, projections_norm: bool = True) Tensor[source]
Simple pytorch implementation of the BYOL loss function presented in [BYOL] .
- Parameters:
p1 (torch.Tensor) – 2D Tensor with one augmented batch predictor output.
z1 (torch.Tensor) – 2D Tensor with one augmented batch projection output.
p2 (torch.Tensor) – Same as p1 but with the other augmented batch.
z2 (torch.Tensor) – Same as z1 with the other augmented batch.
projections_norm (bool, optional) –
Whether to normalize the projections or not.
Default= True
- Returns:
loss (torch.Tensor) – The calculated loss.
References
[BYOL]J.-B. Grill, F. Strub, F. Altché, C. Tallec, P. Richemond, E. Buchatskaya, C. Doersch, B. Avila Pires, Z. Guo, M. Gheshlaghi Azar, et al., “Bootstrap your own latent - a new approach to self-supervised learning,” Advances in neural information processing systems, vol. 33, pp. 21271– 21284, 2020.
Example
>>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> p1 = torch.randn(64, 32) >>> z1 = torch.randn(64, 32) >>> p2 = torch.randn(64, 32) >>> z2 = torch.randn(64, 32) >>> loss = losses.byol_loss(p1,z1,p2,z2) >>> print(loss) # will return 3.9357