Offline SSVEP-based BCI Multiclass Prediction

Building extended covariance matrices for SSVEP-based BCI. The obtained matrices are shown. A Minimum Distance to Mean classifier is trained to predict a 4-class problem for an offline setup.

# Authors: Sylvain Chevallier <>,
# Emmanuel Kalunga, Quentin Barthélemy, David Ojeda
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt
from mne import find_events, Epochs
from import Raw
from sklearn.model_selection import cross_val_score, RepeatedKFold

from pyriemann.estimation import BlockCovariances
from pyriemann.utils.mean import mean_riemann
from pyriemann.classification import MDM
from helpers.ssvep_helpers import download_data, extend_signal

Loading EEG data

The data are loaded through a MNE loader

# Download data
destination = download_data(subject=12, session=1)
# Read data in MNE Raw and numpy format
raw = Raw(destination, preload=True, verbose='ERROR')
events = find_events(raw, shortest_event=0, verbose=False)
raw = raw.pick("eeg")

event_id = {'13 Hz': 2, '17 Hz': 4, '21 Hz': 3, 'resting-state': 1}
sfreq = int(['sfreq'])
eeg_data = raw.get_data()
Using default location ~/mne_data for ssvep...

  0%|                                              | 0.00/3.33M [00:00<?, ?B/s]
  0%|▏                                     | 14.3k/3.33M [00:00<00:24, 137kB/s]
  1%|▍                                     | 41.0k/3.33M [00:00<00:15, 209kB/s]
  3%|█▎                                     | 110k/3.33M [00:00<00:07, 403kB/s]
  6%|██▏                                    | 191k/3.33M [00:00<00:05, 532kB/s]
  9%|███▎                                   | 288k/3.33M [00:00<00:04, 665kB/s]
 11%|████▎                                  | 369k/3.33M [00:00<00:04, 701kB/s]
 14%|█████▍                                 | 469k/3.33M [00:00<00:03, 785kB/s]
 17%|██████▋                                | 567k/3.33M [00:00<00:03, 836kB/s]
 21%|████████▏                             | 715k/3.33M [00:00<00:02, 1.02MB/s]
 25%|█████████▋                            | 846k/3.33M [00:01<00:02, 1.09MB/s]
 29%|██████████▉                           | 959k/3.33M [00:01<00:02, 1.09MB/s]
 33%|████████████                         | 1.09M/3.33M [00:01<00:01, 1.15MB/s]
 36%|█████████████▍                       | 1.21M/3.33M [00:01<00:01, 1.14MB/s]
 40%|██████████████▋                      | 1.32M/3.33M [00:01<00:01, 1.13MB/s]
 45%|████████████████▍                    | 1.48M/3.33M [00:01<00:01, 1.26MB/s]
 48%|█████████████████▉                   | 1.61M/3.33M [00:01<00:01, 1.25MB/s]
 52%|███████████████████▎                 | 1.74M/3.33M [00:01<00:01, 1.24MB/s]
 56%|████████████████████▋                | 1.86M/3.33M [00:01<00:01, 1.23MB/s]
 60%|██████████████████████               | 1.98M/3.33M [00:01<00:01, 1.21MB/s]
 63%|███████████████████████▍             | 2.11M/3.33M [00:02<00:01, 1.16MB/s]
 67%|████████████████████████▋            | 2.22M/3.33M [00:02<00:00, 1.14MB/s]
 70%|█████████████████████████▉           | 2.34M/3.33M [00:02<00:00, 1.13MB/s]
 74%|███████████████████████████▉          | 2.45M/3.33M [00:02<00:01, 869kB/s]
 77%|█████████████████████████████         | 2.55M/3.33M [00:02<00:00, 885kB/s]
 79%|██████████████████████████████▏       | 2.64M/3.33M [00:02<00:00, 894kB/s]
 82%|███████████████████████████████▏      | 2.74M/3.33M [00:02<00:00, 872kB/s]
 85%|████████████████████████████████▎     | 2.83M/3.33M [00:02<00:00, 872kB/s]
 88%|█████████████████████████████████▎    | 2.92M/3.33M [00:03<00:00, 872kB/s]
 90%|██████████████████████████████████▎   | 3.01M/3.33M [00:03<00:00, 869kB/s]
 93%|███████████████████████████████████▎  | 3.10M/3.33M [00:03<00:00, 865kB/s]
 96%|████████████████████████████████████▎ | 3.18M/3.33M [00:03<00:00, 863kB/s]
 99%|█████████████████████████████████████▍| 3.28M/3.33M [00:03<00:00, 889kB/s]
  0%|                                              | 0.00/3.33M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 3.33M/3.33M [00:00<00:00, 15.0GB/s]
/home/docs/checkouts/ RuntimeWarning: Setting non-standard config type: "MNE_DATASETS_SSVEPEXO_PATH"
  data_path = fetch_dataset(dataset_params, force_update=True)
Download complete in 04s (3.2 MB)

Visualization of raw EEG data

Plot few seconds of signal from the Oz electrode using matplotlib

n_seconds = 2
time = np.linspace(0, n_seconds, n_seconds * sfreq,
                   endpoint=False)[np.newaxis, :]
plt.figure(figsize=(10, 4))
plt.plot(time.T, eeg_data[np.array(raw.ch_names) == 'Oz', :n_seconds*sfreq].T,
         color='C0', lw=0.5)
plt.xlabel("Time (s)")
plt.ylabel(r"Oz ($\mu$V)")
plot classify ssvep mdm

And of all electrodes:

