"""Clustering functions."""
from math import floor
import warnings
from joblib import Parallel, delayed
import numpy as np
from scipy.stats import combine_pvalues, norm
import sklearn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cluster import KMeans as sklearnKMeans
from sklearn.utils.validation import check_random_state
from .classification import MDM
from .datasets import sample_gaussian_spd
from .geometry.covariance import covariance_scm
from .geometry.distance import (
distance,
pairwise_distance,
distance_mahalanobis,
)
from .geometry.geodesic import geodesic
from .geometry.mean import gmean
from .geometry.tangentspace import exp_map, log_map, tangent_space
from .utils._base import SpdClassifMixin, SpdClustMixin, SpdTransfMixin
from .utils._check import check_metric, check_function, check_weights
def _init_centroids(X, n_clusters, init, random_state, x_squared_norms):
if random_state is not None:
random_state = np.random.RandomState(random_state)
if sklearn.__version__ < "1.3.0":
return sklearnKMeans(n_clusters=n_clusters, init=init)._init_centroids(
X,
x_squared_norms,
init,
random_state,
)
else:
n_matrices = X.shape[0]
return sklearnKMeans(n_clusters=n_clusters, init=init)._init_centroids(
X,
x_squared_norms,
init,
random_state,
sample_weight=np.ones(n_matrices) / n_matrices,
)
def _fit_single(X, y=None, n_clusters=2, init="random", random_state=None,
metric="riemann", max_iter=100, tol=1e-4, n_jobs=1):
"""helper to fit a single run of centroid."""
# init random state if provided
mdm = MDM(metric=metric, n_jobs=n_jobs)
mdm._metric_mean, mdm._metric_dist = check_metric(metric)
squared_norms = np.linalg.norm(X, ord="fro", axis=(1, 2))**2
mdm.covmeans_ = _init_centroids(
X,
n_clusters,
init,
random_state=random_state,
x_squared_norms=squared_norms,
)
mdm.classes_ = np.arange(n_clusters)
labels = mdm.predict(X)
k = 0
while True:
old_labels = labels.copy()
mdm.fit(X, old_labels)
dist = mdm._predict_distances(X)
labels = mdm.classes_[dist.argmin(axis=1)]
k += 1
if (k > max_iter) | (np.mean(labels == old_labels) > (1 - tol)):
break
inertia = sum([
sum(dist[labels == mdm.classes_[i], i])
for i in range(len(mdm.classes_))
])
return labels, inertia, mdm
[docs]
class Kmeans(SpdClassifMixin, SpdClustMixin, SpdTransfMixin, BaseEstimator):
"""Clustering by k-means with SPD/HPD matrices as inputs.
The k-means is a clustering method used to find clusters that minimize the
sum of squared distances between centroids and SPD/HPD matrices [1]_.
Then, for each new matrix, the class is affected according to the nearest
centroid.
Parameters
----------
n_clusters : int, default=2
Number of clusters.
max_iter : int, default=100
Maximum number of iteration to reach convergence.
metric : string | dict, default="riemann"
Metric used for mean estimation (for the list of supported metrics,
see :func:`pyriemann.geometry.mean.gmean`) and for distance estimation
(see :func:`pyriemann.geometry.distance.distance`).
The metric can be a dict with two keys, "mean" and "distance"
in order to pass different metrics.
random_state : None | integer | np.RandomState, default=None
The generator used to initialize the centroids. If an integer is
given, it fixes the seed. Defaults to the global numpy random
number generator.
init : "random" | ndarray, shape (n_clusters, n_channels, n_channels), \
default="random"
Method for initialization of centroids.
If "random", it chooses k matrices at random for the initial centroids.
If an ndarray is passed, it should be of shape
(n_clusters, n_channels, n_channels) and gives the initial centroids.
n_init : int, default=10
Number of time the k-means algorithm will be run with different
centroid seeds. The final results will be the best output of
n_init consecutive runs in terms of inertia.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each of the n_init runs in parallel.
If -1 all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. For n_jobs below -1,
(n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.
tol : float, default=1e-4
Stopping criterion to stop convergence, representing the minimum
amount of change in labels between two iterations.
Attributes
----------
mdm_ : MDM instance
MDM instance containing the centroids.
labels_ : ndarray, shape (n_matrices,)
Labels, ie centroid indices, of each matrix of training set.
inertia_ : float
Sum of distances of matrices to their closest cluster centroids.
Notes
-----
.. versionadded:: 0.2.2
See Also
--------
Kmeans
MDM
References
----------
.. [1] `Commande robuste d'un effecteur par une interface cerveau machine
EEG asynchrone
<https://theses.hal.science/tel-01196752/>`_
A. Barachant, Thesis, 2012
"""
[docs]
def __init__(
self,
n_clusters=2,
max_iter=100,
metric="riemann",
random_state=None,
init="random",
n_init=10,
n_jobs=1,
tol=1e-4,
):
"""Init."""
self.metric = metric
self.n_clusters = n_clusters
self.max_iter = max_iter
self.random_state = random_state
self.init = init
self.n_init = n_init
self.tol = tol
self.n_jobs = n_jobs
[docs]
def fit(self, X, y=None):
"""Fit the centroids of clusters.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : Kmeans instance
The Kmeans instance.
"""
if isinstance(self.init, str) and self.init == "random":
np.random.seed(self.random_state)
seeds = np.random.randint(
np.iinfo(np.int32).max,
size=self.n_init,
)
res = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_single)(
X,
y,
n_clusters=self.n_clusters,
init=self.init,
random_state=seed,
metric=self.metric,
max_iter=self.max_iter,
tol=self.tol,
n_jobs=1,
) for seed in seeds
)
labels, inertia, mdm = zip(*res)
best = np.argmin(inertia)
mdm = mdm[best]
labels = labels[best]
inertia = inertia[best]
else:
# no need to iterate if init is not random
labels, inertia, mdm = _fit_single(
X,
y,
n_clusters=self.n_clusters,
init=self.init,
random_state=self.random_state,
metric=self.metric,
max_iter=self.max_iter,
tol=self.tol,
n_jobs=self.n_jobs,
)
self.mdm_ = mdm
self.inertia_ = inertia
self.labels_ = labels
return self
[docs]
def predict(self, X):
"""Get the predictions.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
Returns
-------
pred : ndarray of int, shape (n_matrices,)
Prediction for each matrix according to the closest centroid.
"""
return self.mdm_.predict(X)
[docs]
def centroids(self):
"""Helper for fast access to the centroids.
Returns
-------
centroids : ndarray, shape (n_clusters, n_channels, n_channels)
Centroids of each cluster.
"""
return self.mdm_.covmeans_
###############################################################################
@np.vectorize
def kernel_normal(x):
return np.exp(- x ** 2)
@np.vectorize
def kernel_uniform(x):
if np.abs(x) <= 1:
return 1
return 0
ker_clust_functions = {
"normal": kernel_normal,
"uniform": kernel_uniform,
}
[docs]
class MeanShift(SpdClustMixin, BaseEstimator):
"""Clustering by mean shift with SPD/HPD matrices as inputs.
The mean shift is a non-parametric clustering method used to find clusters
on the manifold of SPD/HPD matrices, estimating the gradient of matrices
density [1]_.
Parameters
----------
kernel : {"normal", "uniform"} | callable, default="uniform"
Kernel used for kernel density estimation.
bandwidth : None | float, default=None
Bandwidth of the kernel.
metric : string | dict, default="riemann"
Metric used for map estimation (for the list of supported metrics,
see :func:`pyriemann.geometry.tangentspace.log_map`) and
for distance estimation
(see :func:`pyriemann.geometry.distance.distance`).
The metric can be a dict with two keys, "map" and "distance"
in order to pass different metrics.
tol : float, default=1e-4
Stopping criterion to stop convergence, representing the norm of
gradient.
max_iter : int, default=100
Maximum number of iteration to reach convergence.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each of the n_init runs in parallel.
If -1 all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. For n_jobs below -1,
(n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.
Attributes
----------
modes_ : ndarray, shape (n_modes, n_channels, n_channels)
Modes of each cluster.
labels_ : ndarray, shape (n_matrices,)
Labels, ie mode indices, of each matrix of training set.
Notes
-----
.. versionadded:: 0.9
See Also
--------
Kmeans
References
----------
.. [1] `Nonlinear Mean Shift over Riemannian Manifolds
<https://sites.rutgers.edu/peter-meer/wp-content/uploads/sites/69/2019/01/manifoldmsijcv.pdf>`_
R. Subbarao & P. Meer. International Journal of Computer Vision, 84,
1-20, 2009
""" # noqa
[docs]
def __init__(
self,
kernel="uniform",
bandwidth=None,
metric="riemann",
tol=1e-3,
max_iter=100,
n_jobs=1
):
"""Init."""
self.kernel = kernel
self.bandwidth = bandwidth
self.metric = metric
self.tol = tol
self.max_iter = max_iter
self.n_jobs = n_jobs
[docs]
def fit(self, X, y=None):
"""Fit the modes of clusters.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : MeanShift instance
The MeanShift instance.
"""
self._kernel_fun = check_function(self.kernel, ker_clust_functions)
self._metric_map, self._metric_dist = check_metric(
self.metric, ["map", "dist"]
)
if self.bandwidth is None:
self._bandwidth = self._estimate_bandwidth(X, quantile=0.3)
self._bandwidth2 = self._bandwidth ** 2
modes = Parallel(n_jobs=self.n_jobs)(
delayed(self._seek_mode)(X, x) for x in X
)
modes = self._fuse_mode(modes)
self.modes_ = np.array(modes)
self.labels_ = self.predict(X)
return self
def _estimate_bandwidth(self, X, quantile):
dist = pairwise_distance(X, None, metric=self._metric_dist)
dist = np.triu(dist, 1)
dist_sorted = np.sort(dist[dist > 0])
bandwidth = dist_sorted[floor(quantile * len(dist_sorted))]
print(f"MeanShift bandwidth={bandwidth:.3f}")
return bandwidth
def _seek_mode(self, X, mean):
for _ in range(self.max_iter):
T = log_map(X, mean, metric=self._metric_map)
dist2 = distance(X, mean, metric=self._metric_dist, squared=True)
weights = self._kernel_fun(dist2[:, 0] / self._bandwidth2)
meanshift = np.einsum("a,abc->bc", weights, T) / np.sum(weights)
mean = exp_map(meanshift, mean, metric=self._metric_map)
if np.linalg.norm(meanshift) <= self.tol:
break
else:
warnings.warn("Convergence not reached")
return mean
def _fuse_mode(self, in_modes):
out_modes = in_modes.copy()
in_modes = np.stack(in_modes, axis=0)
dist = pairwise_distance(in_modes, None, metric=self._metric_dist)
np.fill_diagonal(dist, self._bandwidth + 1)
for i in range(dist.shape[0] - 1, -1, -1):
if np.min(dist[i]) < self._bandwidth:
del out_modes[i]
dist[:, i] = self._bandwidth + 1
if len(out_modes) == 0:
raise ValueError(
"No mode found, try other parameters (Got "
f"kernel={self.kernel} and bandwith={self._bandwidth:.3f})"
)
return out_modes
[docs]
def predict(self, X):
"""Get the predictions.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
Returns
-------
pred : ndarray of int, shape (n_matrices,)
Prediction for each matrix according to the closest mode.
"""
dist = Parallel(n_jobs=self.n_jobs)(
delayed(distance)(X, mode, self._metric_dist)
for mode in self.modes_
)
dist = np.concatenate(dist, axis=1)
return dist.argmin(axis=1)
###############################################################################
class Gaussian():
"""Gaussian model.
Gaussian model for Riemannian manifold of SPD matrices,
defined with a mean in manifold and a covariance in tangent space [1]_.
Parameters
----------
n : integer
Dimension of the matrices.
mu : ndarray, shape (n, n)
Mean of the Gaussian, in manifold.
sigma : None | ndarray, shape (n * (n + 1) / 2, n * (n + 1) / 2), \
default=None
Covariance of the Gaussian, in tangent space.
If None, it uses identity matrix.
metric : string | dict, default="riemann"
Metric used for mean update (for the list of supported metrics,
see :func:`pyriemann.geometry.mean.gmean`) and for tangent space map
(see :func:`pyriemann.geometry.tangent_space.tangent_space`).
The metric can be a dict with two keys, "mean" and "map"
in order to pass different metrics.
Notes
-----
.. versionadded:: 0.11
References
----------
.. [1] `Intrinsic statistics on Riemannian manifolds: Basic tools for
geometric measurements
<https://www.cis.jhu.edu/~tingli/App_of_Lie_group/Intrinsic%20Statistics%20on%20Riemannian%20Manifolds.pdf>`_
X. Pennec. Journal of Mathematical Imaging and Vision, 2006
""" # noqa
def __init__(self, n, mu, sigma=None, metric="riemann"):
self.n = n
self.mu = mu
if sigma is None:
sigma = np.eye(n * (n + 1) // 2)
self.sigma = sigma
self.metric = metric
self._metric_mean, self._metric_map = check_metric(
metric, ["mean", "map"]
)
def pdf(self, X, *, reg=1e-16, use_pi=True):
"""Compute approximate probability density function (pdf) of matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n, n)
Set of SPD matrices.
reg : float, default=1e-16
Regularization parameter for pdf normalization term.
use_pi : bool, default=True
If true, use (2 pi)^n to compute the full denominator.
If false, do not use (2 pi)^n, because will be simplified with
upcoming normalizations.
Returns
-------
pdf : ndarray, shape (n_matrices,)
Probability density function of each matrix.
"""
TangVec = tangent_space(X, self.mu, metric=self._metric_map)
dist = distance_mahalanobis(TangVec.T, self.sigma, squared=True)
num = np.exp(-0.5 * dist)
det = np.linalg.det(self.sigma)
if use_pi:
denom = np.sqrt(((2 * np.pi) ** self.n) * det)
else:
denom = np.sqrt(det)
return num / (denom + reg)
def update_mean(self, X, sample_weight):
"""Update mean in manifold.
Compute weighted mean of matrices, initialized on previous mean.
Parameters
----------
X : ndarray, shape (n_matrices, n, n)
Set of SPD matrices.
sample_weight : ndarray, shape (n_matrices,)
Weights for each matrix.
"""
self.mu = gmean(
X,
metric=self._metric_mean,
sample_weight=sample_weight,
init=self.mu,
)
def update_covariance(self, X, sample_weight):
"""Update covariance in tangent space.
Compute weighted covariance of tangent vectors.
Parameters
----------
X : ndarray, shape (n_matrices, n, n)
Set of SPD matrices.
sample_weight : ndarray, shape (n_matrices,)
Weights for each matrix.
"""
TangVec = tangent_space(X, self.mu, metric=self._metric_map)
self.sigma = covariance_scm(
TangVec.T,
assume_centered=True,
weights=sample_weight,
)
[docs]
class GaussianMixture(SpdClustMixin, BaseEstimator):
"""Gaussian mixture model.
Representation of a Gaussian mixture model (GMM) probability distribution
for SPD matrices by expectation-maximization (EM) algorithm [1]_.
Parameters
----------
n_components : integer, default=1
Number of mixture components.
metric : string | dict, default="riemann"
Metric used for mean update (for the list of supported metrics,
see :func:`pyriemann.geometry.mean.gmean`) and for tangent space map
(see :func:`pyriemann.geometry.tangent_space.tangent_space`).
The metric can be a dict with two keys, "mean" and "map"
in order to pass different metrics.
weights_init : None | ndarray, shape (n_components,), defaut=None
Initial weights. If None, it uses equal weights.
means_init : None | ndarray, shape (n_components,), defaut=None
Initial means of Gaussians. If None, it randomly selects training
matrices.
tol : float, default=1e-5
Tolerance to stop the EM algorithm.
maxiter : int, default=100
Maximum number of iterations of EM algorithm.
random_state : None | integer | np.RandomState, default=None
The generator used to initialize the Gaussian models. If an integer is
given, it fixes the seed. Defaults to the global numpy random
number generator.
verbose : bool, default=False
Verbose flag.
Attributes
----------
weights_ : ndarray, shape (n_components,)
Weight of each mixture component.
means_ : ndarray, shape (n_components, n_channels, n_channels)
Mean of each mixture component.
covariances_ : ndarray, shape (n_components, n_ts, n_ts)
Covariance of each mixture component.
Notes
-----
.. versionadded:: 0.11
References
----------
.. [1] `Gaussian mixture regression on symmetric positive definite matrices
manifolds: Application to wrist motion estimation with sEMG
<https://calinon.ch/papers/Jaquier-IROS2017.pdf>`_
N. Jacquier & S. Calinon. IEEE IROS, 2017
"""
[docs]
def __init__(
self,
n_components=1,
metric="riemann",
weights_init=None,
means_init=None,
tol=1e-5,
maxiter=100,
random_state=None,
verbose=False,
):
"""Init."""
self.n_components = n_components
self.metric = metric
self.weights_init = weights_init
self.means_init = means_init
self.tol = tol
self.maxiter = maxiter
self.random_state = random_state
self.verbose = verbose
@property
def means_(self):
return np.stack([component.mu for component in self._components])
@property
def covariances_(self):
return np.stack([component.sigma for component in self._components])
def _get_wlik(self, X, use_pi=True):
"""Compute weighted likelihoods.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
use_pi : bool, default=True
If true, use (2 pi)^n to compute the full denominator of pdf.
If false, do not use (2 pi)^n, because will be simplified with
upcoming normalizations.
Returns
-------
wlik : ndarray, shape (n_matrices, n_components)
Weighted likelihood of each matrix given component.
"""
wlik = np.zeros((X.shape[0], self.n_components))
for k in range(self.n_components):
lik = self._components[k].pdf(X, use_pi=use_pi)
wlik[:, k] = self.weights_[k] * lik
return wlik
def _get_proba(self, X, reg=1e-16):
"""Compute posterior probabilities.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
reg : float, default=1e-16
Regularization parameter for probabilities normalization.
Returns
-------
prob : ndarray, shape (n_matrices, n_components)
Posterior probability of each component given matrix.
"""
num = self._get_wlik(X, use_pi=False)
prob = num / (np.sum(num, axis=1, keepdims=True) + reg)
return prob
def _log(self, X):
"""Log after clip."""
return np.log(np.clip(X, a_min=1e-10, a_max=None))
[docs]
def fit(self, X, y=None):
"""Fit the mixture with EM.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : GaussianMixture instance
The GaussianMixture instance.
"""
n_matrices, n_channels, _ = X.shape
if (n_channels * (n_channels + 1) // 2 > n_matrices):
raise ValueError("Not enough matrices for training GMM.")
# initialization
self.random_state = check_random_state(self.random_state)
if isinstance(self.means_init, np.ndarray) and self.means_init.shape \
== (self.n_components, n_channels, n_channels):
means_init = self.means_init
else:
inds = self.random_state.randint(
n_matrices,
size=(self.n_components,)
)
means_init = X[inds]
self._components = []
for k in range(self.n_components):
self._components.append(
Gaussian(
n_channels,
mu=means_init[k],
sigma=None,
metric=self.metric,
)
)
self.weights_ = check_weights(self.weights_init, self.n_components)
# expectation-maximization
crit = 0
for _ in range(self.maxiter):
# e-step
prob = self._get_proba(X)
# m-step
self.weights_ = np.sum(prob, axis=0) / n_matrices
# re-normalization (necessary because of approx Gaussian pdf?)
self.weights_ = self.weights_ / self.weights_.sum()
for k in range(self.n_components):
self._components[k].update_mean(X, prob[:, k])
self._components[k].update_covariance(X, prob[:, k])
# check convergence
crit_new = -np.sum(self._log(np.sum(self._get_wlik(X), axis=1)))
if self.verbose:
print(f"neg log-likelihood = {crit_new}")
if abs(crit - crit_new) < self.tol:
break
crit = crit_new
else:
warnings.warn("EM convergence not reached")
return self
[docs]
def predict_proba(self, X):
"""Predict probabilities.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
Returns
-------
prob : ndarray, shape (n_matrices, n_components)
Probabilities for each component.
"""
return self._get_proba(X)
[docs]
def predict(self, X):
"""Get the predictions.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
Returns
-------
pred : ndarray of int, shape (n_matrices,)
Predictions for each matrix.
"""
prob = self._get_proba(X)
return np.argmax(prob, axis=1)
[docs]
def score(self, X, y=None):
"""Compute the average log-likelihood of the given matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
score : float
Log-likelihood of matrices under the Gaussian mixture model.
"""
lik = np.sum(self._get_wlik(X), axis=1)
return np.mean(self._log(lik))
[docs]
def sample(self, n_matrices=1):
"""Generate random matrices from the fitted Gaussian distribution.
Warning: GMM is calibrated using the Gaussian model [1]_,
while this sampling uses the wrapped Gaussian model [2]_.
Parameters
----------
n_matrices : int, default=1
Number of matrices to generate.
Returns
-------
X : array, shape (n_matrices, n_channels, n_channels)
Randomly generated matrices.
y : array, shape (n_matrices,)
Component labels.
References
----------
.. [1] `Intrinsic statistics on Riemannian manifolds: Basic tools for
geometric measurements
<https://www.cis.jhu.edu/~tingli/App_of_Lie_group/Intrinsic%20Statistics%20on%20Riemannian%20Manifolds.pdf>`_
X. Pennec. Journal of Mathematical Imaging and Vision, 2006
.. [2] `Wrapped gaussian on the manifold of symmetric positive
definite matrices
<https://openreview.net/pdf?id=EhStXG4dCS>`_
T. de Surrel, F. Lotte, S. Chevallier, and F. Yger. ICML, 2025
""" # noqa
y = self.random_state.randint(self.n_components, size=(n_matrices,))
means, covariances = self.means_, self.covariances_
n_channels = means.shape[-1]
X = np.zeros((n_matrices, means.shape[-1], n_channels))
for i in np.unique(y):
X[y == i] = sample_gaussian_spd(
np.count_nonzero(y == i),
mean=means[i],
sigma=covariances[i],
random_state=self.random_state
)
return X, y
###############################################################################
[docs]
class Potato(TransformerMixin, SpdClassifMixin, BaseEstimator):
"""Artifact detection with the Riemannian Potato.
The Riemannian Potato [1]_ is a clustering method used to detect artifact
in multichannel signals. Processing SPD/HPD matrices,
the algorithm iteratively estimates the centroid of clean
matrices by rejecting every matrix that is too far from it.
Parameters
----------
metric : string | dict, default="riemann"
Metric used for mean estimation (for the list of supported metrics,
see :func:`pyriemann.geometry.mean.gmean`) and for distance estimation
(see :func:`pyriemann.geometry.distance.distance`).
The metric can be a dict with two keys, "mean" and "distance"
in order to pass different metrics.
threshold : float, default=3
Threshold on z-score of distance to reject artifacts. It is the number
of standard deviations from the mean of distances to the centroid.
n_iter_max : int, default=100
The maximum number of iteration to reach convergence.
pos_label : int, default=1
The positive label corresponding to clean data.
neg_label : int, default=0
The negative label corresponding to artifact data.
Attributes
----------
covmean_ : ndarray, shape (n_channels, n_channels)
Centroid of potato.
Notes
-----
.. versionadded:: 0.2.3
See Also
--------
MDM
References
----------
.. [1] `The Riemannian Potato: an automatic and adaptive artifact detection
method for online experiments using Riemannian geometry
<https://hal.archives-ouvertes.fr/hal-00781701>`_
A. Barachant, A Andreev, and M. Congedo. TOBI Workshop lV, Jan 2013,
Sion, Switzerland. pp.19-20.
.. [2] `The Riemannian Potato Field: A Tool for Online Signal Quality Index
of EEG
<https://hal.archives-ouvertes.fr/hal-02015909>`_
Q. Barthélemy, L. Mayaud, D. Ojeda, and M. Congedo. IEEE Transactions
on Neural Systems and Rehabilitation Engineering, IEEE Institute of
Electrical and Electronics Engineers, 2019, 27 (2), pp.244-255
"""
[docs]
def __init__(
self,
metric="riemann",
threshold=3,
n_iter_max=100,
pos_label=1,
neg_label=0,
):
"""Init."""
self.metric = metric
self.threshold = threshold
self.n_iter_max = n_iter_max
self.pos_label = pos_label
self.neg_label = neg_label
[docs]
def fit(self, X, y=None, sample_weight=None):
"""Fit the potato.
Fit the potato from SPD/HPD matrices, with an iterative outlier
removal to obtain a reliable potato.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
y : None | ndarray, shape (n_matrices,), default=None
Labels corresponding to each matrix: positive (resp. negative)
label corresponds to a clean (resp. artifact) matrix.
If None, all matrices are considered as clean.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
self : Potato instance
The Potato instance.
"""
if self.pos_label == self.neg_label:
raise ValueError("Positive and negative labels must be different")
n_matrices, _, _ = X.shape
y_old = self._check_labels(X, y)
if sample_weight is None:
sample_weight = np.ones(n_matrices)
self._metric_mean, _ = check_metric(self.metric)
self._mdm = MDM(metric=self.metric)
for _ in range(self.n_iter_max):
ix = (y_old == 1)
if not any(ix):
raise ValueError("Iterative outlier removal has rejected all "
"matrices. Choose a higher threshold.")
self._mdm.fit(X[ix], y_old[ix], sample_weight=sample_weight[ix])
y = np.zeros(n_matrices)
d = np.squeeze(np.log(self._mdm.transform(X[ix])))
self._mean = np.mean(d)
self._std = np.std(d)
y[ix] = self._get_z_score(d) < self.threshold
if np.array_equal(y, y_old):
break
else:
y_old = y
self.covmean_ = self._mdm.covmeans_[0]
return self
[docs]
def partial_fit(self, X, y=None, *, sample_weight=None, alpha=0.1):
"""Partially fit the potato.
This partial fit can be used to update dynamic or semi-dymanic online
potatoes with clean matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
y : None | ndarray, shape (n_matrices,), default=None
Labels corresponding to each matrix: positive (resp. negative)
label corresponds to a clean (resp. artifact) matrix.
If None, all matrices are considered as clean.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
alpha : float, default=0.1
Update rate in [0, 1] for the centroid, and mean and standard
deviation of log-distances: 0 for no update, 1 for full update.
Returns
-------
self : Potato instance
The Potato instance.
Notes
-----
.. versionadded:: 0.3
"""
if not hasattr(self, "_mdm"):
raise ValueError(
"partial_fit can be called only on an already fitted potato."
)
n_matrices, n_channels, _ = X.shape
if n_channels != self._mdm.covmeans_[0].shape[0]:
raise ValueError(
"X does not have the good number of channels. Should be %d but"
" got %d." % (self._mdm.covmeans_[0].shape[0], n_channels)
)
y = self._check_labels(X, y)
if sample_weight is None:
sample_weight = np.ones(X.shape[0])
if not 0 <= alpha <= 1:
raise ValueError("Parameter alpha must be in [0, 1]")
if alpha == 0:
return self
Xm = gmean(
X[y == self.pos_label],
metric=self._metric_mean,
sample_weight=sample_weight[y == self.pos_label],
)
self._mdm.covmeans_[0] = geodesic(
self._mdm.covmeans_[0], Xm, alpha, metric=self._metric_mean
)
d = np.squeeze(np.log(self._mdm.transform(Xm[np.newaxis, ...])))
self._mean = (1 - alpha) * self._mean + alpha * d
self._std = np.sqrt(
(1 - alpha) * self._std**2 + alpha * (d - self._mean)**2
)
self.covmean_ = self._mdm.covmeans_[0]
return self
[docs]
def predict(self, X):
"""Predict artifact from data.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
Returns
-------
pred : ndarray of bool, shape (n_matrices,)
The artifact detection: True if the matrix is clean, and False if
the matrix contains an artifact.
"""
z = self.transform(X)
pred = z < self.threshold
out = np.zeros_like(z) + self.neg_label
out[pred] = self.pos_label
return out
[docs]
def predict_proba(self, X):
"""Return probability of belonging to the potato / being clean.
It is the probability to reject the null hypothesis "clean data",
computing the right-tailed probability from z-score.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
Returns
-------
proba : ndarray, shape (n_matrices,)
Matrix is considered as normal/clean for high value of proba.
It is considered as abnormal/artifacted for low value of proba.
Notes
-----
.. versionadded:: 0.2.7
"""
z = self.transform(X)
proba = self._get_proba(z)
return proba
def _check_labels(self, X, y):
"""Check validity of labels."""
if y is not None:
if len(y) != len(X):
raise ValueError("y must be the same length of X")
classes = np.int32(np.unique(y))
if len(classes) > 2:
raise ValueError("number of classes must be maximum 2")
if self.pos_label not in classes:
raise ValueError("y must contain a positive class")
y = np.int32(np.array(y) == self.pos_label)
else:
y = np.ones(len(X))
return y
def _get_z_score(self, d):
"""Get z-score from distance."""
z = (d - self._mean) / self._std
return z
def _get_proba(self, z):
"""Get right-tailed proba from z-score."""
proba = 1 - norm.cdf(z)
return proba
def _check_n_matrices(X, n_matrices):
"""Check number of matrices in ndarray."""
if X.shape[0] != n_matrices:
raise ValueError(
"Unequal n_matrices between ndarray of X. Should be %d but"
" got %d." % (n_matrices, X.shape[0])
)
[docs]
class PotatoField(TransformerMixin, SpdClassifMixin, BaseEstimator):
"""Artifact detection with the Riemannian Potato Field.
The Riemannian Potato Field [1]_ is a clustering method used to detect
artifact in multichannel signals. Processing SPD/HPD matrices,
the algorithm combines several potatoes of low dimension,
each one being designed to capture specific artifact typically
affecting specific subsets of channels and/or specific frequency bands.
Parameters
----------
n_potatoes : int, default=1
Number of potatoes in the field.
p_threshold : float, default=0.01
Threshold on probability to being clean, in (0, 1), combining
probabilities of potatoes using ``method_combination``.
z_threshold : float, default=3
Threshold on z-score of distance to reject artifacts. It is the number
of standard deviations from the mean of distances to the centroid.
metric : string | dict | list, default="riemann"
Metric used for mean estimation (for the list of supported metrics,
see :func:`pyriemann.geometry.mean.gmean`) and for distance estimation
(see :func:`pyriemann.geometry.distance.distance`).
The metric can be a single str;
or a dict with two keys, "mean" and "distance",
in order to pass different metrics for mean and distance;
or a list of ``n_potatoes`` str or dict,
in order to pass different metrics for each potato [2]_.
.. versionchanged:: 0.11
Allow a different metric per potato.
n_iter_max : int, default=10
The maximum number of iteration to reach convergence.
pos_label : int, default=1
The positive label corresponding to clean data.
neg_label : int, default=0
The negative label corresponding to artifact data.
method_combination : {"fisher", "stouffer"} | callable, default="fisher"
Method to combine probabilities from the different potatoes:
* fisher: Fisher's method;
* stouffer: Stouffer's z-score method;
* callable: for a custom combination function, with an axis argument.
.. versionadded:: 0.11
Notes
-----
.. versionadded:: 0.3
See Also
--------
Potato
References
----------
.. [1] `The Riemannian Potato Field: A Tool for Online Signal Quality Index
of EEG
<https://hal.archives-ouvertes.fr/hal-02015909>`_
Q. Barthélemy, L. Mayaud, D. Ojeda, and M. Congedo. IEEE
Transactions on Neural Systems and Rehabilitation Engineering, 2019
.. [2] `Improved Riemannian potato field: an Automatic Artifact Rejection
Method for EEG
<https://arxiv.org/pdf/2509.09264>`_
D. Hajhassani, Q. Barthélemy, J. Mattout & M. Congedo.
Biomedical Signal Processing and Control, 2026
"""
[docs]
def __init__(
self,
n_potatoes=1,
p_threshold=0.01,
z_threshold=3,
metric="riemann",
n_iter_max=10,
pos_label=1,
neg_label=0,
method_combination="fisher",
):
"""Init."""
self.n_potatoes = int(n_potatoes)
self.p_threshold = p_threshold
self.metric = metric
self.z_threshold = z_threshold
self.n_iter_max = n_iter_max
self.pos_label = pos_label
self.neg_label = neg_label
self.method_combination = method_combination
[docs]
def fit(self, X, y=None, sample_weight=None):
"""Fit the potato field.
Fit the potato field from SPD/HPD matrices, with iterative
outlier removal to obtain reliable potatoes.
Parameters
----------
X : list of n_potatoes ndarrays of shape (n_matrices, n_channels, \
n_channels) with same n_matrices but potentially different \
n_channels
List of sets of SPD/HPD matrices, each corresponding to a different
subset of channels and/or filtering with a specific frequency band.
y : None | ndarray, shape (n_matrices,), default=None
Labels corresponding to each matrix: positive (resp. negative)
label corresponds to a clean (resp. artifact) matrix.
If None, all matrices are considered as clean.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
Returns
-------
self : PotatoField instance
The PotatoField instance.
"""
if self.n_potatoes < 1:
raise ValueError("Parameter n_potatoes must be at least 1")
if not 0 < self.p_threshold < 1:
raise ValueError("Parameter p_threshold must be in (0, 1)")
self._check_length(X)
n_matrices = X[0].shape[0]
if isinstance(self.metric, (str, dict)):
metric = [self.metric] * self.n_potatoes
elif isinstance(self.metric, list):
if len(self.metric) == self.n_potatoes:
metric = self.metric
else:
raise ValueError(
f"Metric must be a list with {self.n_potatoes} elements."
)
else:
raise TypeError(
"Metric must be a str, a dict or a list, "
f"but got {type(self.metric)}."
)
self._potatoes = []
for i in range(self.n_potatoes):
_check_n_matrices(X[i], n_matrices)
pt = Potato(
metric=metric[i],
threshold=self.z_threshold,
n_iter_max=self.n_iter_max,
pos_label=self.pos_label,
neg_label=self.neg_label,
)
self._potatoes.append(pt)
self._potatoes[i].fit(X[i], y, sample_weight=sample_weight)
return self
[docs]
def partial_fit(self, X, y=None, *, sample_weight=None, alpha=0.1):
"""Partially fit the potato field.
This partial fit can be used to update dynamic or semi-dymanic online
potatoes with clean matrices.
Parameters
----------
X : list of n_potatoes ndarrays of shape (n_matrices, n_channels, \
n_channels) with same n_matrices but potentially different \
n_channels
List of sets of SPD/HPD matrices, each corresponding to a different
subset of channels and/or filtering with a specific frequency band.
y : None | ndarray, shape (n_matrices,), default=None
Labels corresponding to each matrix: positive (resp. negative)
label corresponds to a clean (resp. artifact) matrix.
If None, all matrices are considered as clean.
sample_weight : None | ndarray, shape (n_matrices,), default=None
Weights for each matrix. If None, it uses equal weights.
alpha : float, default=0.1
Update rate in [0, 1] for the centroid, and mean and standard
deviation of log-distances: 0 for no update, 1 for full update.
Returns
-------
self : PotatoField instance
The PotatoField instance.
"""
if not hasattr(self, "_potatoes"):
raise ValueError("partial_fit can be called only on an already "
"fitted potato field.")
self._check_length(X)
n_matrices = X[0].shape[0]
for i in range(self.n_potatoes):
_check_n_matrices(X[i], n_matrices)
self._potatoes[i].partial_fit(
X[i],
y,
sample_weight=sample_weight,
alpha=alpha,
)
return self
[docs]
def predict(self, X):
"""Predict artifact from data.
Parameters
----------
X : list of n_potatoes ndarrays of shape (n_matrices, n_channels, \
n_channels) with same n_matrices but potentially different \
n_channels
List of sets of SPD/HPD matrices, each corresponding to a different
subset of channels and/or filtering with a specific frequency band.
Returns
-------
pred : ndarray of bool, shape (n_matrices,)
The artifact detection: True if the matrix is clean, and False if
the matrix contains an artifact.
"""
p = self.predict_proba(X)
pred = p > self.p_threshold
out = np.zeros_like(p) + self.neg_label
out[pred] = self.pos_label
return out
[docs]
def predict_proba(self, X):
"""Predict probability combining probabilities of potatoes.
Predict probability combining probabilities of the different potatoes
using ``method_combination``.
With Fisher's method, a threshold of 0.01 can be used.
Parameters
----------
X : list of n_potatoes ndarrays of shape (n_matrices, n_channels, \
n_channels) with same n_matrices but potentially different \
n_channels
List of sets of SPD/HPD matrices, each corresponding to a different
subset of channels and/or filtering with a specific frequency band.
Returns
-------
proba : ndarray, shape (n_matrices,)
Matrix is considered as normal/clean for high value of proba.
It is considered as abnormal/artifacted for low value of proba.
"""
self._check_length(X)
n_matrices = X[0].shape[0]
probas = np.zeros((self.n_potatoes, n_matrices))
for i in range(self.n_potatoes):
_check_n_matrices(X[i], n_matrices)
probas[i] = self._potatoes[i].predict_proba(X[i])
probas = np.clip(probas, a_min=1e-10, a_max=1) # avoid trouble w. log
if isinstance(self.method_combination, str):
_, proba = combine_pvalues(
probas,
method=self.method_combination,
axis=0,
)
elif callable(self.method_combination):
proba = self.method_combination(probas, axis=0)
else:
raise TypeError(
"method_combination must be a str or a callable, "
f"but got {type(self.method_combination)}."
)
return proba
def _check_length(self, X):
"""Check validity of input length."""
if len(X) != self.n_potatoes:
raise ValueError(
"Length of X is not equal to n_potatoes. Should be %d but got "
"%d." % (self.n_potatoes, len(X))
)