Motor imagery classification

Classify Motor imagery data with Riemannian Geometry.

# generic import
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

# mne import
from mne import Epochs, pick_types
from mne.io import concatenate_raws
from mne.io.edf import read_raw_edf
from mne.datasets import eegbci
from mne.event import find_events
from mne.decoding import CSP

# pyriemann import
from pyriemann.classification import MDM, TSclassifier
from pyriemann.estimation import Covariances

# sklearn imports
from sklearn.cross_validation import cross_val_score, KFold
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

Set parameters and read data

# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = 1., 2.
event_id = dict(hands=2, feet=3)
subject = 7
runs = [6, 10, 14]  # motor imagery: hands vs feet

raw_files = [read_raw_edf(f, preload=True) for f in eegbci.load_data(subject,
             runs)]
raw = concatenate_raws(raw_files)

picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')
# subsample elecs
picks = picks[::2]

# Apply band-pass filter
raw.filter(7., 35., method='iir', picks=picks)

events = find_events(raw, shortest_event=0, stim_channel='STI 014')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                baseline=None, preload=True, verbose=False)
labels = epochs.events[:, -1] - 2


# cross validation
cv = KFold(len(labels), 10, shuffle=True, random_state=42)
# get epochs
epochs_data_train = 1e6*epochs.get_data()

# compute covariance matrices
cov_data_train = Covariances().transform(epochs_data_train)

Out:

Downloading http://www.physionet.org/physiobank/database/eegmmidb/S007/S007R06.edf (2.5 MB)

[                                        ] 0.32% (    8 kB,   2.2 MB/s) |
[                                        ] 0.95% (   24 kB,   2.1 MB/s) /
[                                        ] 1.58% (   40 kB,   9.8 MB/s) -
[.                                       ] 2.84% (   72 kB,   9.0 MB/s) \
[.                                       ] 4.10% (  104 kB,   8.2 MB/s) |
[..                                      ] 5.36% (  136 kB,  15.3 MB/s) /
[...                                     ] 7.89% (  200 kB,  36.2 MB/s) -
[.....                                   ] 12.93% (  328 kB,  33.0 MB/s) \
[.......                                 ] 17.98% (  456 kB,  30.4 MB/s) |
[.........                               ] 23.03% (  584 kB,  28.2 MB/s) /
[...........                             ] 28.08% (  712 kB,  25.7 MB/s) -
[.............                           ] 33.12% (  840 kB,  23.9 MB/s) \
[...............                         ] 38.17% (  968 kB,  22.3 MB/s) |
[.................                       ] 43.22% (  1.1 MB,  25.4 MB/s) /
[.....................                   ] 53.31% (  1.3 MB,  23.6 MB/s) -
[.........................               ] 63.41% (  1.6 MB,  22.1 MB/s) \
[.............................           ] 73.50% (  1.8 MB,  21.2 MB/s) |
[.................................       ] 83.60% (  2.1 MB,  19.9 MB/s) /
[.....................................   ] 93.69% (  2.3 MB,  19.4 MB/s) -
[........................................] 100.00% (  2.5 MB,  18.5 MB/s) \
Downloading http://www.physionet.org/physiobank/database/eegmmidb/S007/S007R10.edf (2.5 MB)

[                                        ] 0.32% (    8 kB,   3.2 MB/s) |
[                                        ] 0.95% (   24 kB,   3.0 MB/s) /
[                                        ] 1.58% (   40 kB,   9.1 MB/s) -
[.                                       ] 2.84% (   72 kB,   8.4 MB/s) \
[.                                       ] 4.10% (  104 kB,   7.7 MB/s) |
[..                                      ] 5.36% (  136 kB,  18.4 MB/s) /
[...                                     ] 7.89% (  200 kB,  35.1 MB/s) -
[.....                                   ] 12.93% (  328 kB,  32.0 MB/s) \
[.......                                 ] 17.98% (  456 kB,  29.6 MB/s) |
[.........                               ] 23.03% (  584 kB,  27.4 MB/s) /
[...........                             ] 28.08% (  712 kB,  25.4 MB/s) -
[.............                           ] 33.12% (  840 kB,  23.6 MB/s) \
[...............                         ] 38.17% (  968 kB,  22.0 MB/s) |
[.................                       ] 43.22% (  1.1 MB,  20.6 MB/s) /
[...................                     ] 48.26% (  1.2 MB,  19.3 MB/s) -
[.....................                   ] 53.31% (  1.3 MB,  22.5 MB/s) \
[.........................               ] 63.41% (  1.6 MB,  21.0 MB/s) |
[.............................           ] 73.50% (  1.8 MB,  19.8 MB/s) /
[.................................       ] 83.60% (  2.1 MB,  19.1 MB/s) -
[.....................................   ] 93.69% (  2.3 MB,  18.6 MB/s) \
[........................................] 100.00% (  2.5 MB,  17.8 MB/s) |
Downloading http://www.physionet.org/physiobank/database/eegmmidb/S007/S007R14.edf (2.5 MB)

