.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/artifacts/plot_detect_riemannian_potato_field_EEG.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_artifacts_plot_detect_riemannian_potato_field_EEG.py: =============================================================================== Online Artifact Detection with Riemannian Potato Field =============================================================================== Example of Riemannian Potato Field (RPF) [1]_ applied on EEG time-series to detect artifacts in online processing. It is compared to the Riemannian Potato (RP) [2]_. .. GENERATED FROM PYTHON SOURCE LINES 10-28 .. code-block:: default # Authors: Quentin Barthélemy # # License: BSD (3-clause) import numpy as np from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation from mne.datasets import eegbci from mne.io import read_raw_edf from mne.channels import make_standard_montage from mne import make_fixed_length_epochs from pyriemann.estimation import Covariances from pyriemann.utils.covariance import normalize from pyriemann.clustering import Potato, PotatoField .. GENERATED FROM PYTHON SOURCE LINES 29-70 .. code-block:: default def filter_bandpass(signal, low_freq, high_freq, channels=None, method="iir"): """Filter signal on specific channels and in a specific frequency band""" sig = signal.copy() if channels is not None: sig.pick_channels(channels) sig.filter(l_freq=low_freq, h_freq=high_freq, method=method, verbose=False) return sig def plot_detection(ax, rp_label, rpf_label): labels = [] ylims = ax.get_ylim() height = ylims[1] - ylims[0] if not rp_label: r1 = ax.axhspan( ylims[0] + 0.06 * height, ylims[1] - 0.05 * height, edgecolor='r', facecolor='none', xmin=-test_time_start / test_duration - 0.005, xmax=(duration - test_time_start) / test_duration - 0.005) labels.append(r1) ax.text(0.25, 0.95, 'RP', color='r', size=16, transform=ax.transAxes) if not rpf_label: r2 = ax.axhspan( ylims[0] + 0.05 * height, ylims[1] - 0.06 * height, edgecolor='m', facecolor='none', xmin=-test_time_start / test_duration + 0.005, xmax=(duration - test_time_start) / test_duration + 0.005) labels.append(r2) ax.text(0.65, 0.95, 'RPF', color='m', size=16, transform=ax.transAxes) if rp_label and rpf_label: r3 = ax.axhspan( ylims[0] + 0.05 * height, ylims[1] - 0.05 * height, edgecolor='k', facecolor='none', xmin=-test_time_start / test_duration, xmax=(duration - test_time_start) / test_duration) labels.append(r3) return labels .. GENERATED FROM PYTHON SOURCE LINES 71-73 Load EEG data ------------- .. GENERATED FROM PYTHON SOURCE LINES 73-92 .. code-block:: default # Load motor imagery data raw = read_raw_edf(eegbci.load_data(2, [5])[0], preload=True, verbose=False) eegbci.standardize(raw) raw.set_montage(make_standard_montage('standard_1005')) sfreq = int(raw.info['sfreq']) # 160 Hz # Select the 21 channels of the 10-20 montage raw.pick_channels( ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'Oz', 'O2'], ordered=True) ch_names = raw.ch_names ch_count = len(ch_names) # Define time-series epoching with a sliding window duration = 2.5 # duration of epochs interval = 0.2 # interval between epochs .. GENERATED FROM PYTHON SOURCE LINES 93-98 Riemannian potato ----------------- Riemannian potato (RP) [2]_ selects all channels and filter between 1 and 35 Hz. .. GENERATED FROM PYTHON SOURCE LINES 98-116 .. code-block:: default # RP definition z_th = 2.0 # z-score threshold low_freq, high_freq = 1., 35. rp = Potato(metric='riemann', threshold=z_th) # EEG processing for RP rp_sig = filter_bandpass(raw, low_freq, high_freq) # band-pass filter rp_epochs = make_fixed_length_epochs( # epoch time-series rp_sig, duration=duration, overlap=duration - interval, verbose=False) rp_covs = Covariances(estimator='scm').transform(rp_epochs.get_data()) # RP training train_covs = 45 # nb of matrices for training train_set = range(train_covs) rp.fit(rp_covs[train_set]) .. rst-class:: sphx-glr-script-out .. code-block:: none Using data from preloaded Raw for 603 events and 400 original time points ... 0 bad epochs dropped Potato(threshold=2.0) .. GENERATED FROM PYTHON SOURCE LINES 117-128 Riemannian potato field ----------------------- Riemannian potato field (RPF) [1]_ combines several potatoes of low dimensionality, each one designed to capture a different kind of artifact typically affecting some specific spatial area (i.e. subsets of channels) and/or specific frequency bands. BCI or NFB applications aim at the modulation specific brain oscillations, it is thus advisable to exclude such frequencies from potatoes so as to prevent desirable brain modulations to be detected as artifactual. .. GENERATED FROM PYTHON SOURCE LINES 128-165 .. code-block:: default # RPF definition p_th = 0.01 # probability threshold rpf_config = { 'RPF eye_blinks': { # for eye-blinks 'ch_names': ['Fp1', 'Fpz', 'Fp2'], 'low_freq': 1., 'high_freq': 20.}, 'RPF occipital': { # for high-frequency artifacts in occipital area 'ch_names': ['O1', 'Oz', 'O2'], 'low_freq': 25., 'high_freq': 45., 'cov_normalization': 'trace'}, # trace-norm to be insensitive to power 'RPF global_lf': { # for low-frequency artifacts in all channels 'ch_names': None, 'low_freq': 0.5, 'high_freq': 3.} } rpf = PotatoField(metric='riemann', z_threshold=z_th, p_threshold=p_th, n_potatoes=len(rpf_config)) # EEG processing for RPF rpf_covs = [] for p in rpf_config.values(): # loop on potatoes rpf_sig = filter_bandpass(raw, p.get('low_freq'), p.get('high_freq'), channels=p.get('ch_names')) rpf_epochs = make_fixed_length_epochs( rpf_sig, duration=duration, overlap=duration - interval, verbose=False) covs_ = Covariances(estimator='scm').transform(rpf_epochs.get_data()) if p.get('cov_normalization'): covs_ = normalize(covs_, p.get('cov_normalization')) rpf_covs.append(covs_) # RPF training rpf.fit([c[train_set] for c in rpf_covs]) .. rst-class:: sphx-glr-script-out .. code-block:: none Using data from preloaded Raw for 603 events and 400 original time points ... 0 bad epochs dropped Using data from preloaded Raw for 603 events and 400 original time points ... 0 bad epochs dropped Using data from preloaded Raw for 603 events and 400 original time points ... 0 bad epochs dropped /home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/v0.4/pyriemann/utils/mean.py:470: UserWarning: Convergence not reached warnings.warn('Convergence not reached') PotatoField(n_potatoes=3, z_threshold=2.0) .. GENERATED FROM PYTHON SOURCE LINES 166-173 Online Artifact Detection with Potatoes --------------------------------------- Detect artifacts/outliers on test set, with an animation to imitate an online acquisition, processing and artifact detection of EEG time-series. Remark that all these potatoes are semi-dynamic: they are updated when EEG is not artifacted [1]_. .. GENERATED FROM PYTHON SOURCE LINES 173-214 .. code-block:: default # Prepare data for online detection test_covs_max = 400 # nb of epochs to visualize in this example test_covs_visu = 100 # nb of z-scores/proba to display simultaneously test_time_start = -2 # start time to display signal test_time_end = 5 # end time to display signal test_duration = test_time_end - test_time_start time_start = train_covs * interval + test_time_start time_end = train_covs * interval + test_time_end time = np.linspace(time_start, time_end, int((time_end - time_start) * sfreq), endpoint=False) raw.filter(l_freq=0.5, h_freq=75., method='iir', verbose=False) eeg_data = 1e5 * raw.get_data() sig = eeg_data[:, int(time_start * sfreq):int(time_end * sfreq)] eeg_offset = - 15 * np.linspace(1, ch_count, ch_count, endpoint=False) covs_t, covs_z = np.empty([0]), np.empty([len(rpf_config) + 1, 0]) covs_p = np.empty([0]) fig, ax = plt.subplots(figsize=(12, 10), nrows=2, ncols=1) fig.suptitle('Online artifact detection, RP vs RPF', fontsize=16) ax[0].set(xlabel='Time (s)', ylabel='EEG channels') ax[0].set_xlim([time[0], time[-1]]) ax[0].set_yticks(eeg_offset) ax[0].set_yticklabels(ch_names) pl = ax[0].plot(time, sig.T + eeg_offset.T, lw=0.75) labels = [] ax[1].set(ylabel='Z-scores of distances to references') pl2 = ax[1].plot(covs_t, covs_z.T, lw=0.75) for c, l in enumerate(['RP'] + [*rpf_config]): pl2[c].set_label(l) ax[1].set_ylim([-1.5, 8.5]) ax[1].legend(loc='upper left') axp = ax[1].twinx() axp.set(ylabel='RPF probability of clean EEG') pl3 = axp.plot(covs_t, covs_p, lw=0.75, c='k', label='RPF proba') axp.set_ylim([0, 1]) axp.legend(loc='upper right') .. image-sg:: /auto_examples/artifacts/images/sphx_glr_plot_detect_riemannian_potato_field_EEG_001.png :alt: Online artifact detection, RP vs RPF :srcset: /auto_examples/artifacts/images/sphx_glr_plot_detect_riemannian_potato_field_EEG_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 215-269 .. code-block:: default # Prepare animation for online detection def online_detect(t): global time, sig, labels, covs_t, covs_z, covs_p # Online artifact detection rp_label = rp.predict(rp_covs[np.newaxis, t])[0] rp_zscore = rp.transform(rp_covs[np.newaxis, t]) rpf_label = rpf.predict([c[np.newaxis, t] for c in rpf_covs])[0] rpf_zscores = rpf.transform([c[np.newaxis, t] for c in rpf_covs]) rpf_proba = rpf.predict_proba([c[np.newaxis, t] for c in rpf_covs]) if rp_label == 1: rp.partial_fit(rp_covs[np.newaxis, t], alpha=1 / t) if rpf_label == 1: rpf.partial_fit([c[np.newaxis, t] for c in rpf_covs], alpha=1 / t) # Update data time_start = t * interval + test_time_end time_end = (t + 1) * interval + test_time_end time_ = np.linspace(time_start, time_end, int(interval * sfreq), endpoint=False) time = np.r_[time[int(interval * sfreq):], time_] sig = np.hstack((sig[:, int(interval * sfreq):], eeg_data[:, int(time_start*sfreq):int(time_end*sfreq)])) covs_t = np.r_[covs_t, time_start] covs_z = np.hstack((covs_z, np.vstack((rp_zscore[np.newaxis], rpf_zscores.T)))) covs_p = np.r_[covs_p, rpf_proba] if len(covs_p) > test_covs_visu: covs_t, covs_z, covs_p = covs_t[1:], covs_z[:, 1:], covs_p[1:] # Update plot for c in range(ch_count): pl[c].set_data(time, sig[c] + eeg_offset[c]) pl[c].axes.set_xlim(time[0], time[-1]) for lbl in labels: lbl.remove() for txt in ax[0].texts: txt.set_visible(False) labels = plot_detection(ax[0], rp_label, rpf_label) for c in range(len(pl2)): pl2[c].set_data(covs_t, covs_z[c]) pl2[c].axes.set_xlim(covs_t[0] - 0.1, covs_t[-1]) pl3[0].set_data(covs_t, covs_p) return pl, pl2, pl3 interval_display = 1.0 # can be changed for a slower display potato = FuncAnimation(fig, online_detect, frames=range(train_covs, test_covs_max), interval=interval_display, blit=False, repeat=False) .. GENERATED FROM PYTHON SOURCE LINES 270-286 .. code-block:: default # Plot online detection # Plot complete visu: a dynamic display is required plt.show() # Plot only 10s, for animated documentation try: from IPython.display import HTML except ImportError: raise ImportError("Install IPython to plot animation in documentation") plt.rcParams["animation.embed_limit"] = 10 HTML(potato.to_jshtml(fps=5, default_mode='loop')) .. rst-class:: sphx-glr-script-out .. code-block:: none Animation size has reached 10958720 bytes, exceeding the limit of 10485760.0. If you're sure you want a larger animation embedded, set the animation.embed_limit rc parameter to a larger value (in MB). This and further frames will be dropped. .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 287-300 References ---------- .. [1] `The Riemannian Potato Field: A Tool for Online Signal Quality Index of EEG `_ Q. Barthélemy, L. Mayaud, D. Ojeda, and M. Congedo. IEEE Transactions on Neural Systems and Rehabilitation Engineering, IEEE Institute of Electrical and Electronics Engineers, 2019, 27 (2), pp.244-255 .. [2] `The Riemannian Potato: an automatic and adaptive artifact detection method for online experiments using Riemannian geometry `_ A. Barachant, A Andreev, and M. Congedo. TOBI Workshop lV, Jan 2013, Sion, Switzerland. pp.19-20. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 8.752 seconds) .. _sphx_glr_download_auto_examples_artifacts_plot_detect_riemannian_potato_field_EEG.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_detect_riemannian_potato_field_EEG.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_detect_riemannian_potato_field_EEG.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_