pyRiemann: Biosignals classification with Riemannian geometry¶
pyRiemann is a Python machine learning package based on scikit-learn API. It provides a high-level interface for processing and classification of multivariate time series through the Riemannian geometry of symmetric positive definite (SPD) matrices.
pyRiemann aims at being a generic package for multivariate time series classification but has been designed around multichannel biosignals (like EEG, MEG or EMG) manipulation applied to brain-computer interface (BCI), transforming multichannel time series into covariance matrices, and classifying them using the Riemannian geometry of SPD matrices.
For a brief introduction to the ideas behind the package, you can read the introductory notes. More practical information is on the installation page. You may also want to browse the example gallery to get a sense for what you can do with pyRiemann and API reference to find out how.
To see the code or report a bug, please visit the github repository.
Content
Introduction to pyRiemann¶
What’s new in the package¶
A catalog of new features, improvements, and bug-fixes in each release.
v0.4 (Feb 2023)¶
Add exponential and logarithmic maps for three main metrics: ‘euclid’, ‘logeuclid’ and ‘riemann’.
pyriemann.utils.tangentspace.tangent_space()
is splitted in two steps: (i)log_map_*()
projecting SPD matrices into tangent space depending on the metric; and (ii)pyriemann.utils.tangentspace.upper()
taking the upper triangular part of matrices. Similarly,pyriemann.utils.tangentspace.untangent_space()
is splitted into (i)pyriemann.utils.tangentspace.unupper()
and (ii)exp_map_*()
. The different metrics for tangent space mapping can now be defined intopyriemann.tangentspace.TangentSpace
, then used fortransform()
as well as forinverse_transform()
. #195 by @qbarthelemyEnhance AJD: add
init
topyriemann.utils.ajd.ajd_pham()
andpyriemann.utils.ajd.rjd()
, addwarm_restart
topyriemann.spatialfilters.AJDC
. #196 by @qbarthelemyAdd parameter
sampling_method
topyriemann.datasets.sample_gaussian_spd()
, withrejection
accelerating 2x2 matrices generation. #198 by @Artim436Add geometric medians for Euclidean and Riemannian metrics:
pyriemann.utils.median_euclid()
andpyriemann.utils.median_riemann()
, and add an example in gallery to compare means and medians on synthetic datasets. #200 by @qbarthelemyAdd
score()
topyriemann.regression.KNearestNeighborRegressor
. #205 by @qbarthelemyAdd Transfer Learning module and examples, including RPA and MDWM. #189 by @plcrodrigues, @qbarthelemy and @sylvchev
Add class distinctiveness function to measure the distinctiveness between classes on the manifold,
pyriemann.classification.class_distinctiveness()
, and complete an example in gallery to show how it works on synthetic datasets. #215 by @MSYamamotoAdd example on ensemble learning applied to functional connectivity, and add
pyriemann.utils.base.nearest_sym_pos_def()
. #202 by @mccorsi and @sylvchevAdd kernel matrices representation
pyriemann.estimation.Kernels
and complete example comparing estimators. #217 by @qbarthelemyAdd a new covariance estimator, robust fixed point covariance, and add kwds arguments for all covariance based functions and classes. #220 by @qbarthelemy
Add example in gallery on frequency band selection using class distinctiveness measure. #219 by @MSYamamoto
Add
pyriemann.utils.covariance.covariance_mest()
supporting three robust M-estimators (Huber, Student-t and Tyler) and available for all covariance based functions and classes; and add an example on robust covariance estimation for corrupted data. Add alsopyriemann.utils.distance.distance_mahalanobis()
between between vectors and a Gaussian distribution. #223 by @qbarthelemy
v0.3 (July 2022)¶
Correct spectral estimation in
pyriemann.utils.covariance.cross_spectrum()
to obtain equivalence with SciPy. #133 by @qbarthelemyAdd instantaneous, lagged and imaginary coherences in
pyriemann.utils.covariance.coherence()
andpyriemann.estimation.Coherences
. #132 by @qbarthelemyAdd
partial_fit
inpyriemann.clustering.Potato
, useful for an online update; and update example on artifact detection. #133 by @qbarthelemyDeprecate
pyriemann.utils.viz.plot_confusion_matrix()
as sklearn integrate its own version. #135 by @sylvchevAdd Ando-Li-Mathias mean estimation in
pyriemann.utils.mean.mean_covariance()
. #56 by @sylvchevAdd Schaefer-Strimmer covariance estimator in
pyriemann.utils.covariance.covariances()
, and an example to compare estimators #59 by @sylvchevRefactor tests + fix refit of
pyriemann.tangentspace.TangentSpace
. #136 by @sylvchevAdd
pyriemann.clustering.PotatoField
, and an example on artifact detection. #142 by @qbarthelemyAdd sampling SPD matrices from a Riemannian Gaussian distribution in
pyriemann.datasets.sample_gaussian_spd()
. #140 by @plcrodriguesAdd new function
pyriemann.datasets.make_gaussian_blobs()
for generating random datasets with SPD matrices. #140 by @plcrodriguesAdd module
pyriemann.utils.viz
in API, addpyriemann.utils.viz.plot_waveforms()
, and add an example on ERP visualization. #144 by @qbarthelemyAdd a special form covariance matrix
pyriemann.utils.covariance.covariances_X()
. #147 by @qbarthelemyAdd masked and NaN means with Riemannian metric:
pyriemann.utils.mean.maskedmean_riemann()
andpyriemann.utils.mean.nanmean_riemann()
. #149 by @qbarthelemy and @sylvchevAdd
corr
option inpyriemann.utils.covariance.normalize()
, to normalize covariance into correlation matrices. #153 by @qbarthelemyAdd block covariance matrix:
pyriemann.estimation.BlockCovariances
andpyriemann.utils.covariance.block_covariances()
. #154 by @gabelsteinAdd Riemannian Locally Linear Embedding:
pyriemann.embedding.LocallyLinearEmbedding
andpyriemann.embedding.locally_linear_embedding()
. #159 by @gabelsteinAdd Riemannian Kernel Function:
pyriemann.utils.kernel.kernel_riemann()
. #159 by @gabelsteinFix
fit
inpyriemann.channelselection.ElectrodeSelection
. #166 by @qbarthelemyAdd power mean estimation in
pyriemann.utils.mean.mean_power()
. #170 by @qbarthelemy and @plcrodriguesAdd example in gallery to compare classifiers on synthetic datasets. #175 by @qbarthelemy
Add
predict_proba
inpyriemann.classification.KNearestNeighbor
, and correct attributeclasses_
. #171 by @qbarthelemyAdd Riemannian Support Vector Machine classifier:
pyriemann.classification.SVC
. #175 by @gabelstein and @qbarthelemyAdd Riemannian Support Vector Machine regressor:
pyriemann.regression.SVR
. #175 by @gabelstein and @qbarthelemyAdd K-Nearest-Neighbor regressor:
pyriemann.regression.KNearestNeighborRegressor
. #164 by @gabelstein, @qbarthelemy and @agramfortAdd Minimum Distance to Mean Field classifier:
pyriemann.classification.MeanField
. #172 by @qbarthelemy and @plcrodriguesAdd example on principal geodesic analysis (PGA) for SSVEP classification. #169 by @qbarthelemy
Add
pyriemann.utils.distance.distance_harmonic()
, and sort functions by their names in code, doc and tests. #183 by @qbarthelemyParallelize functions for dataset generation:
pyriemann.datasets.make_gaussian_blobs()
. #179 by @sylvchevFix dispersion when generating datasets:
pyriemann.datasets.sample_gaussian_spd()
. #179 by @sylvchevEnhance base and distance functions, to process ndarrays of SPD matrices. #186 and #187 by @qbarthelemy
Enhance utils functions, to process ndarrays of SPD matrices. #190 by @qbarthelemy
Enhance means functions, with faster implementations and warning when convergence is not reached. #188 by @qbarthelemy
v0.2.7 (June 2021)¶
Add example on SSVEP classification
Fix compatibility with scikit-learn v0.24
Correct probas of
pyriemann.classification.MDM
Add
predict_proba
forpyriemann.clustering.Potato
, and an example on artifact detectionAdd weights to Pham’s AJD algorithm
pyriemann.utils.ajd.ajd_pham()
Add
pyriemann.utils.covariance.cross_spectrum()
, fixpyriemann.utils.covariance.cospectrum()
;pyriemann.utils.covariance.coherence()
output is kept unchangedAdd
pyriemann.spatialfilters.AJDC
for BSS and gBSS, with an example on artifact correctionAdd
pyriemann.preprocessing.Whitening
, with optional dimension reduction
v0.2.6 (March 2020)¶
Updated for better Scikit-Learn v0.22 support
v0.2.5 (January 2018)¶
Added BilinearFilter
Added a permutation test for generic scikit-learn estimator
Stats module refactoring, with distance based t-test and f-test
Removed two way permutation test
Added FlatChannelRemover
Support for python 3.5 in travis
Added Shrinkage transformer
Added Coherences transformer
Added Embedding class.
v0.2.4 (June 2016)¶
Improved documentation
Added TSclassifier for out-of the box tangent space classification.
Added Wasserstein distance and mean.
Added NearestNeighbor classifier.
Added Softmax probabilities for MDM.
Added CSP for covariance matrices.
Added Approximate Joint diagonalization algorithms (JADE, PHAM, UWEDGE).
Added ALE mean.
Added Multiclass CSP.
API: param name changes in CospCovariances to comply to Scikit-Learn.
API: attributes name changes in most modules to comply to the Scikit-Learn naming convention.
Added HankelCovariances estimation
Added SPoC spatial filtering
Added Harmonic mean
Added Kullback leibler mean
v0.2.3 (November 2015)¶
Added multiprocessing for MDM with joblib.
Added kullback-leibler divergence.
Added Riemannian Potato.
Added sample_weight for mean estimation and MDM.
Installing pyRiemann¶
The easiest way to install a stable version of pyRiemann is through pypi, the python package manager :
pip install pyriemann
For a bleeding edge version, you can clone the source code on github and install directly the package from source.
pip install -e .
The install script will install the required dependencies. If you want also to build the documentation and to run the test locally, you could install all development dependencies with
pip install -e .[docs,tests]
If you use a zsh shell, you need to write pip install -e .[docs,tests]. If you do not know what zsh is, you could use the above command.
Dependencies¶
Python (>= 3.7)
Mandatory dependencies¶
Recommended dependencies¶
These dependencies are recommanded to use the plotting functions of pyriemann or to run examples and tutorials, but they are not mandatory:
Examples Gallery¶
Contents
Classification of ERP¶
Using Riemannian geometry for classifying event-related potentials (ERP).
Classification of SSVEP¶
Using Riemannian geometry for classifying steady-state visually evoked potentials (SSVEP).

