.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/motor-imagery/plot_ensemble_coherence.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_motor-imagery_plot_ensemble_coherence.py: ==================================================================== Ensemble learning on functional connectivity ==================================================================== This example shows how to compute SPD matrices from functional connectivity estimators and how to combine classification with ensemble learning [1]_. .. GENERATED FROM PYTHON SOURCE LINES 10-38 .. code-block:: Python # Authors: Sylvain Chevallier , # Marie-Constance Corsi # # License: BSD (3-clause) import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from mne import Epochs, pick_types, events_from_annotations from mne.io import concatenate_raws from mne.io.edf import read_raw_edf from mne.datasets import eegbci from sklearn.base import BaseEstimator, TransformerMixin from sklearn.ensemble import StackingClassifier from sklearn.linear_model import LogisticRegression from sklearn.model_selection import GridSearchCV, StratifiedKFold from sklearn.pipeline import Pipeline from sklearn.svm import SVC from pyriemann.classification import FgMDM from pyriemann.estimation import Coherences, Covariances from pyriemann.spatialfilters import CSP from pyriemann.tangentspace import TangentSpace from helpers.coherence_helpers import NearestSPD, get_results .. GENERATED FROM PYTHON SOURCE LINES 39-44 Define connectivity transformer ------------------------------- This estimator computes the functional connectivity from input signal using `pyriemann.estimation.Coherences` .. GENERATED FROM PYTHON SOURCE LINES 44-70 .. code-block:: Python class Connectivities(TransformerMixin, BaseEstimator): """Getting connectivity features from epoch""" def __init__(self, method="ordinary", fmin=8, fmax=35, fs=None): self.method = method self.fmin = fmin self.fmax = fmax self.fs = fs def fit(self, X, y=None): self._coh = Coherences( coh=self.method, fmin=self.fmin, fmax=self.fmax, fs=self.fs, ) return self def transform(self, X): X_coh = self._coh.fit_transform(X) X_con = np.mean(X_coh, axis=-1, keepdims=False) return X_con .. GENERATED FROM PYTHON SOURCE LINES 71-73 Load EEG data ------------- .. GENERATED FROM PYTHON SOURCE LINES 73-115 .. code-block:: Python # avoid classification of evoked responses by using epochs that start 1s after # cue onset. tmin, tmax = 1.0, 2.0 event_id = dict(hands=2, feet=3) subject = 7 runs = [4, 8] # motor imagery: left vs right hand raw_files = [ read_raw_edf(f, preload=True) for f in eegbci.load_data(subject, runs) ] raw = concatenate_raws(raw_files) picks = pick_types( raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads" ) # subsample elecs picks = picks[::2] # Apply band-pass filter raw.filter(7.0, 35.0, method="iir", picks=picks) events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) # Read epochs (train will be done only between 1 and 2s) epochs = Epochs( raw, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True, verbose=False, ) labels = epochs.events[:, -1] - 2 fs = epochs.info["sfreq"] X = 1e6 * epochs.get_data() .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading EEGBCI data Download complete in 06s (5.0 MB) Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S007/S007R04.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S007/S007R08.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... Filtering a subset of channels. The highpass and lowpass values in the measurement info will not be updated. Filtering raw data in 2 contiguous segments Setting up band-pass filter from 7 - 35 Hz IIR filter parameters --------------------- Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter: - Filter order 16 (effective, after forward-backward) - Cutoffs at 7.00, 35.00 Hz: -6.02, -6.02 dB Used Annotations descriptions: ['T1', 'T2'] /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/examples/motor-imagery/plot_ensemble_coherence.py:112: FutureWarning: The current default of copy=False will change to copy=True in 1.7. Set the value of copy explicitly to avoid this warning X = 1e6 * epochs.get_data() .. GENERATED FROM PYTHON SOURCE LINES 116-122 Defining pipelines ------------------- Compare CSP+SVM, FgMDM on covariance, tangent space logistic regression with covariance, lag coherence, and instantaneous coherence, along with ensemble method .. GENERATED FROM PYTHON SOURCE LINES 122-125 .. code-block:: Python ppl_baseline, ppl_fc, ppl_ens = {}, {}, {} .. GENERATED FROM PYTHON SOURCE LINES 126-127 Baseline algorithms are CSP with optimal SVM and FgMDM based on covariances .. GENERATED FROM PYTHON SOURCE LINES 127-142 .. code-block:: Python param_svm = {"kernel": ("linear", "rbf"), "C": [0.1, 1, 10]} step_csp = [ ("cov", Covariances(estimator="lwf")), ("csp", CSP(nfilter=6)), ("optsvm", GridSearchCV(SVC(), param_svm, cv=3)), ] ppl_baseline["CSP+optSVM"] = Pipeline(steps=step_csp) step_mdm = [ ("cov", Covariances(estimator="lwf")), ("fgmdm", FgMDM(metric="riemann", tsupdate=False)), ] ppl_baseline["FgMDM"] = Pipeline(steps=step_mdm) .. GENERATED FROM PYTHON SOURCE LINES 143-146 Functional connectivity pipelines use logistic regression in tangent space. They will be estimated from covariance, lagged coherence and instantaneous coherence. .. GENERATED FROM PYTHON SOURCE LINES 146-171 .. code-block:: Python spectral_met = ["cov", "lagged", "instantaneous"] fmin, fmax = 8, 35 param_lr = { "penalty": "elasticnet", "l1_ratio": 0.15, "intercept_scaling": 1000.0, "solver": "saga", } param_ft = {"fmin": fmin, "fmax": fmax, "fs": fs} step_fc = [ ("spd", NearestSPD()), ("tg", TangentSpace(metric="riemann")), ("LogistReg", LogisticRegression(**param_lr)), ] for sm in spectral_met: pname = sm + "+elasticnet" if sm == "cov": ppl_fc[pname] = Pipeline( steps=[("cov", Covariances(estimator="lwf"))] + step_fc ) else: ft = Connectivities(**param_ft, method=sm) ppl_fc[pname] = Pipeline(steps=[("ft", ft)] + step_fc) .. GENERATED FROM PYTHON SOURCE LINES 172-174 The ensemble classifier stacks a logistic regression on top of the three functional connectivity pipelines to make a global prediction .. GENERATED FROM PYTHON SOURCE LINES 174-187 .. code-block:: Python fc_estim = [(n, ppl_fc[n]) for n in ppl_fc] cvkf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) lr = LogisticRegression(**param_lr) ppl_ens["ensemble"] = StackingClassifier( estimators=fc_estim, cv=cvkf, n_jobs=1, final_estimator=lr, stack_method="predict_proba", ) .. GENERATED FROM PYTHON SOURCE LINES 188-190 Evaluation ---------- .. GENERATED FROM PYTHON SOURCE LINES 190-199 .. code-block:: Python dataset_res = list() all_ppl = {**ppl_baseline, **ppl_ens} # Compute results results = get_results(X, labels, all_ppl) results = pd.DataFrame(results) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/utils/mean.py:540: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/envs/latest/lib/python3.8/site-packages/sklearn/linear_model/_sag.py:350: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 200-202 Plot ---- .. GENERATED FROM PYTHON SOURCE LINES 202-218 .. code-block:: Python list_fc_ens = ["ensemble", "CSP+optSVM", "FgMDM"] + \ [sm + "+elasticnet" for sm in spectral_met] g = sns.catplot( data=results, x="pipeline", y="score", kind="bar", order=list_fc_ens, height=7, aspect=2, ) plt.show() .. image-sg:: /auto_examples/motor-imagery/images/sphx_glr_plot_ensemble_coherence_001.png :alt: plot ensemble coherence :srcset: /auto_examples/motor-imagery/images/sphx_glr_plot_ensemble_coherence_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 219-226 References ---------- .. [1] `Functional connectivity ensemble method to enhance BCI performance (FUCONE) `_ Corsi, M.-C., Chevallier, S., De Vico Fallani, F. & Yger, F. IEEE TBME, 2022 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 40.184 seconds) .. _sphx_glr_download_auto_examples_motor-imagery_plot_ensemble_coherence.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_ensemble_coherence.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_ensemble_coherence.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_