.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/motor-imagery/plot_frequency_band_selection.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_motor-imagery_plot_frequency_band_selection.py: ========================================================================= Frequency band selection on the manifold for motor imagery classification ========================================================================= Find optimal frequency band using class distinctiveness measure on the manifold and compare classification performance for motor imagery data to the baseline with no frequency band selection [1]_. .. GENERATED FROM PYTHON SOURCE LINES 12-30 .. code-block:: Python # Authors: Maria Sayu Yamamoto # # License: BSD (3-clause) from time import time import numpy as np from matplotlib import pyplot as plt from mne import Epochs, pick_types, events_from_annotations from mne.io import concatenate_raws from mne.io.edf import read_raw_edf from mne.datasets import eegbci from sklearn.model_selection import cross_val_score, ShuffleSplit from pyriemann.classification import MDM from pyriemann.estimation import Covariances from helpers.frequencybandselection_helpers import freq_selection_class_dis .. GENERATED FROM PYTHON SOURCE LINES 31-33 Set basic parameters and read data ---------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 33-51 .. code-block:: Python tmin, tmax = 0.5, 2.5 event_id = dict(T1=2, T2=3) subject = 1 runs = [4, 8, 12] # motor imagery: left hand vs right hand raw_files = [ read_raw_edf(f, preload=True) for f in eegbci.load_data(subject, runs) ] raw = concatenate_raws(raw_files) picks = pick_types( raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') # subsample elecs picks = picks[::2] # cross validation cv = ShuffleSplit(n_splits=1, test_size=0.2, random_state=42) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading EEGBCI data Download complete in 08s (7.4 MB) Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R04.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R08.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R12.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... .. GENERATED FROM PYTHON SOURCE LINES 52-57 Baseline pipeline without frequency band selection --------------------------------------------------- Apply band-pass filter using a wide frequency band, 5-35 Hz. Train and evaluate classifier. .. GENERATED FROM PYTHON SOURCE LINES 57-92 .. code-block:: Python t0 = time() raw_filter = raw.copy().filter(5., 35., method='iir', picks=picks, verbose=False) events, _ = events_from_annotations(raw_filter, event_id, verbose=False) # Read epochs (train will be done only between 0.5 and 2.5 s) epochs = Epochs( raw_filter, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True, verbose=False) labels = epochs.events[:, -1] - 2 # Get epochs epochs_data_baseline = epochs.get_data(units="uV", copy=False) # Compute covariance matrices cov_data_baseline = Covariances().transform(epochs_data_baseline) # Set classifier model = MDM(metric=dict(mean='riemann', distance='riemann')) # Classification with minimum distance to mean acc_baseline = cross_val_score(model, cov_data_baseline, labels, cv=cv, n_jobs=1) t1 = time() - t0 .. GENERATED FROM PYTHON SOURCE LINES 93-100 Pipeline with a frequency band selection based on the class distinctiveness --------------------------------------------------------------------------- Step1: Select frequency band maximizing class distinctiveness on training set. Define parameters for frequency band selection .. GENERATED FROM PYTHON SOURCE LINES 100-117 .. code-block:: Python t2 = time() freq_band = [5., 35.] sub_band_width = 4. sub_band_step = 2. alpha = 0.4 # Select frequency band using training set best_freq, all_class_dis = \ freq_selection_class_dis(raw, freq_band, sub_band_width, sub_band_step, alpha, tmin, tmax, picks, event_id, cv, return_class_dis=True, verbose=False) print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz') .. rst-class:: sphx-glr-script-out .. code-block:: none Selected frequency band : 9.0 - 15.0 Hz .. GENERATED FROM PYTHON SOURCE LINES 118-120 Step2: Train classifier using selected frequency band and evaluate performance using test set .. GENERATED FROM PYTHON SOURCE LINES 120-151 .. code-block:: Python # Apply band-pass filter using the best frequency band best_raw_filter = raw.copy().filter(best_freq[0][0], best_freq[0][1], method='iir', picks=picks, verbose=False) events, _ = events_from_annotations(best_raw_filter, event_id, verbose=False) # Read epochs (train will be done only between 0.5 and 2.5s) epochs = Epochs( best_raw_filter, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True, verbose=False) # Get epochs epochs_data_train = epochs.get_data(units="uV", copy=False) # Estimate covariance matrices cov_data = Covariances().transform(epochs_data_train) # Classification with minimum distance to mean acc = cross_val_score(model, cov_data, labels, cv=cv, n_jobs=1) t3 = time() - t2 .. GENERATED FROM PYTHON SOURCE LINES 152-154 Compare pipelines: accuracies and training times ------------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 154-164 .. code-block:: Python print("Classification accuracy without frequency band selection: " + f"{acc_baseline[0]:.02f}") print("Total computational time without frequency band selection: " + f"{t1:.5f} s") print("Classification accuracy with frequency band selection: " + f"{acc[0]:.02f}") print("Total computational time with frequency band selection: " + f"{t3:.5f} s") .. rst-class:: sphx-glr-script-out .. code-block:: none Classification accuracy without frequency band selection: 0.56 Total computational time without frequency band selection: 0.28910 s Classification accuracy with frequency band selection: 0.67 Total computational time with frequency band selection: 9.90647 s .. GENERATED FROM PYTHON SOURCE LINES 165-170 Plot selected frequency bands ----------------------------- Plot the class distinctiveness values for each sub_band, along with the highlight of the finally selected frequency band. .. GENERATED FROM PYTHON SOURCE LINES 170-203 .. code-block:: Python subband_fmin = list(np.arange(freq_band[0], freq_band[1] - sub_band_width + 1., sub_band_step)) subband_fmax = list(np.arange(freq_band[0] + sub_band_width, freq_band[1] + 1., sub_band_step)) n_subband = len(subband_fmin) x = list(range(0, n_subband, 1)) fig = plt.figure(figsize=(10, 5)) freq_start = subband_fmin.index(best_freq[0][0]) freq_end = subband_fmax.index(best_freq[0][1]) plt.subplot(1, 1, 1) plt.grid() plt.plot(x, all_class_dis[0], marker='o') plt.xticks(list(range(0, 14, 1)), [[int(i), int(j)] for i, j in zip(subband_fmin, subband_fmax)]) plt.axvspan(freq_start, freq_end, color="orange", alpha=0.3, label='Selected frequency band') plt.ylabel('Class distinctiveness') plt.xlabel('Filter bank [Hz]') plt.title('Class distinctiveness value of each subband') plt.legend(loc='upper right') fig.tight_layout() plt.show() print(f'Optimal frequency band for this subject is ' f'{best_freq[0][0]} - {best_freq[0][1]} Hz') .. image-sg:: /auto_examples/motor-imagery/images/sphx_glr_plot_frequency_band_selection_001.png :alt: Class distinctiveness value of each subband :srcset: /auto_examples/motor-imagery/images/sphx_glr_plot_frequency_band_selection_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Optimal frequency band for this subject is 9.0 - 15.0 Hz .. GENERATED FROM PYTHON SOURCE LINES 204-213 References ---------- .. [1] `Class-distinctiveness-based frequency band selection on the Riemannian manifold for oscillatory activity-based BCIs: preliminary results `_ M. S. Yamamoto, F. Lotte, F. Yger, and S. Chevallier. 44th Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC2022), 2022. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 18.980 seconds) .. _sphx_glr_download_auto_examples_motor-imagery_plot_frequency_band_selection.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_frequency_band_selection.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_frequency_band_selection.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_