"""Code for channel selection."""
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from .classification import MDM
from .geometry.distance import distance
[docs]
class ElectrodeSelection(TransformerMixin, BaseEstimator):
"""Channel selection based on a Riemannian geometry criterion.
For each class, a centroid is estimated, and the channel selection is based
on the maximization of the distance between centroids. This is done by a
backward elimination where the electrode that carries the less distance is
removed from the subset at each iteration [1]_.
Parameters
----------
nelec : int, default=16
The number of electrode to keep in the final subset.
metric : string | dict, default="riemann"
Metric used for mean estimation (for the list of supported metrics,
see :func:`pyriemann.geometry.mean.gmean`) and for distance estimation
(see :func:`pyriemann.geometry.distance.distance`).
The metric can be a dict with two keys, "mean" and "distance"
in order to pass different metrics.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each of the class centroid in parallel.
If -1 all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. For n_jobs below -1,
(n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.
Attributes
----------
covmeans_ : ndarray, shape (n_classes, n_channels, n_channels)
Centroids for each class.
dist_ : list
Distance at each iteration.
self.subelec_ : list
Indices of selected channels.
See Also
--------
Kmeans
FgMDM
References
----------
.. [1] `Channel selection procedure using riemannian distance for BCI
applications
<https://hal.archives-ouvertes.fr/hal-00602707>`_
A. Barachant and S. Bonnet. The 5th International IEEE EMBS Conference
on Neural Engineering, Apr 2011, Cancun, Mexico.
"""
[docs]
def __init__(self, nelec=16, metric="riemann", n_jobs=1):
"""Init."""
self.nelec = nelec
self.metric = metric
self.n_jobs = n_jobs
[docs]
def fit(self, X, y=None, sample_weight=None):
"""Find the optimal subset of electrodes.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y : None | ndarray, shape (n_matrices,), default=None
Labels for each matrix.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
self : ElectrodeSelection instance
The ElectrodeSelection instance.
"""
if y is None:
y = np.ones((X.shape[0]))
mdm = MDM(metric=self.metric, n_jobs=self.n_jobs)
mdm.fit(X, y, sample_weight=sample_weight)
self.covmeans_ = mdm.covmeans_
n_classes, n_channels, _ = self.covmeans_.shape
self.dist_ = []
self.subelec_ = list(range(n_channels))
while (len(self.subelec_)) > self.nelec:
di = np.zeros((len(self.subelec_), 1))
for idx in range(len(self.subelec_)):
sub = self.subelec_[:]
sub.pop(idx)
di[idx] = 0
for i in range(n_classes):
for j in range(i + 1, n_classes):
di[idx] += distance(
self.covmeans_[i][:, sub][sub, :],
self.covmeans_[j][:, sub][sub, :],
metric=mdm._metric_dist,
)
torm = di.argmax()
self.dist_.append(di.max())
self.subelec_.pop(torm)
return self
[docs]
class FlatChannelRemover(TransformerMixin, BaseEstimator):
"""Flat channel removal.
Attributes
----------
channels_ : ndarray, shape (n_good_channels,)
Indices of the non-flat channels.
"""
[docs]
def fit(self, X, y=None):
"""Find flat channels.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : FlatChannelRemover instance
The FlatChannelRemover instance.
"""
std = np.mean(np.std(X, axis=2) ** 2, 0)
self.channels_ = np.where(std)[0]
return self