plt.figure(figsize=(10, 4))
for ch_idx, ch_name in enumerate(raw.ch_names):
    plt.plot(time.T, eeg_data[ch_idx, :n_seconds*sfreq].T, lw=0.5,
plt.xlabel("Time (s)")
plt.ylabel(r"EEG ($\mu$V)")
plt.legend(loc='upper right')
plot classify ssvep mdm

With MNE, it is much easier to visualize the data

raw.plot(duration=n_seconds, start=0, n_channels=8, scalings={'eeg': 4e-2},
         color={'eeg': 'steelblue'})
plot classify ssvep mdm
Using matplotlib as 2D backend.

<MNEBrowseFigure size 800x800 with 4 Axes>

Extended signals for spatial covariance

Using the approach proposed by [1], the SSVEP signal is extended to include the filtered signals for each stimulation frequency. We stack the filtered signals to build an extended signal.

# We stack the filtered signals to build an extended signal
frequencies = [13, 17, 21]
freq_band = 0.1
raw_ext = extend_signal(raw, frequencies, freq_band)
Creating RawArray with float64 data, n_channels=24, n_times=92384
    Range : 0 ... 92383 =      0.000 ...   360.871 secs

Plot the extended signal

raw_ext.plot(duration=n_seconds, start=14, n_channels=24,
             scalings={'eeg': 5e-4}, color={'eeg': 'steelblue'})
plot classify ssvep mdm
<MNEBrowseFigure size 800x800 with 4 Axes>

Building Epochs and plotting 3 s of the signal from electrode Oz for a trial

epochs = Epochs(
    raw_ext, events, event_id, tmin=2, tmax=5, baseline=None

n_seconds = 3
time = np.linspace(0, n_seconds, n_seconds * sfreq,
                   endpoint=False)[np.newaxis, :]
channels = range(0, len(raw_ext.ch_names), len(raw.ch_names))
plt.figure(figsize=(7, 5))
for f, c in zip(frequencies, channels):
    plt.plot(epochs[5, c, :].T, label=str(int(f))+' Hz')
plt.xlabel("Time (s)")
plt.ylabel(r"Oz after filtering ($\mu$V)")
plt.legend(loc='upper right')
plot classify ssvep mdm
Not setting metadata
32 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 32 events and 769 original time points ...
0 bad epochs dropped

As it can be seen on this example, the subject is watching the 13Hz stimulation and the EEG activity is showing an increase activity in this frequency band while other frequencies have lower amplitudes.

Spatial covariance for SSVEP

The covariance matrices will be estimated using the Ledoit-Wolf shrinkage estimator on the extended signal.

cov_ext_trials = BlockCovariances(
    estimator='lwf', block_size=8

# This plot shows an example of a covariance matrix observed for each class:
ch_names =['ch_names']

plt.figure(figsize=(7, 7))
for i, l in enumerate(event_id):
    ax = plt.subplot(2, 2, i+1)
    plt.imshow(cov_ext_trials[events[:, 2] == event_id[l]][0],
    plt.title('Cov for class: '+l)
    if i == 0 or i == 2:
        plt.yticks(np.arange(len(ch_names)), ch_names)
        ax.tick_params(axis='both', which='major', labelsize=7)
Cov for class: 13 Hz, Cov for class: 17 Hz, Cov for class: 21 Hz, Cov for class: resting-state

It appears clearly that each class yields a different structure of the covariance matrix. Each stimulation (13, 17 and 21 Hz) generating higher covariance values for EEG signal filtered at the proper bandwith and no activation at all for the other bandwiths. The resting state, where the subject focus on the center of the display and far from all blinking stimulus, shows an activity with higher correlation in the 13Hz frequency and lower but still visible activity in the other bandwiths.

Classify with MDM

Plotting mean of each class

cov_centers = np.empty((len(event_id), 24, 24))
for i, l in enumerate(event_id):
    cov_centers[i] = mean_riemann(cov_ext_trials[events[:, 2] == event_id[l]])

plt.figure(figsize=(7, 7))
for i, l in enumerate(event_id):
    ax = plt.subplot(2, 2, i+1)
    plt.imshow(cov_centers[i], cmap=plt.get_cmap('RdBu_r'))
    plt.title('Cov mean for class: '+l)
    if i == 0 or i == 2:
        plt.yticks(np.arange(len(ch_names)), ch_names)
        ax.tick_params(axis='both', which='major', labelsize=7)
Cov mean for class: 13 Hz, Cov mean for class: 17 Hz, Cov mean for class: 21 Hz, Cov mean for class: resting-state

Minimum distance to mean is a simple and robust algorithm for BCI decoding. It reproduces results of [2] for the first session of subject 12.

print("Number of trials: {}".format(len(cov_ext_trials)))

cv = RepeatedKFold(n_splits=2, n_repeats=10, random_state=42)
mdm = MDM(metric=dict(mean='riemann', distance='riemann'))
scores = cross_val_score(mdm, cov_ext_trials, events[:, 2], cv=cv, n_jobs=1)
print("MDM accuracy: {:.2f}% +/- {:.2f}".format(np.mean(scores)*100,
# The obtained results are 80.62% +/- 16.29 for this session, with a repeated
# 10-fold validation.
Number of trials: 32
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
/home/docs/checkouts/ UserWarning: Convergence not reached
  warnings.warn("Convergence not reached")
MDM accuracy: 80.94% +/- 16.23


Total running time of the script: (0 minutes 11.216 seconds)

Gallery generated by Sphinx-Gallery