Note
Go to the end to download the full example code
Visualization of SSVEP-based BCI Classification in Tangent Space¶
Project extended covariance matrices of SSVEP-based BCI in the tangent space, using principal geodesic analysis (PGA).
You should have a look to “Offline SSVEP-based BCI Multiclass Prediction” before this example.
# Authors: Quentin Barthélemy, Emmanuel Kalunga and Sylvain Chevallier
#
# License: BSD (3-clause)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mne import find_events, Epochs, make_fixed_length_epochs
from mne.io import Raw
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import PCA
from pyriemann.estimation import BlockCovariances
from pyriemann.classification import MDM
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import _add_alpha
from helpers.ssvep_helpers import download_data, extend_signal
clabel = ['resting-state', '13 Hz', '17 Hz', '21 Hz']
clist = plt.cm.viridis(np.array([0, 1, 2, 3])/3)
cmap = "viridis"
def plot_pga(ax, data, labels, centers):
sc = ax.scatter(data[:, 0], data[:, 1], c=labels, marker='P', cmap=cmap)
ax.scatter(
centers[:, 0], centers[:, 1], c=clist, marker='o', s=100, cmap=cmap
)
ax.set(xlabel='PGA, 1st axis', ylabel='PGA, 2nd axis')
for i in range(len(clabel)):
ax.scatter([], [], color=clist[i], marker='o', s=50, label=clabel[i])
ax.legend(loc='upper right')
return sc
Load EEG and extract covariance matrices for SSVEP¶
frequencies = [13, 17, 21]
freq_band = 0.1
events_id = {'13 Hz': 2, '17 Hz': 4, '21 Hz': 3, 'resting-state': 1}
duration = 2.5 # duration of epochs
interval = 0.25 # interval between successive epochs for online processing
# Subject 12: first 4 sessions for training, last session for test
# Training set
raw = Raw(download_data(subject=12, session=1), preload=True, verbose=False)
events = find_events(raw, shortest_event=0, verbose=False)
raw = raw.pick_types(eeg=True)
ch_count = len(raw.info['ch_names'])
raw_ext = extend_signal(raw, frequencies, freq_band)
epochs = Epochs(
raw_ext, events, events_id, tmin=2, tmax=5, baseline=None, verbose=False)
x_train = BlockCovariances(
estimator='lwf', block_size=ch_count).transform(epochs.get_data())
y_train = events[:, 2]
# Testing set
raw = Raw(download_data(subject=12, session=4), preload=True, verbose=False)
raw = raw.pick_types(eeg=True)
raw_ext = extend_signal(raw, frequencies, freq_band)
epochs = make_fixed_length_epochs(
raw_ext, duration=duration, overlap=duration - interval, verbose=False)
x_test = BlockCovariances(
estimator='lwf', block_size=ch_count).transform(epochs.get_data())
Download complete in 00s (3.2 MB)
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Creating RawArray with float64 data, n_channels=24, n_times=92384
Range : 0 ... 92383 = 0.000 ... 360.871 secs
Ready.
Using data from preloaded Raw for 32 events and 769 original time points ...
0 bad epochs dropped
0%| | 0.00/5.35M [00:00<?, ?B/s]
1%|▏ | 31.7k/5.35M [00:00<00:17, 306kB/s]
1%|▍ | 64.5k/5.35M [00:00<00:17, 311kB/s]
2%|▋ | 97.3k/5.35M [00:00<00:16, 312kB/s]
3%|█▏ | 163k/5.35M [00:00<00:11, 437kB/s]
5%|█▊ | 245k/5.35M [00:00<00:09, 562kB/s]
6%|██▎ | 310k/5.35M [00:00<00:08, 584kB/s]
7%|██▋ | 376k/5.35M [00:00<00:08, 597kB/s]
8%|███▏ | 441k/5.35M [00:00<00:08, 607kB/s]
9%|███▋ | 507k/5.35M [00:00<00:07, 613kB/s]
11%|████▏ | 572k/5.35M [00:01<00:07, 617kB/s]
12%|████▋ | 638k/5.35M [00:01<00:07, 618kB/s]
13%|█████▏ | 703k/5.35M [00:01<00:07, 619kB/s]
14%|█████▌ | 769k/5.35M [00:01<00:07, 621kB/s]
16%|██████ | 831k/5.35M [00:01<00:07, 614kB/s]
17%|██████▌ | 893k/5.35M [00:01<00:07, 606kB/s]
18%|██████▉ | 954k/5.35M [00:01<00:07, 599kB/s]
19%|███████▏ | 1.01M/5.35M [00:01<00:07, 592kB/s]
20%|███████▋ | 1.07M/5.35M [00:01<00:07, 584kB/s]
21%|████████ | 1.13M/5.35M [00:01<00:07, 578kB/s]
22%|████████▍ | 1.19M/5.35M [00:02<00:07, 572kB/s]
23%|████████▊ | 1.25M/5.35M [00:02<00:07, 566kB/s]
24%|█████████▎ | 1.31M/5.35M [00:02<00:07, 569kB/s]
26%|█████████▊ | 1.38M/5.35M [00:02<00:06, 587kB/s]
27%|██████████▏ | 1.44M/5.35M [00:02<00:06, 599kB/s]
28%|██████████▋ | 1.51M/5.35M [00:02<00:06, 606kB/s]
29%|███████████▏ | 1.57M/5.35M [00:02<00:06, 612kB/s]
31%|███████████▋ | 1.64M/5.35M [00:02<00:06, 617kB/s]
32%|████████████ | 1.70M/5.35M [00:02<00:05, 618kB/s]
33%|████████████▌ | 1.77M/5.35M [00:03<00:05, 619kB/s]
34%|█████████████ | 1.83M/5.35M [00:03<00:05, 621kB/s]
36%|█████████████▍ | 1.90M/5.35M [00:03<00:05, 621kB/s]
37%|█████████████▉ | 1.97M/5.35M [00:03<00:05, 621kB/s]
38%|██████████████▍ | 2.03M/5.35M [00:03<00:05, 622kB/s]
39%|██████████████▉ | 2.10M/5.35M [00:03<00:05, 620kB/s]
40%|███████████████▎ | 2.16M/5.35M [00:03<00:05, 620kB/s]
42%|███████████████▊ | 2.23M/5.35M [00:03<00:05, 620kB/s]
43%|████████████████▎ | 2.29M/5.35M [00:03<00:04, 621kB/s]
44%|████████████████▋ | 2.36M/5.35M [00:03<00:04, 622kB/s]
45%|█████████████████▏ | 2.42M/5.35M [00:04<00:04, 622kB/s]
47%|█████████████████▋ | 2.49M/5.35M [00:04<00:04, 624kB/s]
48%|██████████████████▏ | 2.55M/5.35M [00:04<00:04, 624kB/s]
49%|██████████████████▌ | 2.62M/5.35M [00:04<00:04, 625kB/s]
50%|███████████████████ | 2.69M/5.35M [00:04<00:04, 626kB/s]
51%|███████████████████▌ | 2.75M/5.35M [00:04<00:04, 626kB/s]
53%|████████████████████ | 2.82M/5.35M [00:04<00:04, 626kB/s]
54%|████████████████████▍ | 2.88M/5.35M [00:04<00:03, 626kB/s]
55%|████████████████████▉ | 2.95M/5.35M [00:04<00:03, 626kB/s]
56%|█████████████████████▍ | 3.01M/5.35M [00:05<00:03, 626kB/s]
58%|█████████████████████▊ | 3.08M/5.35M [00:05<00:03, 627kB/s]
59%|██████████████████████▎ | 3.14M/5.35M [00:05<00:03, 627kB/s]
60%|██████████████████████▊ | 3.21M/5.35M [00:05<00:03, 627kB/s]
61%|███████████████████████▎ | 3.28M/5.35M [00:05<00:03, 627kB/s]
62%|███████████████████████▋ | 3.34M/5.35M [00:05<00:03, 627kB/s]
64%|████████████████████████▏ | 3.41M/5.35M [00:05<00:03, 626kB/s]
65%|████████████████████████▋ | 3.47M/5.35M [00:05<00:02, 627kB/s]
66%|█████████████████████████▏ | 3.54M/5.35M [00:05<00:02, 625kB/s]
67%|█████████████████████████▌ | 3.60M/5.35M [00:05<00:02, 625kB/s]
69%|██████████████████████████ | 3.67M/5.35M [00:06<00:02, 624kB/s]
70%|██████████████████████████▌ | 3.73M/5.35M [00:06<00:02, 623kB/s]
71%|██████████████████████████▉ | 3.80M/5.35M [00:06<00:02, 622kB/s]
72%|███████████████████████████▍ | 3.87M/5.35M [00:06<00:02, 623kB/s]
73%|███████████████████████████▉ | 3.93M/5.35M [00:06<00:02, 624kB/s]
75%|████████████████████████████▍ | 4.00M/5.35M [00:06<00:02, 625kB/s]
76%|████████████████████████████▊ | 4.06M/5.35M [00:06<00:02, 624kB/s]
77%|█████████████████████████████▎ | 4.13M/5.35M [00:06<00:01, 623kB/s]
78%|█████████████████████████████▊ | 4.19M/5.35M [00:06<00:01, 623kB/s]
80%|██████████████████████████████▏ | 4.26M/5.35M [00:07<00:01, 622kB/s]
81%|██████████████████████████████▋ | 4.32M/5.35M [00:07<00:01, 622kB/s]
82%|███████████████████████████████▏ | 4.39M/5.35M [00:07<00:01, 622kB/s]
83%|███████████████████████████████▋ | 4.46M/5.35M [00:07<00:01, 623kB/s]
84%|████████████████████████████████ | 4.52M/5.35M [00:07<00:01, 624kB/s]
86%|████████████████████████████████▌ | 4.59M/5.35M [00:07<00:01, 623kB/s]
87%|█████████████████████████████████ | 4.65M/5.35M [00:07<00:01, 614kB/s]
88%|█████████████████████████████████▌ | 4.72M/5.35M [00:07<00:01, 616kB/s]
89%|█████████████████████████████████▉ | 4.78M/5.35M [00:07<00:00, 618kB/s]
91%|██████████████████████████████████▍ | 4.85M/5.35M [00:07<00:00, 620kB/s]
92%|██████████████████████████████████▉ | 4.91M/5.35M [00:08<00:00, 621kB/s]
93%|███████████████████████████████████▎ | 4.98M/5.35M [00:08<00:00, 621kB/s]
94%|███████████████████████████████████▊ | 5.05M/5.35M [00:08<00:00, 621kB/s]
96%|████████████████████████████████████▎ | 5.11M/5.35M [00:08<00:00, 621kB/s]
97%|████████████████████████████████████▊ | 5.18M/5.35M [00:08<00:00, 621kB/s]
98%|█████████████████████████████████████▏| 5.24M/5.35M [00:08<00:00, 623kB/s]
99%|█████████████████████████████████████▋| 5.31M/5.35M [00:08<00:00, 622kB/s]
0%| | 0.00/5.35M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 5.35M/5.35M [00:00<00:00, 20.8GB/s]
Download complete in 09s (5.1 MB)
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Creating RawArray with float64 data, n_channels=24, n_times=148544
Range : 0 ... 148543 = 0.000 ... 580.246 secs
Ready.
Using data from preloaded Raw for 2312 events and 640 original time points ...
0 bad epochs dropped
Classification with minimum distance to mean (MDM)¶
Classification for a 4-class SSVEP BCI, including resting-state class.
print("Number of training trials: {}".format(len(x_train)))
mdm = MDM(metric=dict(mean='riemann', distance='riemann'))
mdm.fit(x_train, y_train)
Number of training trials: 32
Projection in tangent space with principal geodesic analysis (PGA)¶
Project covariance matrices from the Riemannian manifold into the Euclidean tangent space at the grand average, and apply a principal component analysis (PCA) to obtain an unsupervised dimension reduction [1].
pga = make_pipeline(
TangentSpace(metric="riemann", tsupdate=False),
PCA(n_components=2)
)
ts_train = pga.fit_transform(x_train)
ts_means = pga.transform(np.array(mdm.covmeans_))
Offline training of MDM visualized by PGA¶
These figures show the trajectory on the tangent space taken by covariance matrices during a 4-class SSVEP experiment, and how they are classified epoch by epoch.
This figure reproduces Fig 3(c) of reference [2], showing training trials of best subject.
fig, ax = plt.subplots(figsize=(8, 8))
fig.suptitle('PGA of training set', fontsize=16)
plot_pga(ax, ts_train, y_train, ts_means)
plt.show()