[                                        ] 0.32% (    8 kB,   4.4 MB/s) |
[                                        ] 0.95% (   24 kB,   4.1 MB/s) /
[                                        ] 1.58% (   40 kB,  12.9 MB/s) -
[.                                       ] 2.84% (   72 kB,  11.8 MB/s) \
[.                                       ] 4.10% (  104 kB,  10.8 MB/s) |
[..                                      ] 5.36% (  136 kB,  21.1 MB/s) /
[...                                     ] 7.89% (  200 kB,  40.9 MB/s) -
[.....                                   ] 12.93% (  328 kB,  37.2 MB/s) \
[.......                                 ] 17.98% (  456 kB,  34.2 MB/s) |
[.........                               ] 23.03% (  584 kB,  31.6 MB/s) /
[...........                             ] 28.08% (  712 kB,  29.2 MB/s) -
[.............                           ] 33.12% (  840 kB,  27.0 MB/s) \
[...............                         ] 38.17% (  968 kB,  25.1 MB/s) |
[.................                       ] 43.22% (  1.1 MB,  24.0 MB/s) /
[...................                     ] 48.26% (  1.2 MB,  22.4 MB/s) -
[.....................                   ] 53.31% (  1.3 MB,  20.9 MB/s) \
[.......................                 ] 58.36% (  1.4 MB,  20.0 MB/s) |
[.........................               ] 63.41% (  1.6 MB,  18.4 MB/s) /
[...........................             ] 68.45% (  1.7 MB,  17.4 MB/s) -
[.............................           ] 73.50% (  1.8 MB,  24.3 MB/s) \
[.................................       ] 83.60% (  2.1 MB,  22.7 MB/s) |
[.....................................   ] 93.69% (  2.3 MB,  21.9 MB/s) /
[........................................] 100.00% (  2.5 MB,  20.7 MB/s) -
Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S007/S007R06.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S007/S007R10.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S007/S007R14.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Filtering a subset of channels. The highpass and lowpass values in the measurement info will not be updated.
Setting up band-pass filter from 7 - 35 Hz
Setting up band-pass filter from 7 - 35 Hz
Setting up band-pass filter from 7 - 35 Hz
Trigger channel has a non-zero initial value of 1 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
89 events found
Event IDs: [1 2 3]

Classification with Minimum distance to mean

mdm = MDM(metric=dict(mean='riemann', distance='riemann'))

# Use scikit-learn Pipeline with cross_val_score function
scores = cross_val_score(mdm, cov_data_train, labels, cv=cv, n_jobs=1)

# Printing the results
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1. - class_balance)
print("MDM Classification accuracy: %f / Chance level: %f" % (np.mean(scores),
                                                              class_balance))

Out:

MDM Classification accuracy: 0.850000 / Chance level: 0.511111

Classification with Tangent Space Logistic Regression

clf = TSclassifier()
# Use scikit-learn Pipeline with cross_val_score function
scores = cross_val_score(clf, cov_data_train, labels, cv=cv, n_jobs=1)

# Printing the results
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1. - class_balance)
print("Tangent space Classification accuracy: %f / Chance level: %f" %
      (np.mean(scores), class_balance))

Out:

Tangent space Classification accuracy: 0.960000 / Chance level: 0.511111

Classification with CSP + logistic regression

# Assemble a classifier
lr = LogisticRegression()
csp = CSP(n_components=4, reg='ledoit_wolf', log=True)


clf = Pipeline([('CSP', csp), ('LogisticRegression', lr)])
scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=1)

# Printing the results
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1. - class_balance)
print("CSP + LDA Classification accuracy: %f / Chance level: %f" %
      (np.mean(scores), class_balance))

Out:

Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
Estimating covariance using LEDOIT_WOLF
Done.
CSP + LDA Classification accuracy: 0.880000 / Chance level: 0.511111

Display MDM centroid

mdm = MDM()
mdm.fit(cov_data_train, labels)

fig, axes = plt.subplots(1, 2, figsize=[8, 4])
ch_names = [ch.replace('.', '') for ch in epochs.ch_names]

df = pd.DataFrame(data=mdm.covmeans_[0], index=ch_names, columns=ch_names)
g = sns.heatmap(df, ax=axes[0], square=True, cbar=False, xticklabels=2,
                yticklabels=2)
g.set_title('Mean covariance - hands')

df = pd.DataFrame(data=mdm.covmeans_[1], index=ch_names, columns=ch_names)
g = sns.heatmap(df, ax=axes[1], square=True, cbar=False, xticklabels=2,
                yticklabels=2)
plt.xticks(rotation='vertical')
plt.yticks(rotation='horizontal')
g.set_title('Mean covariance - feets')

# dirty fix
plt.sca(axes[0])
plt.xticks(rotation='vertical')
plt.yticks(rotation='horizontal')
plt.show()
../../_images/sphx_glr_plot_single_001.png

Total running time of the script: ( 0 minutes 18.585 seconds)

Gallery generated by Sphinx-Gallery