EEGNetEncoder

class selfeeg.models.encoders.EEGNetEncoder(Chans: int, kernLength: int = 64, dropRate: float = 0.5, F1: int = 8, D: int = 2, F2: int = 16, dropType: str = 'Dropout', ELUalpha: int = 1, pool1: int = 4, pool2: int = 8, separable_kernel: int = 16, depthwise_max_norm: float = 1.0, seed: int = None)[source]

Pytorch Implementation of the EEGnet Encoder.

See EEGNet for some references. The expected input is a 3D tensor with size (Batch x Channels x Samples).

Parameters:
  • Chans (int) – The number of EEG channels.

  • kernlength (int, optional) –

    The length of the temporal convolutional layer.

    Default = 64

  • dropRate (float, optional) –

    The dropout percentage in range [0,1].

    Default = 0.5

  • F1 (int, optional) –

    The number of filters in the first layer.

    Default = 8

  • D (int, optional) –

    The depth of the depthwise conv layer.

    Default = 16

  • dropType (str, optional) –

    The type of dropout. It can be any between ‘Dropout’ and ‘SpatialDropout2D’.

    Default = ‘Dropout’

  • ELUalpha (float, optional) –

    The alpha value of the ELU activation function.

    Default = 1

  • pool1 (int, optional) –

    The first temporal average pooling kernel size.

    Default = 4

  • pool2 (int, optional) –

    The second temporal average pooling kernel size.

    Default = 8

  • separable_kernel (int, optional) –

    The temporal separable conv layer kernel size.

    Default = 16

  • depthwise_max_norm (float, optional) –

    The maximum norm each filter can have in the depthwise block. If None no constraint will be applied.

    Default = None

  • seed (int, optional) –

    A custom seed for model initialization. It must be a nonnegative number. If None is passed, no custom seed will be set

    Default = None

Note

This implementation refers to the latest version of EEGNet which can be found in the official repository.

Example

>>> import selfeeg.models
>>> import torch
>>> x = torch.randn(4,8,64)
>>> mdl = models.EEGNetEncoder(8)
>>> out = mdl(x)
>>> print(out.shape) # shoud return torch.Size([4, 32])
>>> print(torch.isnan(out).sum()) # shoud return 0