/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/examples/SSVEP/plot_classify_ssvep_pga.py:41: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
ax.scatter(
Online classification by MDM visualized by PGA¶
This figure reproduces Fig 6 of reference [2], with an animation to imitate an online acquisition, processing and classification of EEG time-series.
Warning: [2] uses a curved based online classification, while a single trial classification is used here.
# Prepare data for online classification
test_visu = 50 # nb of matrices to display simultaneously
colors, ts_visu = [], np.empty([0, 2])
alphas = np.linspace(0, 1, test_visu)
fig, ax = plt.subplots(figsize=(8, 8))
fig.suptitle('PGA of testing set', fontsize=16)
pl = plot_pga(ax, ts_visu, colors, ts_means)
pl.axes.set_xlim(-5, 6)
pl.axes.set_ylim(-5, 5)

(-5.0, 5.0)
# Prepare animation for online classification
def online_classify(t):
global colors, ts_visu
# Online classification
y = mdm.predict(x_test[np.newaxis, t])
color = clist[int(y[0] - 1)]
ts_test = pga.transform(x_test[np.newaxis, t])
# Update data
colors.append(color)
ts_visu = np.vstack((ts_visu, ts_test))
if len(ts_visu) > test_visu:
colors.pop(0)
ts_visu = ts_visu[1:]
colors = _add_alpha(colors, alphas)
# Update plot
pl.set_offsets(np.c_[ts_visu[:, 0], ts_visu[:, 1]])
pl.set_color(colors)
return pl
interval_display = 1.0 # can be changed for a slower display
visu = FuncAnimation(fig, online_classify,
frames=range(0, len(x_test)),
interval=interval_display, blit=False, repeat=False)
# Plot online classification
# Plot complete visu: a dynamic display is required
plt.show()
# Plot only 10s, for animated documentation
try:
from IPython.display import HTML
except ImportError:
raise ImportError("Install IPython to plot animation in documentation")
plt.rcParams["animation.embed_limit"] = 10
HTML(visu.to_jshtml(fps=5, default_mode='loop'))
Animation size has reached 10525072 bytes, exceeding the limit of 10485760.0. If you're sure you want a larger animation embedded, set the animation.embed_limit rc parameter to a larger value (in MB). This and further frames will be dropped.