.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/covariance-estimation/plot_covariance_estimation.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_covariance-estimation_plot_covariance_estimation.py: =============================================================================== Compare covariance and kernel estimators =============================================================================== Comparison of covariance estimators for different EEG signal lengths and their impact on classification [1]_. Kernel estimators are also compared [2]_. .. GENERATED FROM PYTHON SOURCE LINES 9-29 .. code-block:: Python # Authors: Sylvain Chevallier and Quentin Barthélemy # # License: BSD (3-clause) from matplotlib import pyplot as plt from mne import Epochs, pick_types, events_from_annotations from mne.datasets import eegbci from mne.io import concatenate_raws from mne.io.edf import read_raw_edf import numpy as np import pandas as pd import seaborn as sns from sklearn.model_selection import cross_val_score, StratifiedKFold from sklearn.pipeline import make_pipeline from pyriemann.classification import MDM from pyriemann.estimation import Covariances, Kernels from pyriemann.utils.distance import distance .. GENERATED FROM PYTHON SOURCE LINES 30-35 Estimating covariance on synthetic data ---------------------------------------- Generate synthetic data, sampled from a distribution considered as the groundtruth. .. GENERATED FROM PYTHON SOURCE LINES 35-49 .. code-block:: Python rs = np.random.RandomState(42) n_matrices, n_channels, n_times = 10, 5, 1000 var = 2.0 + 0.1 * rs.randn(n_matrices, n_channels) A = 2 * rs.rand(n_channels, n_channels) - 1 A /= np.linalg.norm(A, axis=1)[:, np.newaxis] true_covs = np.empty(shape=(n_matrices, n_channels, n_channels)) X = np.empty(shape=(n_matrices, n_channels, n_times)) for i in range(n_matrices): true_covs[i] = A @ np.diag(var[i]) @ A.T X[i] = rs.multivariate_normal( np.array([0.0] * n_channels), true_covs[i], size=n_times ).T .. GENERATED FROM PYTHON SOURCE LINES 50-61 `Covariances()` class offers several estimators: - sample covariance matrix (SCM), - Ledoit-Wolf (LWF), - Schaefer-Strimmer (SCH), - oracle approximating shrunk (OAS) covariance, - minimum covariance determinant (MCD), - and others. We will compare the distance of LWF, OAS and SCH estimators with the groundtruth, while increasing epoch length. .. GENERATED FROM PYTHON SOURCE LINES 61-72 .. code-block:: Python estimators = ["lwf", "oas", "sch"] w_len = np.linspace(10, n_times, 20, dtype=int) dfd = list() for est in estimators: for wl in w_len: est_covs = Covariances(estimator=est).transform(X[:, :, :wl]) dists = distance(est_covs, true_covs, metric="riemann") dfd.extend([dict(estimator=est, wlen=wl, dist=d) for d in dists]) dfd = pd.DataFrame(dfd) .. GENERATED FROM PYTHON SOURCE LINES 73-83 .. code-block:: Python fig, ax = plt.subplots(figsize=(6, 4)) ax.set(xscale="log") sns.lineplot(data=dfd, x="wlen", y="dist", hue="estimator", ax=ax) ax.set_title("Distance to groundtruth covariance matrix") ax.set_xlabel("Number of time samples") ax.set_ylabel(r"$\delta(\Sigma, \hat{\Sigma})$") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/covariance-estimation/images/sphx_glr_plot_covariance_estimation_001.png :alt: Distance to groundtruth covariance matrix :srcset: /auto_examples/covariance-estimation/images/sphx_glr_plot_covariance_estimation_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 84-88 Choice of estimator for motor imagery data ------------------------------------------ Loading data from PhysioNet MI dataset, for subject 1. .. GENERATED FROM PYTHON SOURCE LINES 88-106 .. code-block:: Python event_id = dict(hands=2, feet=3) subject = 1 runs = [6, 10] # motor imagery: hands vs feet raw_files = [ read_raw_edf(f, preload=True, stim_channel="auto") for f in eegbci.load_data(subject, runs, update_path=True) ] raw = concatenate_raws(raw_files) picks = pick_types(raw.info, eeg=True, exclude="bads") # subsample elecs picks = picks[::2] # Apply band-pass filter raw.filter(7.0, 35.0, method="iir", picks=picks, skip_by_annotation="edge") events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) event_ids = dict(hands=2, feet=3) .. rst-class:: sphx-glr-script-out .. code-block:: none Using default location ~/mne_data for EEGBCI... Downloading EEGBCI data Attempting to create new mne-python configuration file: /home/docs/.mne/mne-python.json Could not read the /home/docs/.mne/mne-python.json json file during the writing. Assuming it is empty. Got: Expecting value: line 1 column 1 (char 0) Download complete in 10s (5.0 MB) Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R06.edf... Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... Extracting EDF parameters from /home/docs/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R10.edf... Setting channel info structure... Creating raw.info structure... Reading 0 ... 19999 = 0.000 ... 124.994 secs... Filtering a subset of channels. The highpass and lowpass values in the measurement info will not be updated. Filtering raw data in 2 contiguous segments Setting up band-pass filter from 7 - 35 Hz IIR filter parameters --------------------- Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter: - Filter order 16 (effective, after forward-backward) - Cutoffs at 7.00, 35.00 Hz: -6.02, -6.02 dB Used Annotations descriptions: [np.str_('T1'), np.str_('T2')] .. GENERATED FROM PYTHON SOURCE LINES 107-124 Influence of shrinkage to estimate matrices ------------------------------------------- Sample covariance matrix (SCM) estimation could lead to ill-conditionned matrices depending on the quality and quantity of EEG data available. Matrix condition number is the ratio between the highest and lowest eigenvalues: high values indicates ill-conditionned matrices that are not suitable for classification. A common approach to mitigate this issue is to regularize covariance matrices by shrinkage, like in Ledoit-Wolf, Schaefer-Strimmer or oracle estimators. In addition to covariance matrices, kernel matrices are computed for three kernel functions: - radial basis function (RBF), - polynomial, - Laplacian. .. GENERATED FROM PYTHON SOURCE LINES 124-156 .. code-block:: Python estimators = [ "cov-lwf", "cov-oas", "cov-sch", "cov-scm", "ker-rbf", "ker-polynomial", "ker-laplacian", ] tmin = -0.2 w_len = np.linspace(0.5, 2.5, 5) n_matrices = 45 dfc = list() for wl in w_len: X = Epochs( raw, events, event_id=event_ids, tmin=tmin, tmax=tmin + wl, picks=picks, preload=True, verbose=False, ).get_data(copy=False) for est in estimators: est_class, est_param = est.split("-") if est_class == "ker": covs = Kernels(metric=est_param).transform(X) else: covs = Covariances(estimator=est_param).transform(X) evals, _ = np.linalg.eigh(covs) dfc.extend([dict(estimator=est, wlen=wl, cond=max(e) / min(e)) for e in evals]) dfc = pd.DataFrame(dfc) .. GENERATED FROM PYTHON SOURCE LINES 157-167 .. code-block:: Python fig, ax = plt.subplots(figsize=(6, 4)) ax.set(yscale="log") sns.lineplot(data=dfc, x="wlen", y="cond", hue="estimator", ax=ax) ax.set_title("Condition number of estimated matrices") ax.set_xlabel("Epoch length (s)") ax.set_ylabel(r"$\lambda_{\max}$/$\lambda_{\min}$") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/covariance-estimation/images/sphx_glr_plot_covariance_estimation_002.png :alt: Condition number of estimated matrices :srcset: /auto_examples/covariance-estimation/images/sphx_glr_plot_covariance_estimation_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 168-173 Picking a good estimator for classification ------------------------------------------- The choice of estimator have an impact on classification, especially when the matrices are estimated on short time windows. .. GENERATED FROM PYTHON SOURCE LINES 173-210 .. code-block:: Python tmin = 0.0 w_len = np.linspace(0.5, 2.5, 5) n_matrices, n_splits = 45, 5 dfa = list() sc = "balanced_accuracy" cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=123) for wl in w_len: epochs = Epochs( raw, events, event_ids, tmin, tmin + wl, proj=True, picks=picks, preload=True, baseline=None, verbose=False, ) X = epochs.get_data(copy=False) y = np.array([0 if ev == 2 else 1 for ev in epochs.events[:, -1]]) for est in estimators: est_class, est_param = est.split("-") if est_class == "ker": clf = make_pipeline(Kernels(metric=est_param), MDM()) else: clf = make_pipeline(Covariances(estimator=est_param), MDM()) try: score = cross_val_score(clf, X, y, cv=cv, scoring=sc) dfa += [dict(estimator=est, wlen=wl, accuracy=sc) for sc in score] except ValueError: print(f"{est}: {wl} is not sufficent to estimate a SPD matrix") dfa += [dict(estimator=est, wlen=wl, accuracy=np.nan)] * n_splits dfa = pd.DataFrame(dfa) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.10/pyriemann/utils/mean.py:688: UserWarning: Convergence not reached warnings.warn("Convergence not reached") .. GENERATED FROM PYTHON SOURCE LINES 211-230 .. code-block:: Python fig, ax = plt.subplots(figsize=(6, 4)) sns.lineplot( data=dfa, x="wlen", y="accuracy", hue="estimator", style="estimator", ax=ax, errorbar=None, markers=True, dashes=False, ) ax.set_title("Accuracy for different estimators and epoch lengths") ax.set_xlabel("Epoch length (s)") ax.set_ylabel("Classification accuracy") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/covariance-estimation/images/sphx_glr_plot_covariance_estimation_003.png :alt: Accuracy for different estimators and epoch lengths :srcset: /auto_examples/covariance-estimation/images/sphx_glr_plot_covariance_estimation_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 231-242 References ---------- .. [1] `Riemannian classification for SSVEP based BCI: offline versus online implementations `_ S. Chevallier, E. Kalunga, Q. Barthélemy, F. Yger. Brain–Computer Interfaces Handbook: Technological and Theoretical Advances, 2018. .. [2] `Beyond Covariance: Feature Representation with Nonlinear Kernel Matrices `_ # noqa L. Wang, J. Zhang, L. Zhou, C. Tang, W Li. ICCV, 2015. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 23.294 seconds) .. _sphx_glr_download_auto_examples_covariance-estimation_plot_covariance_estimation.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_covariance_estimation.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_covariance_estimation.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_covariance_estimation.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_