Visualization of SSVEP-based BCI Classification in Tangent Space
Artifact management¶
Using Riemannian geometry to detect, reject or correct artifacts.

Artifact Correction by AJDC-based Blind Source Separation

Online Artifact Detection with Riemannian Potato Field
Classification of motor imagery¶
Using Riemannian geometry for classifying motor imagery.

Frequency band selection on the manifold for motor imagery classification
Covariance estimation¶
Examples for covariance matrix estimation.

Compare covariance and kernel estimators with different time windows
Simulated data¶
Examples using datasets sampled from known probability distributions.

Sample from the Riemannian Gaussian distribution in the SPD manifold

Classification accuracy vs class distinctiveness vs class separability
Permutation test¶
Permutation test with pyRiemann.
Transfer learning¶
Using Riemannian geometry for transfer learning and domain adaptation.

Plot the data transformations in the Riemannian Procrustes Analysis
Classification of ERP¶
Using Riemannian geometry for classifying event-related potentials (ERP).
Note
Click here to download the full example code
Embedding ERP MEG data in 2D Euclidean space¶
Riemannian embeddings via Laplacian Eigenmaps (LE) and Locally Linear Embedding (LLE) of a set of ERP data. Embedding via Laplacian Eigenmaps is referred to as Spectral Embedding (SE).
Locally Linear Embedding (LLE) assumes that the local neighborhood of a point on the manifold can be well approximated by the affine subspace spanned by the k-nearest neighbors of the point and finds a low-dimensional embedding of the data based on these affine approximations.
Laplacian Eigenmaps (LE) are based on computing the low dimensional representation that best preserves locality instead of local linearity in LLE 1.
# Authors: Pedro Rodrigues <pedro.rodrigues01@gmail.com>,
# Gabriel Wagner vom Berg <gabriel@bccn-berlin.de>
# License: BSD (3-clause)
from pyriemann.estimation import XdawnCovariances
from pyriemann.utils.viz import plot_embedding
import mne
from mne import io
from mne.datasets import sample
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
print(__doc__)
Set parameters and read data
data_path = str(sample.data_path())
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
tmin, tmax = -0., 1
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
raw.filter(2, None, method='iir') # replace baselining with high-pass
events = mne.read_events(event_fname)
raw.info['bads'] = ['MEG 2443'] # set bad channels
picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False,
exclude='bads')
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
picks=picks, baseline=None, preload=True, verbose=False)
X = epochs.get_data()
y = epochs.events[:, -1]
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 2 Hz
IIR filter parameters
---------------------
Butterworth highpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 8 (effective, after forward-backward)
- Cutoff at 2.00 Hz: -6.02 dB
Embedding of Xdawn covariance matrices
nfilter = 4
xdwn = XdawnCovariances(estimator='scm', nfilter=nfilter)
split = train_test_split(X, y, train_size=0.25, random_state=42)
Xtrain, Xtest, ytrain, ytest = split
covs = xdwn.fit(Xtrain, ytrain).transform(Xtest)
Laplacian Eigenmaps (LE), also called Spectral Embedding (SE)¶
plot_embedding(covs, ytest, metric='riemann', embd_type='Spectral',
normalize=True)
plt.show()

