.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/SSVEP/plot_classify_ssvep_mdm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_SSVEP_plot_classify_ssvep_mdm.py: ==================================================================== Offline SSVEP-based BCI Multiclass Prediction ==================================================================== Building extended covariance matrices for SSVEP-based BCI. The obtained matrices are shown. A Minimum Distance to Mean classifier is trained to predict a 4-class problem for an offline setup. .. GENERATED FROM PYTHON SOURCE LINES 10-27 .. code-block:: Python # Authors: Sylvain Chevallier , # Emmanuel Kalunga, Quentin Barthélemy, David Ojeda # # License: BSD (3-clause) import numpy as np import matplotlib.pyplot as plt from mne import find_events, Epochs from mne.io import Raw from sklearn.model_selection import cross_val_score, RepeatedKFold from pyriemann.estimation import BlockCovariances from pyriemann.utils.mean import mean_riemann from pyriemann.classification import MDM from helpers.ssvep_helpers import download_data, extend_signal .. GENERATED FROM PYTHON SOURCE LINES 28-32 Loading EEG data ---------------- The data are loaded through a MNE loader .. GENERATED FROM PYTHON SOURCE LINES 32-44 .. code-block:: Python # Download data destination = download_data(subject=12, session=1) # Read data in MNE Raw and numpy format raw = Raw(destination, preload=True, verbose='ERROR') events = find_events(raw, shortest_event=0, verbose=False) raw = raw.pick("eeg") event_id = {'13 Hz': 2, '17 Hz': 4, '21 Hz': 3, 'resting-state': 1} sfreq = int(raw.info['sfreq']) eeg_data = raw.get_data() .. rst-class:: sphx-glr-script-out .. code-block:: none Using default location ~/mne_data for ssvep... 0%| | 0.00/3.33M [00:00 .. GENERATED FROM PYTHON SOURCE LINES 79-85 Extended signals for spatial covariance --------------------------------------- Using the approach proposed by [1]_, the SSVEP signal is extended to include the filtered signals for each stimulation frequency. We stack the filtered signals to build an extended signal. .. GENERATED FROM PYTHON SOURCE LINES 85-93 .. code-block:: Python # We stack the filtered signals to build an extended signal frequencies = [13, 17, 21] freq_band = 0.1 raw_ext = extend_signal(raw, frequencies, freq_band) .. rst-class:: sphx-glr-script-out .. code-block:: none Creating RawArray with float64 data, n_channels=24, n_times=92384 Range : 0 ... 92383 = 0.000 ... 360.871 secs Ready. .. GENERATED FROM PYTHON SOURCE LINES 94-95 Plot the extended signal .. GENERATED FROM PYTHON SOURCE LINES 95-99 .. code-block:: Python raw_ext.plot(duration=n_seconds, start=14, n_channels=24, scalings={'eeg': 5e-4}, color={'eeg': 'steelblue'}) .. image-sg:: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_004.png :alt: plot classify ssvep mdm :srcset: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 100-101 Building Epochs and plotting 3 s of the signal from electrode Oz for a trial .. GENERATED FROM PYTHON SOURCE LINES 101-118 .. code-block:: Python epochs = Epochs( raw_ext, events, event_id, tmin=2, tmax=5, baseline=None ).get_data(copy=False) n_seconds = 3 time = np.linspace(0, n_seconds, n_seconds * sfreq, endpoint=False)[np.newaxis, :] channels = range(0, len(raw_ext.ch_names), len(raw.ch_names)) plt.figure(figsize=(7, 5)) for f, c in zip(frequencies, channels): plt.plot(epochs[5, c, :].T, label=str(int(f))+' Hz') plt.xlabel("Time (s)") plt.ylabel(r"Oz after filtering ($\mu$V)") plt.legend(loc='upper right') plt.show() .. image-sg:: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_005.png :alt: plot classify ssvep mdm :srcset: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Not setting metadata 32 matching events found No baseline correction applied 0 projection items activated Using data from preloaded Raw for 32 events and 769 original time points ... 0 bad epochs dropped .. GENERATED FROM PYTHON SOURCE LINES 119-128 As it can be seen on this example, the subject is watching the 13Hz stimulation and the EEG activity is showing an increase activity in this frequency band while other frequencies have lower amplitudes. Spatial covariance for SSVEP ---------------------------- The covariance matrices will be estimated using the Ledoit-Wolf shrinkage estimator on the extended signal. .. GENERATED FROM PYTHON SOURCE LINES 128-150 .. code-block:: Python cov_ext_trials = BlockCovariances( estimator='lwf', block_size=8 ).transform(epochs) # This plot shows an example of a covariance matrix observed for each class: ch_names = raw_ext.info['ch_names'] plt.figure(figsize=(7, 7)) for i, l in enumerate(event_id): ax = plt.subplot(2, 2, i+1) plt.imshow(cov_ext_trials[events[:, 2] == event_id[l]][0], cmap=plt.get_cmap('RdBu_r')) plt.title('Cov for class: '+l) plt.xticks([]) if i == 0 or i == 2: plt.yticks(np.arange(len(ch_names)), ch_names) ax.tick_params(axis='both', which='major', labelsize=7) else: plt.yticks([]) plt.show() .. image-sg:: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_006.png :alt: Cov for class: 13 Hz, Cov for class: 17 Hz, Cov for class: 21 Hz, Cov for class: resting-state :srcset: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 151-162 It appears clearly that each class yields a different structure of the covariance matrix. Each stimulation (13, 17 and 21 Hz) generating higher covariance values for EEG signal filtered at the proper bandwith and no activation at all for the other bandwiths. The resting state, where the subject focus on the center of the display and far from all blinking stimulus, shows an activity with higher correlation in the 13Hz frequency and lower but still visible activity in the other bandwiths. Classify with MDM ----------------- Plotting mean of each class .. GENERATED FROM PYTHON SOURCE LINES 162-180 .. code-block:: Python cov_centers = np.empty((len(event_id), 24, 24)) for i, l in enumerate(event_id): cov_centers[i] = mean_riemann(cov_ext_trials[events[:, 2] == event_id[l]]) plt.figure(figsize=(7, 7)) for i, l in enumerate(event_id): ax = plt.subplot(2, 2, i+1) plt.imshow(cov_centers[i], cmap=plt.get_cmap('RdBu_r')) plt.title('Cov mean for class: '+l) plt.xticks([]) if i == 0 or i == 2: plt.yticks(np.arange(len(ch_names)), ch_names) ax.tick_params(axis='both', which='major', labelsize=7) else: plt.yticks([]) plt.show() .. image-sg:: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_007.png :alt: Cov mean for class: 13 Hz, Cov mean for class: 17 Hz, Cov mean for class: 21 Hz, Cov mean for class: resting-state :srcset: /auto_examples/SSVEP/images/sphx_glr_plot_classify_ssvep_mdm_007.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 181-183 Minimum distance to mean is a simple and robust algorithm for BCI decoding. It reproduces results of [2]_ for the first session of subject 12. .. GENERATED FROM PYTHON SOURCE LINES 183-194 .. code-block:: Python print("Number of trials: {}".format(len(cov_ext_trials))) cv = RepeatedKFold(n_splits=2, n_repeats=10, random_state=42) mdm = MDM(metric=dict(mean='riemann', distance='riemann')) scores = cross_val_score(mdm, cov_ext_trials, events[:, 2], cv=cv, n_jobs=1) print("MDM accuracy: {:.2f}% +/- {:.2f}".format(np.mean(scores)*100, np.std(scores)*100)) # The obtained results are 80.62% +/- 16.29 for this session, with a repeated # 10-fold validation. .. rst-class:: sphx-glr-script-out .. code-block:: none Number of trials: 32 /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") MDM accuracy: 80.94% +/- 16.23 .. GENERATED FROM PYTHON SOURCE LINES 195-206 References ---------- .. [1] `A New generation of Brain-Computer Interface Based on Riemannian Geometry `_ M. Congedo, A. Barachant, A. Andreev. Research report, 2013. .. [2] `Review of Riemannian distances and divergences, applied to SSVEP-based BCI `_ S. Chevallier, E. K. Kalunga, Q. Barthélemy, E. Monacelli. Neuroinformatics, Springer, 2021, 19 (1), pp.93-106 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 11.216 seconds) .. _sphx_glr_download_auto_examples_SSVEP_plot_classify_ssvep_mdm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_classify_ssvep_mdm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_classify_ssvep_mdm.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_