xEEGNet
- class selfeeg.models.zoo.xEEGNet(nb_classes: int, Chans: int, Samples: int, Fs: int, F1: int = 7, K1: int = 125, F2: int = 7, Pool: int = 75, p: float = 0.2, random_temporal_filter=False, freeze_temporal: int = 1000000000000.0, spatial_depthwise: bool = True, log_activation_base: str = 'dB', norm_type: str = 'batchnorm', global_pooling=True, bias: list[int, int, int] = [False, False, False], dense_hidden: int = -1, return_logits=True, seed=None)[source]
Pytorch implementation of xEEGNet.
For more information see the following paper [xEEG] . The original implementation of EEGconformer can be found here [xEEGgit] . The expected input is a 3D tensor with size:
(Batch x Channels x Samples).
- Parameters:
nb_classes (int) – The number of classes. If less than 2, a binary classification problem is considered (output dimensions will be [batch, 1] in this case).
Chans (int) – The number of EEG channels.
Samples (int) – The sample length (number of time steps). It will be used to calculate the embedding size (for head initialization).
Fs (int) – The sampling rate of the EEG signal in Hz. It is used to initialize the weights of the filters. Must be specified even if random_temporal_filter is False.
F1 (int, optional) –
The number of output filters in the temporal convolution layer.
Default = 7
K1 (int, optional) –
The length of the temporal convolutional layer.
Default = 125
F2 (int, optional) –
The number of output filters in the spatial convolution layer.
Default = 7
Pool (int, optional) –
Kernel size for temporal pooling.
Default = 75
p (float, optional) –
Dropout probability in [0,1)
Default = 0.2
random_temporal_filter (bool, optional) –
If True, initialize the temporal filter weights randomly. Otherwise, use a passband FIR filter.
Default = False
freeze_temporal (int, optional) –
Number of forward steps to keep the temporal layer frozen.
Default = 1e12
spatial_depthwise (bool, optional) –
Whether to apply a depthwise layer in the spatial convolution.
Default = True
log_activation_base (str, optional) –
Base for the logarithmic activation after pooling. Options: “e” (natural log), “10” (logarithm base 10), “dB” (decibel scale).
Default = “dB”
norm_type (str, optional) –
The type of normalization. Expected values are “batch” or “instance”.
Default = “batchnorm”
global_pooling (bool, optional) –
If True, apply global average pooling instead of flattening.
Default = True
bias (list[int, int], optional) –
A 2-element list with boolean values. If the first element is True, a bias will be added to the temporal convolutional layer. If the second element is True, a bias will be added to the spatial convolutional layer. If the third element is True, a bias will be added to the final dense layer.
Default = [False, False, False]
return_logits (bool, optional) –
If True, return the output as logit. It is suggested to not use False as the pytorch crossentropy loss function applies the softmax internally.
Default = True
seed (int, optional) –
A custom seed for model initialization. It must be a nonnegative number. If None is passed, no custom seed will be set
Default = None
References
[xEEG]zanola et al., xEEGNet: Towards Explainable AI in EEG Dementia Classification. arXiv preprint. 2025. https://doi.org/10.48550/arXiv.2504.21457
Example
>>> import selfeeg.models >>> import torch >>> x = torch.randn(4,8,512) >>> mdl = models.xEEGNet(3, 8, 512, 125) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 3])