EEGConformerEncoder
- class selfeeg.models.encoders.EEGConformerEncoder(Chans, F: int = 40, K1: int = 25, Pool: int = 75, stride_pool: int = 15, d_model: int = 40, nlayers: int = 6, nheads: int = 10, dim_feedforward: int = 160, activation_transformer: str = 'gelu', p: float = 0.2, p_transformer: float = 0.5, seed: int = None)[source]
Pytorch implementation of the EEGConformer Encoder.
See EEGConformer for some references. The expected input is a 3D tensor with size (Batch x Channels x Samples).
- Parameters:
Chans (int) – The number of EEG channels.
F (int, optional) –
The number of output filters in the temporal convolution layer.
Default = 40
K1 (int, optional) –
The length of the temporal convolutional layer.
Default = 25
Pool (int, optional) –
The temporal pooling kernel size.
Default = 75
stride_pool (int, optional) –
The temporal pooling stride.
Default = 15
d_model (int, optional) –
The embedding size. It is the number of expected features in the input of the transformer encoder layer.
Default = 40
nlayers (int, optional) –
The number of transformer encoder layers.
Default = 6
nheads (int, optional) –
The number of heads in the multi-head attention layers.
Default = 10
dim_feedforward (int, optional) –
The dimension of the feedforward hidden layer in the transformer encoder.
Default = 160
activation_transformer (str or Callabel, optional) –
The activation function in the transformer encoder. See the PyTorch TransformerEncoderLayer documentation for accepted inputs.
Default = “gelu”
p (float, optional) –
Dropout probability in the tokenizer. Must be in [0,1)
Default= 0.2
p_transformer (float, optional) –
Dropout probability in the transformer encoder. Must be in [0,1)
Default= 0.5
seed (int, optional) –
A custom seed for model initialization. It must be a nonnegative number. If None is passed, no custom seed will be set
Default = None
Example
>>> import selfeeg.models >>> import torch >>> x = torch.randn(4,8,512) >>> mdl = models.EEGConformerEncoder(8) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 224]) >>> print(torch.isnan(out).sum()) # shoud return 0