Dataloading guide
This section is intended to provide a brief introduction to the dataloading module and its main functionalities.
In short, all functions and custom classes are designed to help you create an efficient Pytorch Dataloader to use during training. The main objective is to avoid loading the entire dataset all at once, but instead iteratively load (possibly overlapping) time windows called “partitions”. A typical pipeline is based on the following steps:
Define the partition specs, i.e. the EEGs’ sampling rate, the window length and the overlap between consecutive windows.
Call the GetEEGPartitionNumber function to extract the dataset length, i.e. the number of partitions which can be extracted from the EEG datasets, given the defined partition specs.
Call the GetEEGSplitTable or the GetEEGSplitTableKfold function to split the data in train, validation and test sets.
Pass the results of the previous points to the custom Pytorch Dataset EEGDataset
Optional: create a custom Pytorch Sampler EEGSampler
Create a Pytorch Dataloader with the custom Dataset (and Sampler)
First, let’s import the dataloading module
[1]:
import os
import random
import pickle
import sys
sys.path.append('..') # Needed when running this from the selfeeg/doc folder
from selfeeg import dataloading as dl
import numpy as np
import torch
from torch.utils.data import DataLoader
# set seeds for reproducibility
seed = 12
torch.manual_seed( seed )
np.random.seed( seed )
random.seed( seed )
To provide a simple and excecutable tutorial, we will create a fake collection of EEG datasets (already aligned) which we will save in a folder “Simulated EEG”. Just to be clear, we will generate randn arrays of random length and save them. This is just to avoid downloading large datasets.
To keep the size of the folder low, each file will be:
a 2 Channel EEG
random length between 1024 and 4096 samples
Stored with name
"{dataset_id}_{subject_id}_{session_id}_{trial_id}.pickle". This will be useful for the split part
[2]:
# create a folder if that not exists
if not(os.path.isdir('Simulated_EEG')):
os.mkdir('Simulated_EEG')
N=1000
for i in range(N):
x = np.random.randn(2,np.random.randint(1024,4097))
y = np.random.randint(1,5)
sample = {'data': x, 'label': y}
dataset_id = (int(i//200)+1)
subject_id = (int( (i - 200*int(i//200)))//5+1)
session_id = (i%5+1)
trial_id = 1
file_name = f'Simulated_EEG/{dataset_id}_{subject_id}_{session_id}_{trial_id}.pickle'
with open(file_name, 'wb') as f:
pickle.dump(sample, f)
Now we have a folder with simulated 1000 EEGs coming from:
5 datasets (ID from 1 to 5);
40 subjects per dataset (ID from 1 to 40)
5 session per subject (ID from 1 to 5)
Each file is a pickle file with a dictionary having keys:
'data': the numpy 2D array'label’: a fake label associated to the EEG file (from 1 to 4)
The GetEEGPartitionNumber function
This function is important to calculate the dataset length once defined the partition specs. Let’s suppose data have a sampling rate of 128 Hz, and we want to extract 2 seconds samples with a 15% overlap.
To complicate things, let’s assume that we want to remove the last half second of record, for example because it often has bad recorded data.
WARNING:
remember that this function is not omniscent, so we need to give a way to load the data. By default the function will try the scipy’s loadmat function with the syntax EEG = loadmat(path_to_file, simplify_cells=True)['DATA_STRUCT']['data'] which is the output of the BIDSalign library provided by our team
[3]:
# Define partition spec
eegpath = 'Simulated_EEG'
freq = 128 # sampling frequency in [Hz]
overlap = 0.15 # overlap between partitions
window = 2 # window length in [seconds]
# define a function to load and transform data
# SOME NOTES: these function can be fused to an unique one. Also, if
# there's need to pass some arguments it's possible to pass them with
# the optional_load_fun_args and optional_transform_fun_args arguments
def loadEEG(path, return_label=False):
with open(path, 'rb') as handle:
EEG = pickle.load(handle)
x = EEG['data']
y = EEG['label']
if return_label:
return x, y
else:
return x
def transformEEG(EEG):
EEG = EEG[:,:-64]
return EEG
# call the function
EEGlen = dl.get_eeg_partition_number(
eegpath,
freq,
window,
overlap,
file_format='*.pickle',
load_function=loadEEG,
optional_load_fun_args=[False],
transform_function=transformEEG
)
EEGlen.head()
[3]:
| full_path | file_name | N_samples | |
|---|---|---|---|
| 0 | Simulated_EEG/1_10_1_1.pickle | 1_10_1_1.pickle | 15 |
| 1 | Simulated_EEG/1_10_2_1.pickle | 1_10_2_1.pickle | 5 |
| 2 | Simulated_EEG/1_10_3_1.pickle | 1_10_3_1.pickle | 7 |
| 3 | Simulated_EEG/1_10_4_1.pickle | 1_10_4_1.pickle | 12 |
| 4 | Simulated_EEG/1_10_5_1.pickle | 1_10_5_1.pickle | 6 |
The GetEEGSplitTable function
Now that we have a table with the exact number of samples associated to each EEG file, let’s split the data.
Split can be performed with different level of granularity (e.g. dataset, subject, file level), and can be performed in different ways, i.e. by giving the ID to put in a set, or simply the ratio. Also, some data can be excluded and, if you have a label (or a way to extract it) associated to the file, it is possible to perform a stratified split, with the ratio between label preserved, up to a certain limit, in each set.
TIP
you can also create a table for cross validation splits with the GetEEGSplitTableKfold function. Its functionalities are similar to the previous function, and if you want to extract a specific partition, you can use the ExtractSplit function.
WARNING
stratification assume that EEG files at the split granulosity level share the same label. For example, if you want to split files at the subject level, be sure that all EEGs from the same subject are associated with the same labels, otherwise the split will not be excecuted in the right way.
For now, let’s assume we want to do a stratified split at the file level, but we want to exclude EEGs from subjects 13 and 23 of each dataset. Split ratios are 80/10/10
[5]:
# for stratified split we need to create an array with the labels
# associated to each eeg file
Labels = np.zeros(EEGlen.shape[0], dtype=int)
for i in range(EEGlen.shape[0]):
_ , Labels[i] = loadEEG(EEGlen.iloc[i]['full_path'], return_label=True)
EEGsplit = dl.get_eeg_split_table(
EEGlen,
test_ratio = 0.1,
val_ratio = 0.1,
test_split_mode = 'file',
val_split_mode = 'file',
exclude_data_id = None, #{x:[13,23] for x in range(1,6)},
stratified = True,
labels = Labels,
perseverance = 5000,
split_tolerance = 0.005,
seed = seed
)
dl.check_split(EEGlen, EEGsplit, Labels)
train ratio: 0.80
validation ratio: 0.10
test ratio: 0.10
train labels ratio: 1=0.239, 2=0.244, 3=0.252, 4=0.265,
val labels ratio: 1=0.239, 2=0.244, 3=0.252, 4=0.265,
test labels ratio: 1=0.239, 2=0.244, 3=0.252, 4=0.265,
Here is another example of a non stratified split at the subject level (EEG from the same subject in the same split set), but we want to exclude EEGs from subjects 13 and 23 of each dataset. Split ratios are 80/10/10
[6]:
EEGsplit2 = dl.get_eeg_split_table(
EEGlen,
test_ratio = 0.1,
val_ratio = 0.1,
test_split_mode = 'subject',
val_split_mode = 'subject',
exclude_data_id = {x:[13,23] for x in range(1,6)},
dataset_id_extractor = lambda x: int(x.split('_')[0]),
subject_id_extractor = lambda x: int(x.split('_')[1]),
perseverance = 5000,
split_tolerance = 0.005,
seed = seed
)
dl.check_split(EEGlen, EEGsplit2)
# Considering the structure of the created dataset,
# it's easy to look if splits are really subject based
for i in range(EEGsplit2.shape[0]//5):
if EEGsplit2.iloc[(5*i):(5*i+5)]['split_set'].sum() not in [-5,0,5,10]:
# since split set is equal to -1, 0, 1, 2
# we just check that the sum of split set is five times one of such values
print('wrong_split')
train ratio: 0.80
validation ratio: 0.10
test ratio: 0.10
[7]:
EEGsplit2
[7]:
| file_name | split_set | |
|---|---|---|
| 0 | 1_10_1_1.pickle | 0 |
| 1 | 1_10_2_1.pickle | 0 |
| 2 | 1_10_3_1.pickle | 0 |
| 3 | 1_10_4_1.pickle | 0 |
| 4 | 1_10_5_1.pickle | 0 |
| ... | ... | ... |
| 995 | 5_9_1_1.pickle | 0 |
| 996 | 5_9_2_1.pickle | 0 |
| 997 | 5_9_3_1.pickle | 0 |
| 998 | 5_9_4_1.pickle | 0 |
| 999 | 5_9_5_1.pickle | 0 |
1000 rows × 2 columns
The EEGDataset class
Now we have all the ingredients necessary to initialize the custom dataset. The EEGDataset class is highly customizable, so we illustrate two examples, one usually employed for the pretraining, which doesn’t involve the extraction of labels from the EEG files, and the other usually employed for fine-tuning, which instead use the labels.
To initialize correctly the class EEGdataset you need :
the output of the
GetEEGPartitionNumberfunction (used to calculate the length)the output of the
GetEEGSplitTablefunction (used to extract data of a specific split set)the partition spec as a list (format: [freq, window, overlap])
other optional important parameters are:
the mode (train, validation, test), used to select data from a specific split set
the boolean ‘supervised’, used to tell if indexing using
[](the__getitem__method) must extract a label associated to the samplethe label_on_load argument, used to tell if indexing using
[](the__getitem__method) will get the label from the loading function or it must call a custom function
TIP 1
the class EEGDataset also accept custom functions to load, transform and get label from the EEG files.
TIP 2
if the label must be extracted from a dictionary, also with different files having the label inside a different key, check the label_key argument to handle that.
CASE 1: Pretraining - no label
[8]:
dataset_pretrain = dl.EEGDataset(
EEGlen,
EEGsplit,
[freq, window, overlap], # split parameters must be given as list
mode = 'train', # default, select all samples in the train set
load_function = loadEEG,
transform_function = transformEEG
)
sample_1 = dataset_pretrain[0] # Grab the first sample
print(sample_1.shape) # Note: the sample is automatically converted in a Tensor
torch.Size([2, 256])
CASE 2: FineTuning - with label
[9]:
dataset_finetune = dl.EEGDataset(
EEGlen,
EEGsplit,
[freq, window, overlap], # split parameters must be given as list
mode = 'train', # the default, select all samples in the train set
supervised = True, # !!!!IMPORTANT!!!!
load_function = loadEEG,
optional_load_fun_args= [True], # tells loadEEG to return a label
transform_function=transformEEG,
label_on_load=True, # the default
)
sample_2, label_2 = dataset_finetune[0] # grab the first sample
print(sample_2.shape, label_2) # now we also have a label
torch.Size([2, 256]) 1
The EEGSampler
Although optional, you can also create a custom sampler. The sampler allows creating 2 different types of iterator, which differently balance the trade-off between batch heterogeneity and batch creation speed:
Linear: just returns a linear iterator. It is useful when you want to minimize the number of EEG file loading operations. However batches will contain cosecutive partitions of the same file, which could affect the operations of some layers like BatchNorm layers. To initialize the sampler in this mode, simply use the command EEGSampler( EEGDataset, Mode=0)
Shuffled: it returns a customized iterator. The iterator is constructed in this way:
Samples are shuffled at the file level;
Samples of the same file are shuffled;
Samples are rearranged based on the desired batch size and number of works. This step is performed to exploit the parallelization properties of the pytorch dataloader and reduce the number of loading operations. To initialize the sampler in this mode, simply use the command EEGSampler( EEGDataset, BatchSize, Workers )
TIP
We suggest to use the linear iterator for validation and test purpose since it’s faster and does not require any batch heterogeneity.
Here is a schematic representation of how Shuffled iterator is constructed, with batch size = 5 and workers = 4

[10]:
sampler_linear = dl.EEGSampler(dataset_pretrain, Mode=0)
sampler_custom = dl.EEGSampler(dataset_pretrain, 16, 4)
Final Dataloader
Now simply put all together and create your custom Dataloader.
WARNING
If you have created a custom sampler, remember to also pass the same batch size and number of workers
[11]:
Final_Dataloader = DataLoader(
dataset = dataset_pretrain,
batch_size = 16,
sampler = sampler_custom,
num_workers = 0
)
for X in Final_Dataloader:
print(X.shape)
break
torch.Size([16, 2, 256])