Locally Linear Embedding (LLE)¶
plot_embedding(covs, ytest, metric='riemann', embd_type='LocallyLinear',
normalize=False)
plt.show()

References¶
- 1
Clustering and dimensionality reduction on Riemannian manifolds A. Goh and R Vidal, in 2008 IEEE Conference on Computer Vision and Pattern Recognition.
Total running time of the script: ( 0 minutes 20.283 seconds)
Note
Click here to download the full example code
Display ERP¶
Different ways to display a multichannel event-related potential (ERP).
# Authors: Quentin Barthélemy
#
# License: BSD (3-clause)
import numpy as np
import mne
from matplotlib import pyplot as plt
from pyriemann.utils.viz import plot_waveforms
Load EEG data¶
# Set filenames
data_path = str(mne.datasets.sample.data_path())
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"
# Read raw data, select occipital channels and high-pass filter signal
raw = mne.io.Raw(raw_fname, preload=True, verbose=False)
raw.pick_channels(['EEG 057', 'EEG 058', 'EEG 059'], ordered=True)
raw.rename_channels({'EEG 057': 'O1', 'EEG 058': 'Oz', 'EEG 059': 'O2'})
n_channels = len(raw.ch_names)
raw.filter(1.0, None, method="iir")
# Read epochs and get responses to left visual field stimulus
tmin, tmax = -0.1, 0.8
epochs = mne.Epochs(
raw, mne.read_events(event_fname), {'vis_l': 3}, tmin, tmax, proj=False,
baseline=None, preload=True, verbose=False)
X = 5e5 * epochs.get_data()
print('Number of trials:', X.shape[0])
times = np.linspace(tmin, tmax, num=X.shape[2])
plt.rcParams["figure.figsize"] = (7, 12)
ylims = []
Removing projector <Projection | PCA-v1, active : False, n_channels : 102>
Removing projector <Projection | PCA-v2, active : False, n_channels : 102>
Removing projector <Projection | PCA-v3, active : False, n_channels : 102>
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 1 Hz
IIR filter parameters
---------------------
Butterworth highpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 8 (effective, after forward-backward)
- Cutoff at 1.00 Hz: -6.02 dB
Number of trials: 73
Plot all trials¶
This kind of plot is a little bit messy.
fig = plot_waveforms(X, 'all', times=times, alpha=0.3)
fig.suptitle('Plot all trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_xlim(tmin, tmax)
ylims.append(fig.axes[i_channel].get_ylim())
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()

Plot central tendency and dispersion of trials¶
This kind of plot is well-spread, but mean and standard deviation can be contaminated by artifacts, and they make a symmetric assumption on amplitude distribution.
fig = plot_waveforms(X, 'mean+/-std', times=times)
fig.suptitle('Plot mean+/-std of trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_xlim(tmin, tmax)
fig.axes[i_channel].set_ylim(ylims[i_channel])
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()

Plot histogram of trials¶
This plot estimates a 2D histogram of trials 1.
fig = plot_waveforms(X, 'hist', times=times, n_bins=25, cmap=plt.cm.Greys)
fig.suptitle('Plot histogram of trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_ylim(ylims[i_channel])
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()

References¶
- 1
Improved estimation of EEG evoked potentials by jitter compensation and enhancing spatial filters A. Souloumiac and B. Rivet. 2013 IEEE International Conference on Acoustics, Speech and Signal Processing.
Total running time of the script: ( 0 minutes 1.135 seconds)
Note
Click here to download the full example code
ERP EEG decoding in Tangent space.¶
Decoding applied to EEG data in sensor space decomposed using Xdawn. After spatial filtering, covariances matrices are estimated, then projected in the tangent space and classified with a logistic regression.
# Authors: Alexandre Barachant <alexandre.barachant@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
import mne
from mne import io
from mne.datasets import sample
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.pipeline import make_pipeline
from matplotlib import pyplot as plt
print(__doc__)
Set parameters and read data
data_path = str(sample.data_path())
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"
tmin, tmax = -0.0, 1
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
raw.filter(2, None, method="iir") # replace baselining with high-pass
events = mne.read_events(event_fname)
raw.info["bads"] = ["MEG 2443"] # set bad channels
picks = mne.pick_types(
raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads"
)
# Read epochs
epochs = mne.Epochs(
raw,
events,
event_id,
tmin,
tmax,
proj=False,
picks=picks,
baseline=None,
preload=True,
verbose=False,
)
labels = epochs.events[:, -1]
evoked = epochs.average()
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 2 Hz
IIR filter parameters
---------------------
Butterworth highpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 8 (effective, after forward-backward)
- Cutoff at 2.00 Hz: -6.02 dB
Removing projector <Projection | PCA-v1, active : False, n_channels : 102>
Removing projector <Projection | PCA-v2, active : False, n_channels : 102>
Removing projector <Projection | PCA-v3, active : False, n_channels : 102>
Decoding in tangent space with a logistic regression
n_components = 2 # pick some components
# Define a monte-carlo cross-validation generator (reduce variance):
cv = KFold(n_splits=10, shuffle=True, random_state=42)
epochs_data = epochs.get_data()
clf = make_pipeline(
XdawnCovariances(n_components),
TangentSpace(metric="riemann"),
LogisticRegression(),
)
preds = np.zeros(len(labels))
for train_idx, test_idx in cv.split(epochs_data):
y_train, y_test = labels[train_idx], labels[test_idx]
clf.fit(epochs_data[train_idx], y_train)
preds[test_idx] = clf.predict(epochs_data[test_idx])
# Printing the results
acc = np.mean(preds == labels)
print("Classification accuracy: %f " % (acc))
names = ["audio left", "audio right", "vis left", "vis right"]
cm = confusion_matrix(labels, preds)
ConfusionMatrixDisplay(cm, display_labels=names).plot()
plt.show()

Classification accuracy: 0.750000
Total running time of the script: ( 0 minutes 6.134 seconds)
Note
Click here to download the full example code
Multiclass MEG ERP Decoding¶
Decoding applied to MEG data in sensor space decomposed using Xdawn. After spatial filtering, covariances matrices are estimated and classified by the MDM algorithm (Nearest centroid).
4 Xdawn spatial patterns (1 for each class) are displayed, as per the for mean-covariance matrices used by the classification algorithm.
# Authors: Alexandre Barachant <alexandre.barachant@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from matplotlib import pyplot as plt
from pyriemann.estimation import XdawnCovariances
from pyriemann.classification import MDM
import mne
from mne import io
from mne.datasets import sample
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay,
)
from sklearn.model_selection import KFold
from sklearn.pipeline import make_pipeline
print(__doc__)
Set parameters and read data
data_path = str(sample.data_path())
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"
tmin, tmax = -0.0, 1
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True)
raw.filter(2, None, method="iir") # replace baselining with high-pass
events = mne.read_events(event_fname)
raw.info["bads"] = ["MEG 2443"] # set bad channels
picks = mne.pick_types(
raw.info, meg="grad", eeg=False, stim=False, eog=False, exclude="bads"
)
# Read epochs
epochs = mne.Epochs(
raw,
events,
event_id,
tmin,
tmax,
proj=False,
picks=picks,
baseline=None,
preload=True,
verbose=False,
)
labels = epochs.events[:, -1]
evoked = epochs.average()
Opening raw data file /home/docs/mne_data/MNE-sample-data/MEG/sample/sample_audvis_filt-0-40_raw.fif...
Read a total of 4 projection items:
PCA-v1 (1 x 102) idle
PCA-v2 (1 x 102) idle
PCA-v3 (1 x 102) idle
Average EEG reference (1 x 60) idle
Range : 6450 ... 48149 = 42.956 ... 320.665 secs
Ready.
Reading 0 ... 41699 = 0.000 ... 277.709 secs...
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 2 Hz
IIR filter parameters
---------------------
Butterworth highpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 8 (effective, after forward-backward)
- Cutoff at 2.00 Hz: -6.02 dB
Removing projector <Projection | PCA-v1, active : False, n_channels : 102>
Removing projector <Projection | PCA-v2, active : False, n_channels : 102>
Removing projector <Projection | PCA-v3, active : False, n_channels : 102>
Removing projector <Projection | Average EEG reference, active : False, n_channels : 60>
Decoding with Xdawn + MDM
n_components = 3 # pick some components
# Define a monte-carlo cross-validation generator (reduce variance):
cv = KFold(n_splits=10, shuffle=True, random_state=42)
pr = np.zeros(len(labels))
epochs_data = epochs.get_data()
print("Multiclass classification with XDAWN + MDM")
clf = make_pipeline(XdawnCovariances(n_components), MDM())
for train_idx, test_idx in cv.split(epochs_data):
y_train, y_test = labels[train_idx], labels[test_idx]
clf.fit(epochs_data[train_idx], y_train)
pr[test_idx] = clf.predict(epochs_data[test_idx])
print(classification_report(labels, pr))
Multiclass classification with XDAWN + MDM
precision recall f1-score support
1 0.89 0.93 0.91 72
2 0.90 0.89 0.90 73
3 0.92 0.96 0.94 73
4 0.97 0.90 0.93 70
accuracy 0.92 288
macro avg 0.92 0.92 0.92 288
weighted avg 0.92 0.92 0.92 288
plot the spatial patterns
xd = XdawnCovariances(n_components)
xd.fit(epochs_data, labels)
info = evoked.copy().resample(1).info # make it 1Hz for plotting
patterns = mne.EvokedArray(
data=xd.Xd_.patterns_.T, info=info
)
patterns.plot_topomap(
times=[0, n_components, 2 * n_components, 3 * n_components],
ch_type="grad",
colorbar=False,
size=1.5,
time_format="Pattern %d"
)

<MNEFigure size 900x287.5 with 4 Axes>
plot the confusion matrix
names = ["audio left", "audio right", "vis left", "vis right"]
cm = confusion_matrix(labels, pr)
ConfusionMatrixDisplay(cm, display_labels=names).plot()
plt.show()

Total running time of the script: ( 0 minutes 12.535 seconds)
Classification of SSVEP¶
Using Riemannian geometry for classifying steady-state visually evoked potentials (SSVEP).

Visualization of SSVEP-based BCI Classification in Tangent Space
Note
Click here to download the full example code
Offline SSVEP-based BCI Multiclass Prediction¶
Building extended covariance matrices for SSVEP-based BCI. The obtained matrices are shown. A Minimum Distance to Mean classifier is trained to predict a 4-class problem for an offline setup.
# Authors: Sylvain Chevallier <sylvain.chevallier@uvsq.fr>,
# Emmanuel Kalunga, Quentin Barthélemy, David Ojeda
#
# License: BSD (3-clause)
import numpy as np
import matplotlib.pyplot as plt
from mne import find_events, Epochs
from mne.io import Raw
from sklearn.model_selection import cross_val_score, RepeatedKFold
from pyriemann.estimation import BlockCovariances
from pyriemann.utils.mean import mean_riemann
from pyriemann.classification import MDM
from helpers.ssvep_helpers import download_data, extend_signal
Loading EEG data¶
The data are loaded through a MNE loader
# Download data
destination = download_data(subject=12, session=1)
# Read data in MNE Raw and numpy format
raw = Raw(destination, preload=True, verbose='ERROR')
events = find_events(raw, shortest_event=0, verbose=False)
raw = raw.pick_types(eeg=True)
event_id = {'13 Hz': 2, '17 Hz': 4, '21 Hz': 3, 'resting-state': 1}
sfreq = int(raw.info['sfreq'])
eeg_data = raw.get_data()
Using default location ~/mne_data for ssvep...
0%| | 0.00/3.33M [00:00<?, ?B/s]
1%|▎ | 31.7k/3.33M [00:00<00:15, 216kB/s]
2%|▋ | 64.5k/3.33M [00:00<00:14, 219kB/s]
4%|█▌ | 130k/3.33M [00:00<00:09, 322kB/s]
6%|██▎ | 196k/3.33M [00:00<00:08, 370kB/s]
8%|███ | 261k/3.33M [00:00<00:07, 396kB/s]
10%|███▊ | 327k/3.33M [00:00<00:07, 412kB/s]
12%|████▌ | 392k/3.33M [00:01<00:06, 422kB/s]
14%|█████▎ | 458k/3.33M [00:01<00:06, 429kB/s]
16%|██████▏ | 523k/3.33M [00:01<00:06, 433kB/s]
18%|██████▉ | 589k/3.33M [00:01<00:06, 435kB/s]
20%|███████▋ | 654k/3.33M [00:01<00:06, 437kB/s]
22%|████████▍ | 720k/3.33M [00:01<00:05, 438kB/s]
24%|█████████▏ | 785k/3.33M [00:01<00:05, 440kB/s]
26%|█████████▉ | 851k/3.33M [00:02<00:05, 440kB/s]
28%|██████████▋ | 916k/3.33M [00:02<00:05, 440kB/s]
30%|███████████▌ | 982k/3.33M [00:02<00:05, 440kB/s]
31%|███████████▉ | 1.05M/3.33M [00:02<00:05, 440kB/s]
33%|████████████▋ | 1.11M/3.33M [00:02<00:05, 440kB/s]
35%|█████████████▍ | 1.18M/3.33M [00:02<00:04, 440kB/s]
37%|██████████████▏ | 1.24M/3.33M [00:02<00:04, 440kB/s]
39%|██████████████▉ | 1.31M/3.33M [00:03<00:04, 440kB/s]
41%|███████████████▋ | 1.38M/3.33M [00:03<00:04, 441kB/s]
43%|████████████████▎ | 1.42M/3.33M [00:03<00:04, 408kB/s]
45%|█████████████████ | 1.49M/3.33M [00:03<00:04, 418kB/s]
46%|█████████████████▌ | 1.54M/3.33M [00:03<00:04, 392kB/s]
48%|██████████████████▏ | 1.59M/3.33M [00:03<00:04, 374kB/s]
50%|██████████████████▉ | 1.65M/3.33M [00:04<00:04, 395kB/s]
52%|███████████████████▋ | 1.72M/3.33M [00:04<00:03, 409kB/s]
54%|████████████████████▍ | 1.78M/3.33M [00:04<00:03, 419kB/s]
56%|█████████████████████▏ | 1.85M/3.33M [00:04<00:03, 426kB/s]
58%|█████████████████████▊ | 1.92M/3.33M [00:04<00:03, 430kB/s]
60%|██████████████████████▌ | 1.98M/3.33M [00:04<00:03, 433kB/s]
62%|███████████████████████▎ | 2.05M/3.33M [00:04<00:02, 433kB/s]
63%|████████████████████████ | 2.11M/3.33M [00:05<00:02, 436kB/s]
65%|████████████████████████▊ | 2.18M/3.33M [00:05<00:02, 437kB/s]
67%|█████████████████████████▌ | 2.24M/3.33M [00:05<00:02, 438kB/s]
69%|██████████████████████████▎ | 2.31M/3.33M [00:05<00:02, 440kB/s]
71%|██████████████████████████▉ | 2.36M/3.33M [00:05<00:02, 408kB/s]
73%|███████████████████████████▋ | 2.42M/3.33M [00:05<00:02, 417kB/s]
75%|████████████████████████████▍ | 2.49M/3.33M [00:05<00:01, 425kB/s]
77%|█████████████████████████████▏ | 2.55M/3.33M [00:06<00:01, 430kB/s]
79%|█████████████████████████████▉ | 2.62M/3.33M [00:06<00:01, 434kB/s]
80%|██████████████████████████████▍ | 2.67M/3.33M [00:06<00:01, 403kB/s]
82%|███████████████████████████████▏ | 2.74M/3.33M [00:06<00:01, 415kB/s]
84%|███████████████████████████████▉ | 2.80M/3.33M [00:06<00:01, 423kB/s]
86%|████████████████████████████████▋ | 2.87M/3.33M [00:06<00:01, 428kB/s]
88%|█████████████████████████████████▍ | 2.93M/3.33M [00:06<00:00, 433kB/s]
90%|██████████████████████████████████▏ | 3.00M/3.33M [00:07<00:00, 435kB/s]
92%|██████████████████████████████████▉ | 3.06M/3.33M [00:07<00:00, 437kB/s]
94%|███████████████████████████████████▋ | 3.13M/3.33M [00:07<00:00, 438kB/s]
95%|████████████████████████████████████▎ | 3.18M/3.33M [00:07<00:00, 407kB/s]
97%|█████████████████████████████████████ | 3.24M/3.33M [00:07<00:00, 417kB/s]
99%|█████████████████████████████████████▊| 3.31M/3.33M [00:07<00:00, 424kB/s]
0%| | 0.00/3.33M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 3.33M/3.33M [00:00<00:00, 6.26GB/s]
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/examples/SSVEP/helpers/ssvep_helpers.py:79: RuntimeWarning: Setting non-standard config type: "MNE_DATASETS_SSVEPEXO_PATH"
data_path = fetch_dataset(dataset_params, force_update=True)
Visualization of raw EEG data¶
Plot few seconds of signal from the Oz electrode using matplotlib
n_seconds = 2
time = np.linspace(0, n_seconds, n_seconds * sfreq,
endpoint=False)[np.newaxis, :]
plt.figure(figsize=(10, 4))
plt.plot(time.T, eeg_data[np.array(raw.ch_names) == 'Oz', :n_seconds*sfreq].T,
color='C0', lw=0.5)
plt.xlabel("Time (s)")
plt.ylabel(r"Oz ($\mu$V)")
plt.show()

And of all electrodes:
plt.figure(figsize=(10, 4))
for ch_idx, ch_name in enumerate(raw.ch_names):
plt.plot(time.T, eeg_data[ch_idx, :n_seconds*sfreq].T, lw=0.5,
label=ch_name)
plt.xlabel("Time (s)")
plt.ylabel(r"EEG ($\mu$V)")
plt.legend(loc='upper right')
plt.show()

With MNE, it is much easier to visualize the data
raw.plot(duration=n_seconds, start=0, n_channels=8, scalings={'eeg': 4e-2},
color={'eeg': 'steelblue'})

Using matplotlib as 2D backend.
<MNEBrowseFigure size 800x800 with 4 Axes>
Extended signals for spatial covariance¶
Using the approach proposed by 1, the SSVEP signal is extended to include the filtered signals for each stimulation frequency. We stack the filtered signals to build an extended signal.
# We stack the filtered signals to build an extended signal
frequencies = [13, 17, 21]
freq_band = 0.1
raw_ext = extend_signal(raw, frequencies, freq_band)
Creating RawArray with float64 data, n_channels=24, n_times=92384
Range : 0 ... 92383 = 0.000 ... 360.871 secs
Ready.
Plot the extended signal
raw_ext.plot(duration=n_seconds, start=14, n_channels=24,
scalings={'eeg': 5e-4}, color={'eeg': 'steelblue'})

<MNEBrowseFigure size 800x800 with 4 Axes>
Building Epochs and plotting 3 s of the signal from electrode Oz for a trial
epochs = Epochs(raw_ext, events, event_id, tmin=2, tmax=5, baseline=None)
n_seconds = 3
time = np.linspace(0, n_seconds, n_seconds * sfreq,
endpoint=False)[np.newaxis, :]
channels = range(0, len(raw_ext.ch_names), len(raw.ch_names))
plt.figure(figsize=(7, 5))
for f, c in zip(frequencies, channels):
plt.plot(epochs.get_data()[5, c, :].T, label=str(int(f))+' Hz')
plt.xlabel("Time (s)")
plt.ylabel(r"Oz after filtering ($\mu$V)")
plt.legend(loc='upper right')
plt.show()

Not setting metadata
32 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 32 events and 769 original time points ...
0 bad epochs dropped
Using data from preloaded Raw for 32 events and 769 original time points ...
Using data from preloaded Raw for 32 events and 769 original time points ...
As it can be seen on this example, the subject is watching the 13Hz stimulation and the EEG activity is showing an increase activity in this frequency band while other frequencies have lower amplitudes.
Spatial covariance for SSVEP¶
The covariance matrices will be estimated using the Ledoit-Wolf shrinkage estimator on the extended signal.
cov_ext_trials = BlockCovariances(estimator='lwf',
block_size=8).transform(epochs.get_data())
# This plot shows an example of a covariance matrix observed for each class:
ch_names = raw_ext.info['ch_names']
plt.figure(figsize=(7, 7))
for i, l in enumerate(event_id):
ax = plt.subplot(2, 2, i+1)
plt.imshow(cov_ext_trials[events[:, 2] == event_id[l]][0],
cmap=plt.get_cmap('RdBu_r'))
plt.title('Cov for class: '+l)
plt.xticks([])
if i == 0 or i == 2:
plt.yticks(np.arange(len(ch_names)), ch_names)
ax.tick_params(axis='both', which='major', labelsize=7)
else:
plt.yticks([])
plt.show()

Using data from preloaded Raw for 32 events and 769 original time points ...
It appears clearly that each class yields a different structure of the covariance matrix. Each stimulation (13, 17 and 21 Hz) generating higher covariance values for EEG signal filtered at the proper bandwith and no activation at all for the other bandwiths. The resting state, where the subject focus on the center of the display and far from all blinking stimulus, shows an activity with higher correlation in the 13Hz frequency and lower but still visible activity in the other bandwiths.
Classify with MDM¶
Plotting mean of each class
cov_centers = np.empty((len(event_id), 24, 24))
for i, l in enumerate(event_id):
cov_centers[i] = mean_riemann(cov_ext_trials[events[:, 2] == event_id[l]])
plt.figure(figsize=(7, 7))
for i, l in enumerate(event_id):
ax = plt.subplot(2, 2, i+1)
plt.imshow(cov_centers[i], cmap=plt.get_cmap('RdBu_r'))
plt.title('Cov mean for class: '+l)
plt.xticks([])
if i == 0 or i == 2:
plt.yticks(np.arange(len(ch_names)), ch_names)
ax.tick_params(axis='both', which='major', labelsize=7)
else:
plt.yticks([])
plt.show()

Minimum distance to mean is a simple and robust algorithm for BCI decoding. It reproduces results of 2 for the first session of subject 12.
print("Number of trials: {}".format(len(cov_ext_trials)))
cv = RepeatedKFold(n_splits=2, n_repeats=10, random_state=42)
mdm = MDM(metric=dict(mean='riemann', distance='riemann'))
scores = cross_val_score(mdm, cov_ext_trials, events[:, 2], cv=cv, n_jobs=1)
print("MDM accuracy: {:.2f}% +/- {:.2f}".format(np.mean(scores)*100,
np.std(scores)*100))
# The obtained results are 80.62% +/- 16.29 for this session, with a repeated
# 10-fold validation.
Number of trials: 32
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/stable/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached
warnings.warn('Convergence not reached')
MDM accuracy: 80.94% +/- 16.23
References¶
- 1
A New generation of Brain-Computer Interface Based on Riemannian Geometry M. Congedo, A. Barachant, A. Andreev. Research report, 2013.
- 2
Review of Riemannian distances and divergences, applied to SSVEP-based BCI S. Chevallier, E. K. Kalunga, Q. Barthélemy, E. Monacelli. Neuroinformatics, Springer, 2021, 19 (1), pp.93-106
Total running time of the script: ( 0 minutes 16.174 seconds)
Note
Click here to download the full example code
Visualization of SSVEP-based BCI Classification in Tangent Space¶
Project extended covariance matrices of SSVEP-based BCI in the tangent space, using principal geodesic analysis (PGA).
You should have a look to “Offline SSVEP-based BCI Multiclass Prediction” before this example.
# Authors: Quentin Barthélemy, Emmanuel Kalunga and Sylvain Chevallier
#
# License: BSD (3-clause)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mne import find_events, Epochs, make_fixed_length_epochs
from mne.io import Raw
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import PCA
from pyriemann.estimation import BlockCovariances
from pyriemann.classification import MDM
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import _add_alpha
from helpers.ssvep_helpers import download_data, extend_signal
clabel = ['resting-state', '13 Hz', '17 Hz', '21 Hz']
clist = plt.cm.viridis(np.array([0, 1, 2, 3])/3)
cmap = "viridis"
def plot_pga(ax, data, labels, centers):
sc = ax.scatter(data[:, 0], data[:, 1], c=labels, marker='P', cmap=cmap)
ax.scatter(
centers[:, 0], centers[:, 1], c=clist, marker='o', s=100, cmap=cmap
)
ax.set(xlabel='PGA, 1st axis', ylabel='PGA, 2nd axis')
for i in range(len(clabel)):
ax.scatter([], [], color=clist[i], marker='o', s=50, label=clabel[i])
ax.legend(loc='upper right')
return sc
Load EEG and extract covariance matrices for SSVEP¶
frequencies = [13, 17, 21]
freq_band = 0.1
events_id = {'13 Hz': 2, '17 Hz': 4, '21 Hz': 3, 'resting-state': 1}
duration = 2.5 # duration of epochs
interval = 0.25 # interval between successive epochs for online processing
# Subject 12: first 4 sessions for training, last session for test
# Training set
raw = Raw(download_data(subject=12, session=1), preload=True, verbose=False)
events = find_events(raw, shortest_event=0, verbose=False)
raw = raw.pick_types(eeg=True)
ch_count = len(raw.info['ch_names'])
raw_ext = extend_signal(raw, frequencies, freq_band)
epochs = Epochs(
raw_ext, events, events_id, tmin=2, tmax=5, baseline=None, verbose=False)
x_train = BlockCovariances(
estimator='lwf', block_size=ch_count).transform(epochs.get_data())
y_train = events[:, 2]
# Testing set
raw = Raw(download_data(subject=12, session=4), preload=True, verbose=False)
raw = raw.pick_types(eeg=True)
raw_ext = extend_signal(raw, frequencies, freq_band)
epochs = make_fixed_length_epochs(
raw_ext, duration=duration, overlap=duration - interval, verbose=False)
x_test = BlockCovariances(
estimator='lwf', block_size=ch_count).transform(epochs.get_data())
Creating RawArray with float64 data, n_channels=24, n_times=92384
Range : 0 ... 92383 = 0.000 ... 360.871 secs
Ready.
Using data from preloaded Raw for 32 events and 769 original time points ...
0 bad epochs dropped
0%| | 0.00/5.35M [00:00<?, ?B/s]
1%|▏ | 31.7k/5.35M [00:00<00:24, 216kB/s]
1%|▍ | 64.5k/5.35M [00:00<00:24, 219kB/s]
2%|▉ | 130k/5.35M [00:00<00:16, 322kB/s]
4%|█▍ | 196k/5.35M [00:00<00:13, 370kB/s]
5%|█▉ | 261k/5.35M [00:00<00:12, 396kB/s]
6%|██▍ | 327k/5.35M [00:00<00:12, 412kB/s]
7%|██▊ | 392k/5.35M [00:01<00:11, 422kB/s]
9%|███▎ | 458k/5.35M [00:01<00:11, 429kB/s]
10%|███▊ | 523k/5.35M [00:01<00:11, 433kB/s]
11%|████▎ | 589k/5.35M [00:01<00:10, 436kB/s]
12%|████▊ | 654k/5.35M [00:01<00:10, 438kB/s]
13%|█████▏ | 720k/5.35M [00:01<00:10, 439kB/s]
15%|█████▋ | 785k/5.35M [00:01<00:10, 440kB/s]
16%|██████▏ | 851k/5.35M [00:02<00:10, 441kB/s]
17%|██████▋ | 916k/5.35M [00:02<00:10, 441kB/s]
18%|███████▏ | 982k/5.35M [00:02<00:09, 440kB/s]
20%|███████▍ | 1.05M/5.35M [00:02<00:09, 441kB/s]
21%|███████▉ | 1.11M/5.35M [00:02<00:09, 441kB/s]
22%|████████▎ | 1.18M/5.35M [00:02<00:09, 441kB/s]
23%|████████▊ | 1.24M/5.35M [00:02<00:09, 440kB/s]
24%|█████████▎ | 1.31M/5.35M [00:03<00:09, 440kB/s]
26%|█████████▊ | 1.38M/5.35M [00:03<00:09, 441kB/s]
27%|██████████▏ | 1.44M/5.35M [00:03<00:08, 441kB/s]
28%|██████████▋ | 1.51M/5.35M [00:03<00:08, 442kB/s]
29%|███████████ | 1.56M/5.35M [00:03<00:09, 409kB/s]
30%|███████████▌ | 1.62M/5.35M [00:03<00:08, 419kB/s]
31%|███████████▊ | 1.67M/5.35M [00:04<00:09, 393kB/s]
32%|████████████▎ | 1.74M/5.35M [00:04<00:08, 407kB/s]
34%|████████████▊ | 1.80M/5.35M [00:04<00:08, 417kB/s]
35%|█████████████▎ | 1.87M/5.35M [00:04<00:08, 424kB/s]
36%|█████████████▋ | 1.93M/5.35M [00:04<00:07, 429kB/s]
37%|██████████████▏ | 2.00M/5.35M [00:04<00:07, 433kB/s]
38%|██████████████▌ | 2.05M/5.35M [00:04<00:08, 403kB/s]
39%|███████████████ | 2.11M/5.35M [00:05<00:07, 414kB/s]
41%|███████████████▍ | 2.18M/5.35M [00:05<00:07, 423kB/s]
42%|███████████████▉ | 2.24M/5.35M [00:05<00:07, 428kB/s]
43%|████████████████▍ | 2.31M/5.35M [00:05<00:07, 432kB/s]
44%|████████████████▊ | 2.37M/5.35M [00:05<00:06, 435kB/s]
46%|█████████████████▎ | 2.44M/5.35M [00:05<00:06, 437kB/s]
47%|█████████████████▊ | 2.51M/5.35M [00:05<00:06, 437kB/s]
48%|██████████████████▎ | 2.57M/5.35M [00:06<00:06, 438kB/s]
49%|██████████████████▌ | 2.62M/5.35M [00:06<00:06, 406kB/s]
50%|███████████████████ | 2.69M/5.35M [00:06<00:06, 417kB/s]
51%|███████████████████▌ | 2.75M/5.35M [00:06<00:06, 424kB/s]
53%|████████████████████ | 2.82M/5.35M [00:06<00:05, 429kB/s]
54%|████████████████████▍ | 2.88M/5.35M [00:06<00:05, 433kB/s]
55%|████████████████████▉ | 2.95M/5.35M [00:06<00:05, 435kB/s]
56%|█████████████████████▍ | 3.01M/5.35M [00:07<00:05, 437kB/s]
58%|█████████████████████▊ | 3.08M/5.35M [00:07<00:05, 438kB/s]
58%|██████████████████████▏ | 3.13M/5.35M [00:07<00:05, 407kB/s]
59%|██████████████████████▌ | 3.18M/5.35M [00:07<00:05, 384kB/s]
61%|███████████████████████ | 3.24M/5.35M [00:07<00:05, 401kB/s]
62%|███████████████████████▍ | 3.31M/5.35M [00:07<00:04, 414kB/s]
63%|███████████████████████▉ | 3.37M/5.35M [00:08<00:04, 421kB/s]
64%|████████████████████████▍ | 3.44M/5.35M [00:08<00:04, 427kB/s]
65%|████████████████████████▊ | 3.49M/5.35M [00:08<00:04, 399kB/s]
66%|█████████████████████████▏ | 3.55M/5.35M [00:08<00:04, 411kB/s]
67%|█████████████████████████▌ | 3.60M/5.35M [00:08<00:04, 388kB/s]
68%|█████████████████████████▉ | 3.65M/5.35M [00:08<00:04, 371kB/s]
69%|██████████████████████████▎ | 3.70M/5.35M [00:08<00:04, 360kB/s]
70%|██████████████████████████▊ | 3.77M/5.35M [00:09<00:04, 384kB/s]
72%|███████████████████████████▏ | 3.83M/5.35M [00:09<00:03, 401kB/s]
73%|███████████████████████████▋ | 3.90M/5.35M [00:09<00:03, 413kB/s]
74%|████████████████████████████▏ | 3.96M/5.35M [00:09<00:03, 421kB/s]
75%|████████████████████████████▌ | 4.01M/5.35M [00:09<00:03, 395kB/s]
76%|████████████████████████████▊ | 4.06M/5.35M [00:09<00:03, 376kB/s]
77%|█████████████████████████████▏ | 4.11M/5.35M [00:09<00:03, 363kB/s]
78%|█████████████████████████████▋ | 4.18M/5.35M [00:10<00:03, 387kB/s]
79%|██████████████████████████████▏ | 4.24M/5.35M [00:10<00:02, 403kB/s]
80%|██████████████████████████████▍ | 4.29M/5.35M [00:10<00:02, 382kB/s]
81%|██████████████████████████████▉ | 4.36M/5.35M [00:10<00:02, 400kB/s]
83%|███████████████████████████████▍ | 4.42M/5.35M [00:10<00:02, 413kB/s]
84%|███████████████████████████████▉ | 4.49M/5.35M [00:10<00:02, 421kB/s]
85%|████████████████████████████████▎ | 4.55M/5.35M [00:10<00:01, 428kB/s]
86%|████████████████████████████████▊ | 4.62M/5.35M [00:11<00:01, 432kB/s]
87%|█████████████████████████████████▏ | 4.67M/5.35M [00:11<00:01, 402kB/s]
88%|█████████████████████████████████▌ | 4.73M/5.35M [00:11<00:01, 414kB/s]
90%|██████████████████████████████████ | 4.80M/5.35M [00:11<00:01, 422kB/s]
91%|██████████████████████████████████▌ | 4.87M/5.35M [00:11<00:01, 428kB/s]
92%|███████████████████████████████████ | 4.93M/5.35M [00:11<00:00, 432kB/s]
93%|███████████████████████████████████▍ | 5.00M/5.35M [00:12<00:00, 434kB/s]
95%|███████████████████████████████████▉ | 5.06M/5.35M [00:12<00:00, 437kB/s]
96%|████████████████████████████████████▍ | 5.13M/5.35M [00:12<00:00, 438kB/s]
97%|████████████████████████████████████▉ | 5.19M/5.35M [00:12<00:00, 439kB/s]
98%|█████████████████████████████████████▎| 5.26M/5.35M [00:12<00:00, 439kB/s]
100%|█████████████████████████████████████▊| 5.32M/5.35M [00:12<00:00, 441kB/s]
0%| | 0.00/5.35M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 5.35M/5.35M [00:00<00:00, 16.4GB/s]
Creating RawArray with float64 data, n_channels=24, n_times=148544
Range : 0 ... 148543 = 0.000 ... 580.246 secs
Ready.
Using data from preloaded Raw for 2312 events and 640 original time points ...
0 bad epochs dropped
Classification with minimum distance to mean (MDM)¶
Classification for a 4-class SSVEP BCI, including resting-state class.
print("Number of training trials: {}".format(len(x_train)))
mdm = MDM(metric=dict(mean='riemann', distance='riemann'))
mdm.fit(x_train, y_train)
Number of training trials: 32
MDM(metric={'distance': 'riemann', 'mean': 'riemann'})
Projection in tangent space with principal geodesic analysis (PGA)¶
Project covariance matrices from the Riemannian manifold into the Euclidean tangent space at the grand average, and apply a principal component analysis (PCA) to obtain an unsupervised dimension reduction 1.
pga = make_pipeline(
TangentSpace(metric="riemann", tsupdate=False),
PCA(n_components=2)
)
ts_train = pga.fit_transform(x_train)
ts_means = pga.transform(np.array(mdm.covmeans_))
Offline training of MDM visualized by PGA¶
These figures show the trajectory on the tangent space taken by covariance matrices during a 4-class SSVEP experiment, and how they are classified epoch by epoch.
This figure reproduces Fig 3(c) of reference 2, showing training trials of best subject.
fig, ax = plt.subplots(figsize=(8, 8))
fig.suptitle('PGA of training set', fontsize=16)
plot_pga(ax, ts_train, y_train, ts_means)
plt.show()

Online classification by MDM visualized by PGA¶
This figure reproduces Fig 6 of reference 2, with an animation to imitate an online acquisition, processing and classification of EEG time-series.
Warning: 2 uses a curved based online classification, while a single trial classification is used here.
# Prepare data for online classification
test_visu = 50 # nb of matrices to display simultaneously
colors, ts_visu = [], np.empty([0, 2])
alphas = np.linspace(0, 1, test_visu)
fig, ax = plt.subplots(figsize=(8, 8))
fig.suptitle('PGA of testing set', fontsize=16)
pl = plot_pga(ax, ts_visu, colors, ts_means)
pl.axes.set_xlim(-5, 6)
pl.axes.set_ylim(-5, 5)

(-5.0, 5.0)
# Prepare animation for online classification
def online_classify(t):
global colors, ts_visu
# Online classification
y = mdm.predict(x_test[np.newaxis, t])
color = clist[int(y[0] - 1)]
ts_test = pga.transform(x_test[np.newaxis, t])
# Update data
colors.append(color)
ts_visu = np.vstack((ts_visu, ts_test))
if len(ts_visu) > test_visu:
colors.pop(0)
ts_visu = ts_visu[1:]
colors = _add_alpha(colors, alphas)
# Update plot
pl.set_offsets(np.c_[ts_visu[:, 0], ts_visu[:, 1]])
pl.set_color(colors)
return pl
interval_display = 1.0 # can be changed for a slower display
visu = FuncAnimation(fig, online_classify,
frames=range(0, len(x_test)),
interval=interval_display, blit=False, repeat=False)
# Plot online classification
# Plot complete visu: a dynamic display is required
plt.show()
# Plot only 10s, for animated documentation
try:
from IPython.display import HTML
except ImportError:
raise ImportError("Install IPython to plot animation in documentation")
plt.rcParams["animation.embed_limit"] = 10
HTML(visu.to_jshtml(fps=5, default_mode='loop'))
Animation size has reached 10525072 bytes, exceeding the limit of 10485760.0. If you're sure you want a larger animation embedded, set the animation.embed_limit rc parameter to a larger value (in MB). This and further frames will be dropped.