{ "cells": [ { "cell_type": "markdown", "id": "8bdb80d3-737a-4306-8c46-3408f98851db", "metadata": {}, "source": [ "# Dataloading guide\n", "\n", "This section is intended to provide a brief introduction to the dataloading module and its main functionalities.\n", "\n", "In short, all functions and custom classes are designed to help you create an efficient Pytorch Dataloader to use during training. The main objective is to avoid loading the entire dataset all at once, but instead iteratively load (possibly overlapping) time windows called \"partitions\". A typical pipeline is based on the following steps:\n", "\n", "\n", "1) Define the **partition specs**, i.e. the EEGs' sampling rate, the window length and the overlap between consecutive windows.\n", "2) Call the **GetEEGPartitionNumber** function to extract the dataset length, i.e. the number of partitions which can be extracted from the EEG datasets, given the defined partition specs.\n", "3) Call the **GetEEGSplitTable** or the **GetEEGSplitTableKfold** function to split the data in train, validation and test sets.\n", "4) Pass the results of the previous points to the custom Pytorch Dataset **EEGDataset**\n", "5) Optional: create a custom Pytorch Sampler **EEGSampler**\n", "6) Create a **Pytorch Dataloader** with the custom Dataset (and Sampler)" ] }, { "cell_type": "markdown", "id": "6103b7e9-d8cc-4889-83ce-ec0b5f5bfe97", "metadata": {}, "source": [ "First, let's import the dataloading module" ] }, { "cell_type": "code", "execution_count": 1, "id": "31f8ad6a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "import pickle\n", "import sys\n", "sys.path.append('..') # Needed when running this from the selfeeg/doc folder\n", "from selfeeg import dataloading as dl\n", "\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import DataLoader\n", "\n", "# set seeds for reproducibility\n", "seed = 12\n", "torch.manual_seed( seed )\n", "np.random.seed( seed )\n", "random.seed( seed )" ] }, { "cell_type": "markdown", "id": "6038e1ec-3373-4e5f-bc32-193da92c44b7", "metadata": {}, "source": [ "To provide a simple and excecutable tutorial, we will create a fake collection of EEG datasets (already aligned) which we will save in a folder \"Simulated EEG\".\n", "Just to be clear, we will generate randn arrays of random length and save them. This is just to avoid downloading large datasets.\n", "\n", "To keep the size of the folder low, each file will be:\n", "1) a 2 Channel EEG\n", "2) random length between 1024 and 4096 samples\n", "3) Stored with name `\"{dataset_id}_{subject_id}_{session_id}_{trial_id}.pickle\"`. This will be useful for the split part " ] }, { "cell_type": "code", "execution_count": 2, "id": "669a9877-adc2-4741-b291-2ff0d5a11396", "metadata": {}, "outputs": [], "source": [ "# create a folder if that not exists\n", "if not(os.path.isdir('Simulated_EEG')):\n", " os.mkdir('Simulated_EEG')\n", "\n", "N=1000\n", "for i in range(N):\n", " x = np.random.randn(2,np.random.randint(1024,4097))\n", " y = np.random.randint(1,5)\n", " sample = {'data': x, 'label': y}\n", " dataset_id = (int(i//200)+1)\n", " subject_id = (int( (i - 200*int(i//200)))//5+1)\n", " session_id = (i%5+1)\n", " trial_id = 1\n", " file_name = f'Simulated_EEG/{dataset_id}_{subject_id}_{session_id}_{trial_id}.pickle'\n", " with open(file_name, 'wb') as f:\n", " pickle.dump(sample, f)" ] }, { "cell_type": "markdown", "id": "39788994-04b8-4554-b41e-ea9f56ea3e52", "metadata": {}, "source": [ "Now we have a folder with simulated 1000 EEGs coming from:\n", "1) 5 datasets (ID from 1 to 5);\n", "2) 40 subjects per dataset (ID from 1 to 40)\n", "3) 5 session per subject (ID from 1 to 5)\n", "\n", "Each file is a pickle file with a dictionary having keys:\n", "1) `'data'`: the numpy 2D array\n", "2) `'label`': a fake label associated to the EEG file (from 1 to 4)" ] }, { "cell_type": "markdown", "id": "06d4ca58", "metadata": {}, "source": [ "## The GetEEGPartitionNumber function\n", "\n", "This function is important to calculate the dataset length once defined the partition specs. Let's suppose data have a sampling rate of 128 Hz, and we want to extract 2 seconds samples with a 15% overlap. \n", "\n", "To complicate things, let's assume that we want to remove the last half second of record, for example because it often has bad recorded data.\n", "\n", "
\n", "WARNING: \n", " \n", "remember that this function is not omniscent, so we need to give a way to load the data. By default the function will try the scipy's `loadmat` function with the syntax\n", "`EEG = loadmat(path_to_file, simplify_cells=True)['DATA_STRUCT']['data'] ` \n", "which is the output of the BIDSalign library provided by our team\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 3, "id": "0fe92ca6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
full_pathfile_nameN_samples
0Simulated_EEG/1_10_1_1.pickle1_10_1_1.pickle15
1Simulated_EEG/1_10_2_1.pickle1_10_2_1.pickle5
2Simulated_EEG/1_10_3_1.pickle1_10_3_1.pickle7
3Simulated_EEG/1_10_4_1.pickle1_10_4_1.pickle12
4Simulated_EEG/1_10_5_1.pickle1_10_5_1.pickle6
\n", "
" ], "text/plain": [ " full_path file_name N_samples\n", "0 Simulated_EEG/1_10_1_1.pickle 1_10_1_1.pickle 15\n", "1 Simulated_EEG/1_10_2_1.pickle 1_10_2_1.pickle 5\n", "2 Simulated_EEG/1_10_3_1.pickle 1_10_3_1.pickle 7\n", "3 Simulated_EEG/1_10_4_1.pickle 1_10_4_1.pickle 12\n", "4 Simulated_EEG/1_10_5_1.pickle 1_10_5_1.pickle 6" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define partition spec\n", "eegpath = 'Simulated_EEG'\n", "freq = 128 # sampling frequency in [Hz]\n", "overlap = 0.15 # overlap between partitions\n", "window = 2 # window length in [seconds]\n", "\n", "# define a function to load and transform data\n", "# SOME NOTES: these function can be fused to an unique one. Also, if\n", "# there's need to pass some arguments it's possible to pass them with\n", "# the optional_load_fun_args and optional_transform_fun_args arguments\n", "def loadEEG(path, return_label=False):\n", " with open(path, 'rb') as handle:\n", " EEG = pickle.load(handle)\n", " x = EEG['data']\n", " y = EEG['label']\n", " if return_label:\n", " return x, y\n", " else:\n", " return x\n", "\n", "def transformEEG(EEG):\n", " EEG = EEG[:,:-64]\n", " return EEG\n", "\n", "# call the function\n", "EEGlen = dl.get_eeg_partition_number(\n", " eegpath,\n", " freq,\n", " window,\n", " overlap,\n", " file_format='*.pickle',\n", " load_function=loadEEG,\n", " optional_load_fun_args=[False],\n", " transform_function=transformEEG\n", ")\n", "EEGlen.head()" ] }, { "cell_type": "markdown", "id": "d1a02f78-456b-4463-a502-3d5491094cc1", "metadata": {}, "source": [ "## The GetEEGSplitTable function\n", "\n", "Now that we have a table with the exact number of samples associated to each EEG file, let's split the data.\n", "\n", "Split can be performed with different level of granularity (e.g. dataset, subject, file level), and can be performed in different ways, i.e. by giving the ID to put in a set, or simply the ratio. Also, some data can be excluded and, if you have a label (or a way to extract it) associated to the file, it is possible to perform a stratified split, with the ratio between label preserved, up to a certain limit, in each set.\n", "\n", "
\n", "TIP \n", " \n", "you can also create a table for cross validation splits with the `GetEEGSplitTableKfold` function. Its functionalities are similar to the previous function, and if you want to extract a specific partition, you can use the `ExtractSplit` function.\n", "\n", "
\n", "\n", "
\n", "WARNING \n", "\n", "stratification assume that EEG files at the split granulosity level share the same label. For example, if you want to split files at the subject level, be sure that all EEGs from the same subject are associated with the same labels, otherwise the split will not be excecuted in the right way. \n", "\n", "
" ] }, { "cell_type": "markdown", "id": "f368d5d2-7d49-4a14-91d8-4e538b12f6ab", "metadata": {}, "source": [ "For now, let's assume we want to do a **stratified split** at the **file level**, but we want to **exclude EEGs from subjects 13 and 23 of each dataset**. Split ratios are **80/10/10**" ] }, { "cell_type": "code", "execution_count": 5, "id": "aabf2f21", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "train ratio: 0.80\n", "validation ratio: 0.10\n", "test ratio: 0.10\n", "\n", "train labels ratio: 1=0.239, 2=0.244, 3=0.252, 4=0.265, \n", "val labels ratio: 1=0.239, 2=0.244, 3=0.252, 4=0.265, \n", "test labels ratio: 1=0.239, 2=0.244, 3=0.252, 4=0.265, \n", "\n" ] } ], "source": [ "# for stratified split we need to create an array with the labels\n", "# associated to each eeg file\n", "Labels = np.zeros(EEGlen.shape[0], dtype=int)\n", "for i in range(EEGlen.shape[0]):\n", " _ , Labels[i] = loadEEG(EEGlen.iloc[i]['full_path'], return_label=True)\n", "\n", "EEGsplit = dl.get_eeg_split_table(\n", " EEGlen, \n", " test_ratio = 0.1,\n", " val_ratio = 0.1,\n", " test_split_mode = 'file',\n", " val_split_mode = 'file',\n", " exclude_data_id = None, #{x:[13,23] for x in range(1,6)},\n", " stratified = True,\n", " labels = Labels,\n", " perseverance = 5000,\n", " split_tolerance = 0.005,\n", " seed = seed\n", ")\n", "dl.check_split(EEGlen, EEGsplit, Labels)" ] }, { "cell_type": "markdown", "id": "5749be37-fc41-4762-97e1-f1dbd7d4a6df", "metadata": {}, "source": [ "
Here is another example of a **non stratified split** at the **subject level** (EEG from the same subject in the same split set), but we want to **exclude EEGs from subjects 13 and 23 of each dataset**. Split ratios are **80/10/10**" ] }, { "cell_type": "code", "execution_count": 6, "id": "80f99b9f-a634-4e1c-90cd-10c30f5c20ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "train ratio: 0.80\n", "validation ratio: 0.10\n", "test ratio: 0.10\n" ] } ], "source": [ "EEGsplit2 = dl.get_eeg_split_table(\n", " EEGlen,\n", " test_ratio = 0.1,\n", " val_ratio = 0.1,\n", " test_split_mode = 'subject',\n", " val_split_mode = 'subject',\n", " exclude_data_id = {x:[13,23] for x in range(1,6)},\n", " dataset_id_extractor = lambda x: int(x.split('_')[0]),\n", " subject_id_extractor = lambda x: int(x.split('_')[1]),\n", " perseverance = 5000,\n", " split_tolerance = 0.005,\n", " seed = seed\n", ")\n", "\n", "dl.check_split(EEGlen, EEGsplit2)\n", "\n", "# Considering the structure of the created dataset, \n", "# it's easy to look if splits are really subject based\n", "for i in range(EEGsplit2.shape[0]//5):\n", " if EEGsplit2.iloc[(5*i):(5*i+5)]['split_set'].sum() not in [-5,0,5,10]:\n", " # since split set is equal to -1, 0, 1, 2\n", " # we just check that the sum of split set is five times one of such values\n", " print('wrong_split')" ] }, { "cell_type": "code", "execution_count": 7, "id": "2a9af910-958f-4d65-8615-827c5a9e38eb", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
file_namesplit_set
01_10_1_1.pickle0
11_10_2_1.pickle0
21_10_3_1.pickle0
31_10_4_1.pickle0
41_10_5_1.pickle0
.........
9955_9_1_1.pickle0
9965_9_2_1.pickle0
9975_9_3_1.pickle0
9985_9_4_1.pickle0
9995_9_5_1.pickle0
\n", "

1000 rows × 2 columns

\n", "
" ], "text/plain": [ " file_name split_set\n", "0 1_10_1_1.pickle 0\n", "1 1_10_2_1.pickle 0\n", "2 1_10_3_1.pickle 0\n", "3 1_10_4_1.pickle 0\n", "4 1_10_5_1.pickle 0\n", ".. ... ...\n", "995 5_9_1_1.pickle 0\n", "996 5_9_2_1.pickle 0\n", "997 5_9_3_1.pickle 0\n", "998 5_9_4_1.pickle 0\n", "999 5_9_5_1.pickle 0\n", "\n", "[1000 rows x 2 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "EEGsplit2" ] }, { "cell_type": "markdown", "id": "e57547ad", "metadata": {}, "source": [ "## The EEGDataset class\n", "\n", "Now we have all the ingredients necessary to initialize the custom dataset. The EEGDataset class is highly customizable, so we illustrate two examples, one usually employed for the pretraining, which doesn't involve the extraction of labels from the EEG files, and the other usually employed for fine-tuning, which instead use the labels.\n", "\n", "To initialize correctly the class EEGdataset you need :\n", "1. the output of the `GetEEGPartitionNumber` function (used to calculate the length)\n", "2. the output of the `GetEEGSplitTable` function (used to extract data of a specific split set)\n", "3. the partition spec as a **list** (format: \\[freq, window, overlap\\])\n", "\n", "other optional important parameters are:\n", "1. the mode (train, validation, test), used to select data from a specific split set\n", "2. the boolean 'supervised', used to tell if indexing using `[]` (the `__getitem__` method) must extract a label associated to the sample\n", "3. the label_on_load argument, used to tell if indexing using `[]` (the `__getitem__` method) will get the label from the loading function or it must call a custom function\n", "\n", "
\n", "TIP 1 \n", " \n", "the class EEGDataset also accept custom functions to load, transform and get label from the EEG files.\n", "\n", "
\n", "\n", "
\n", "TIP 2 \n", " \n", "if the label must be extracted from a dictionary, also with different files having the label inside a different key, check the label_key argument to handle that. \n", "\n", "
" ] }, { "cell_type": "markdown", "id": "45aef05d-8483-4ffa-aadf-7f3202f76486", "metadata": {}, "source": [ "**CASE 1: Pretraining - no label**" ] }, { "cell_type": "code", "execution_count": 8, "id": "80f58518", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 256])\n" ] } ], "source": [ "dataset_pretrain = dl.EEGDataset(\n", " EEGlen,\n", " EEGsplit, \n", " [freq, window, overlap], # split parameters must be given as list\n", " mode = 'train', # default, select all samples in the train set\n", " load_function = loadEEG, \n", " transform_function = transformEEG\n", ")\n", "sample_1 = dataset_pretrain[0] # Grab the first sample\n", "print(sample_1.shape) # Note: the sample is automatically converted in a Tensor " ] }, { "cell_type": "markdown", "id": "6e84996a-be9f-47be-964f-322ed90c78a9", "metadata": {}, "source": [ "
**CASE 2: FineTuning - with label**" ] }, { "cell_type": "code", "execution_count": 9, "id": "914261e4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 256]) 1\n" ] } ], "source": [ "dataset_finetune = dl.EEGDataset(\n", " EEGlen,\n", " EEGsplit, \n", " [freq, window, overlap], # split parameters must be given as list\n", " mode = 'train', # the default, select all samples in the train set\n", " supervised = True, # !!!!IMPORTANT!!!!\n", " load_function = loadEEG,\n", " optional_load_fun_args= [True], # tells loadEEG to return a label\n", " transform_function=transformEEG,\n", " label_on_load=True, # the default \n", " )\n", "sample_2, label_2 = dataset_finetune[0] # grab the first sample\n", "print(sample_2.shape, label_2) # now we also have a label" ] }, { "cell_type": "markdown", "id": "f9f8388b-b9e2-431d-8ec6-f1147c9762a2", "metadata": {}, "source": [ "## The EEGSampler\n", "\n", "Although optional, you can also create a custom sampler. The sampler allows creating 2 different types of iterator, which differently balance the trade-off between batch heterogeneity and batch creation speed:\n", "\n", "1. **Linear**: just returns a linear iterator. It is useful when you want to minimize the number of EEG file loading operations. However batches will contain cosecutive partitions of the same file, which could affect the operations of some layers like BatchNorm layers. To initialize the sampler in this mode, simply use the command
EEGSampler( EEGDataset, Mode=0)\n", "2. **Shuffled**: it returns a customized iterator. The iterator is constructed in this way:\n", " 1) Samples are shuffled at the file level;\n", " 2) Samples of the same file are shuffled;\n", " 3) Samples are rearranged based on the desired batch size and number of works. This step is performed to exploit the parallelization properties of the pytorch dataloader and reduce the number of loading operations. To initialize the sampler in this mode, simply use the command
EEGSampler( EEGDataset, BatchSize, Workers ) \n", "\n", "
\n", "TIP \n", " \n", "We suggest to use the linear iterator for validation and test purpose since it's faster and does not require any batch heterogeneity.\n", "\n", "
\n", "\n", "\n", "Here is a schematic representation of how Shuffled iterator is constructed, with **batch size = 5** and **workers = 4**\n", "\n", "![scheme](../Images/sampler_example.png \"Title\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "abdbddd5-346b-4b5d-860b-2e3459b41646", "metadata": {}, "outputs": [], "source": [ "sampler_linear = dl.EEGSampler(dataset_pretrain, Mode=0)\n", "sampler_custom = dl.EEGSampler(dataset_pretrain, 16, 4)" ] }, { "cell_type": "markdown", "id": "cee4b61a-751b-49d3-b422-ae8645816292", "metadata": {}, "source": [ "## Final Dataloader\n", "\n", "Now simply put all together and create your custom Dataloader. \n", "\n", "
\n", "WARNING \n", " \n", "If you have created a custom sampler, remember to also pass the same batch size and number of workers\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 11, "id": "57c84ba5-1bd9-4403-92ed-b7c3f791d6bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 2, 256])\n" ] } ], "source": [ "Final_Dataloader = DataLoader(\n", " dataset = dataset_pretrain,\n", " batch_size = 16,\n", " sampler = sampler_custom,\n", " num_workers = 0\n", ")\n", "\n", "for X in Final_Dataloader:\n", " print(X.shape)\n", " break" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }