from functools import partial
import warnings
from joblib import Parallel, delayed
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.base import BaseEstimator
from sklearn.utils import check_random_state
from ..geometry._docs import deprecated
from ..geometry.base import ctranspose, sqrtm
from ..geometry.geodesic import geodesic
from ..geometry.tangentspace import exp_map_riemann, unupper
from ..geometry.test import is_herm_pos_semi_def as is_hpsd
def _pdf_r(r, sigma):
"""Pdf for the log of eigenvalues of a SPD matrix.
Probability density function for the logarithm of the eigenvalues of a SPD
matrix samples from the Riemannian Gaussian distribution.
See Said2017 for the mathematical details.
Parameters
----------
r : ndarray, shape (n_dim,)
Vector with the logarithm of the eigenvalues of a SPD matrix.
sigma : float
Dispersion of the Riemannian Gaussian distribution.
Returns
-------
p : float
Probability density function applied to r.
"""
if (sigma <= 0):
raise ValueError(f"sigma must be a positive number (Got {sigma})")
n_dim = len(r)
partial_1 = -np.sum(r**2) / (2 * sigma**2)
partial_2 = 0
for i in range(n_dim):
for j in range(i + 1, n_dim):
partial_2 = partial_2 + np.log(np.sinh(np.abs(r[i] - r[j]) / 2))
return np.exp(partial_1 + partial_2)
def _rejection_sampling_2D_gfunction_plus(sigma, r_sample):
"""Auxiliary function for the 2D rejection sampling algorithm.
It is used in the case where r is sampled with the function g+.
Parameters
----------
sigma : float
Dispersion of the Riemannian Gaussian distribution.
r_sample : ndarray, shape (1, n_dim)
Sample of the r parameters of the Riemannian Gaussian distribution.
Returns
-------
p : float
Probability of acceptation.
Notes
-----
.. versionadded:: 0.4
"""
mu_a = np.array([-sigma**2 / 2, (sigma**2) / 2])
cov_matrix = (sigma**2) * np.eye(2)
m = np.pi * (sigma**2) * np.exp(sigma**2 / 4)
if r_sample[0] >= r_sample[1]:
num = _pdf_r(r_sample, sigma)
den = multivariate_normal.pdf(r_sample, mean=mu_a, cov=cov_matrix) * m
return num / den
return 0
def _rejection_sampling_2D_gfunction_minus(sigma, r_sample):
"""Auxiliary function for the 2D rejection sampling algorithm.
It is used in the case where r is sampled with the function g-.
Parameters
----------
sigma : float
Dispersion of the Riemannian Gaussian distribution.
r_sample : ndarray, shape (1, n_dim)
Sample of the r parameters of the Riemannian Gaussian distribution.
Returns
-------
p : float
Probability of acceptation.
Notes
-----
.. versionadded:: 0.4
"""
mu_b = np.array([(sigma**2) / 2, -sigma**2 / 2])
cov_matrix = (sigma**2) * np.eye(2)
m = np.pi * (sigma**2) * np.exp(sigma**2 / 4)
if r_sample[0] < r_sample[1]:
num = _pdf_r(r_sample, sigma)
den = multivariate_normal.pdf(r_sample, mean=mu_b, cov=cov_matrix) * m
return num / den
return 0
def _rejection_sampling_2D(n_samples, sigma, random_state=None,
return_acceptance_rate=False):
"""Rejection sampling algorithm for the 2D case.
Implementation of a rejection sampling algorithm.
The implementation follows the description given in p528 of Christopher
Bishop's book "Pattern recognition and Machine Learning" (2006).
Parameters
----------
n_samples : int
Number of samples to get from the target distribution.
sigma : float
Dispersion of the Riemannian Gaussian distribution.
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
return_acceptance_rate : boolean, default=False
Whether to return the acceptance rate with the sample (number of
samples obtained divided by the number of samples generated by
the algorithm).
.. versionadded:: 0.5
Returns
-------
r_samples : ndarray, shape (n_samples, n_dim)
Samples of the r parameters of the Riemannian Gaussian distribution.
acceptance_rate : float
Acceptance rate empirically computed for the generation of the sample.
Only returned if ``return_acceptance_rate=True``.
Notes
-----
.. versionadded:: 0.4
"""
mu_a = np.array([-sigma**2 / 2, (sigma**2) / 2])
mu_b = np.array([(sigma**2) / 2, -sigma**2 / 2])
cov_matrix = (sigma**2) * np.eye(2)
r_samples = []
cpt = 0
acc = 0
rs = check_random_state(random_state)
while cpt != n_samples:
acc += 1
if (rs.binomial(1, 0.5, 1) == 1):
r_sample = multivariate_normal.rvs(mu_a, cov_matrix, 1, rs)
res = _rejection_sampling_2D_gfunction_plus(sigma, r_sample)
if rs.rand(1) < res:
r_samples.append(r_sample)
cpt += 1
else:
r_sample = multivariate_normal.rvs(mu_b, cov_matrix, 1, rs)
res = _rejection_sampling_2D_gfunction_minus(sigma, r_sample)
if rs.rand(1) < res:
r_samples.append(r_sample)
cpt += 1
if return_acceptance_rate:
return np.array(r_samples), n_samples / acc
return np.array(r_samples)
def _slice_one_sample(ptarget, x0, w, rs):
"""Slice sampling for one sample
Parameters
----------
ptarget : function with one input
The target pdf to sample from or a multiple of it.
x0 : ndarray
Initial state for the MCMC procedure. Note that the shape of this array
defines the dimensionality n_dim of the matrices to be sampled.
w : float
Initial bracket width.
rs : int | RandomState instance | None
Pass an int for reproducible output across multiple function calls.
Returns
-------
sample : ndarray, shape (n_dim,)
Sample from the target pdf.
"""
xt = np.copy(x0)
n_dim = len(x0)
for i in range(n_dim):
ei = np.zeros(n_dim)
ei[i] = 1
# step 1 : evaluate ptarget(xt)
Px = ptarget(xt)
# step 2 : draw vertical coordinate uprime ~ U(0, ptarget(xt))
uprime_i = Px * rs.rand()
# step 3 : create a horizontal interval (xl_i, xr_i) enclosing xt_i
r = rs.rand()
xl_i = xt[i] - r * w
xr_i = xt[i] + (1-r) * w
while ptarget(xt + (xl_i - xt[i]) * ei) > uprime_i:
xl_i = xl_i - w
while ptarget(xt + (xr_i - xt[i]) * ei) > uprime_i:
xr_i = xr_i + w
# step 4 : loop
while True:
xprime_i = xl_i + (xr_i - xl_i) * rs.rand()
Px = ptarget(xt + (xprime_i - xt[i]) * ei)
if Px > uprime_i:
break
else:
if xprime_i > xt[i]:
xr_i = xprime_i
else:
xl_i = xprime_i
# store coordinate i of new sample
xt = np.copy(xt)
xt[i] = xprime_i
return xt
def _slice_sampling(ptarget, n_samples, x0, n_burnin=20, thin=10,
random_state=None, n_jobs=1):
"""Slice sampling procedure.
Implementation of a slice sampling algorithm for sampling from any target
pdf or a multiple of it.
The implementation follows the description given in p375 of David McKay's
book "Information Theory, Inference, and Learning Algorithms" (2003).
Parameters
----------
ptarget : function with one input
The target pdf to sample from or a multiple of it.
n_samples : int
Number of samples to get from the ptarget distribution.
x0 : ndarray
Initial state for the MCMC procedure. Note that the shape of this array
defines the dimensionality n_dim of the matrices to be sampled.
n_burnin : int, default=20
Number of samples to discard from the beginning of the chain generated
by the slice sampling procedure. Usually the first samples are prone to
non-stationary behavior and do not follow very well the target pdf.
thin : int, default=10
Thinning factor for the slice sampling procedure. MCMC samples are
often correlated between them, so taking one sample every ``thin``
samples can help reducing this correlation. Note that this makes the
algorithm actually sample ``thin`` x n_samples samples from the pdf, so
expect the whole sampling procedure to take longer.
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each sample in parallel. If -1 all CPUs are used.
Returns
-------
samples : ndarray, shape (n_samples, n_dim)
Samples from the target pdf.
"""
if (n_samples <= 0) or (not isinstance(n_samples, int)):
raise ValueError(
f"n_samples must be a positive integer (Got {n_samples})"
)
if (n_burnin <= 0) or (not isinstance(n_burnin, int)):
raise ValueError(
f"n_burnin must be a positive integer (Got {n_burnin})"
)
if (thin <= 0) or (not isinstance(thin, int)):
raise ValueError(f"thin must be a positive integer (Got {thin})")
rs = check_random_state(random_state)
w = 1.0 # initial bracket width
n_samples_total = (n_samples + n_burnin) * thin
samples = Parallel(n_jobs=n_jobs)(
delayed(_slice_one_sample)(ptarget, x0, w, rs)
for _ in range(n_samples_total)
)
samples = np.array(samples)[(n_burnin * thin):][::thin]
return samples
def _sample_parameter_r(n_samples, n_dim, sigma,
random_state=None, n_jobs=1, sampling_method="auto"):
"""Sample the r parameters of a Riemannian Gaussian distribution.
Sample the logarithm of the eigenvalues of a SPD matrix following a
Riemannian Gaussian distribution.
See Said2017 for the mathematical details.
Parameters
----------
n_samples : int
Number of samples to generate.
n_dim : int
Dimensionality of the SPD matrices to be sampled.
sigma : float
Dispersion of the Riemannian Gaussian distribution.
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each sample in parallel. If -1 all CPUs are used.
sampling_method : {"auto", "slice", "rejection"}, default="auto"
Method used to sample parameter r: "auto", "slice" or "rejection".
If "auto", sampling_method will be equal to "slice" for n_dim != 2 and
equal to "rejection" for n_dim = 2.
.. versionadded:: 0.4
Returns
-------
r_samples : ndarray, shape (n_samples, n_dim)
Samples of the r parameters of the Riemannian Gaussian distribution.
"""
if sampling_method not in ["slice", "rejection", "auto"]:
raise ValueError(f"Unknown sampling method {sampling_method}, "
"try slice or rejection")
if n_dim == 2 and sampling_method != "slice":
return _rejection_sampling_2D(n_samples, sigma,
random_state=random_state)
if n_dim != 2 and sampling_method == "rejection":
raise ValueError(
f"n_dim={n_dim} is not yet supported with rejection sampling"
)
rs = check_random_state(random_state)
x0 = rs.randn(n_dim)
ptarget = partial(_pdf_r, sigma=sigma)
r_samples = _slice_sampling(
ptarget,
n_samples=n_samples,
x0=x0,
random_state=random_state,
n_jobs=n_jobs,
)
return r_samples
def _sample_parameter_U(n_samples, n_dim, random_state=None,
is_complex=False):
"""Sample the U parameters of a Riemannian Gaussian distribution.
Sample the eigenvectors of a SPD or HPD matrix following a Riemannian
Gaussian distribution.
See Said2017 for the mathematical details.
Parameters
----------
n_samples : int
Number of samples to generate.
n_dim : int
Dimensionality of the matrices to be sampled.
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
is_complex : bool, default=False
If True, generate complex-valued unitary matrices for HPD sampling.
.. versionadded:: 0.12
Returns
-------
u_samples : ndarray, shape (n_samples, n_dim, n_dim)
Samples of the U parameters of the Riemannian Gaussian distribution.
"""
rs = check_random_state(random_state)
dtype = np.complex64 if is_complex else np.float64
u_samples = np.zeros((n_samples, n_dim, n_dim), dtype=dtype)
for i in range(n_samples):
A = rs.randn(n_dim, n_dim)
if is_complex:
A = A + 1j * rs.randn(n_dim, n_dim)
Q, _ = np.linalg.qr(A)
u_samples[i] = Q
return u_samples
def _sample_gaussian_centered(n_matrices, n_dim, sigma, random_state=None,
n_jobs=1, sampling_method="auto",
is_complex=False):
"""Sample a Riemannian Gaussian distribution centered at the Identity.
Sample SPD or HPD matrices from a Riemannian Gaussian distribution
centered at the Identity, which has the role of the origin in the manifold,
and dispersion parametrized by sigma.
See Said2017 for the mathematical details.
Parameters
----------
n_matrices : int
Number of matrices to generate.
n_dim : int
Dimensionality of the matrices to be sampled.
sigma : float
Dispersion of the Riemannian Gaussian distribution.
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each sample in parallel. If -1 all CPUs are used.
sampling_method : {"auto", "slice", "rejection"}, default="auto"
Method used to sample parameter r: "auto", "slice" or "rejection".
If "auto", sampling_method will be equal to "slice" for n_dim != 2 and
equal to "rejection" for n_dim = 2.
.. versionadded:: 0.4
is_complex : bool, default=False
If True, generate complex-valued HPD matrices
instead of real-valued SPD matrices.
.. versionadded:: 0.12
Returns
-------
samples : ndarray, shape (n_matrices, n_dim, n_dim)
Samples of the Riemannian Gaussian distribution.
Notes
-----
.. versionadded:: 0.3
"""
samples_r = _sample_parameter_r(
n_samples=n_matrices,
n_dim=n_dim,
sigma=sigma,
random_state=random_state,
n_jobs=n_jobs,
sampling_method=sampling_method,
)
samples_U = _sample_parameter_U(
n_samples=n_matrices,
n_dim=n_dim,
random_state=random_state,
is_complex=is_complex,
)
dtype = np.complex64 if is_complex else np.float64
samples = np.zeros((n_matrices, n_dim, n_dim), dtype=dtype)
for i in range(n_matrices):
Ui = samples_U[i]
ri = samples_r[i]
sample = (Ui.conj().T * np.exp(ri)) @ Ui
samples[i] = 0.5 * (sample + ctranspose(sample))
return samples
[docs]
def sample_gaussian(n_matrices, mean, sigma, random_state=None,
n_jobs=1, sampling_method="auto"):
"""Sample a Riemannian Gaussian distribution.
Sample SPD or HPD matrices from a Riemannian Gaussian distribution
centered at ``mean`` and with dispersion parametrized by ``sigma``.
If ``mean`` has a complex dtype, Hermitian positive-definite (HPD)
matrices are generated instead of symmetric positive-definite (SPD).
If ``sigma`` is a float, it samples from the distribution defined in [1]_
that generalizes the notion of a Gaussian distribution to the space of
SPD/HPD matrices. This sampling is based on a spectral factorization of
SPD/HPD matrices in terms of their eigenvectors (U-parameters) and the log
of the eigenvalues (r-parameters).
If ``sigma`` is a covariance matrix, it samples from the wrapped Gaussian
distribution defined in [2]_.
Parameters
----------
n_matrices : int
Number of matrices to generate.
mean : ndarray, shape (n_dim, n_dim)
Center of the Riemannian Gaussian distribution.
If complex, HPD matrices are generated;
if real, SPD matrices are generated.
.. versionchanged:: 0.12
sigma : float | ndarray, shape (n_dim * (n_dim + 1) / 2, \
n_dim * (n_dim + 1) / 2)
If float, dispersion of the Riemannian Gaussian distribution [1]_.
If ndarray, covariance matrix of the wrapped Gaussian
distribution [2]_.
.. versionchanged:: 0.11
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
n_jobs : int, default=1
When sigma is a float,
the number of jobs to use for the computation. This works by computing
each sample in parallel. If -1 all CPUs are used.
sampling_method : {"auto", "slice", "rejection"}, default="auto"
When sigma is a float,
method used to sample eigenvalues: "auto", "slice" or "rejection".
If "auto", sampling_method will be equal to "slice" for n_dim != 2 and
equal to "rejection" for n_dim = 2.
.. versionadded:: 0.4
Returns
-------
samples : ndarray, shape (n_matrices, n_dim, n_dim)
Samples of the Riemannian Gaussian distribution.
Notes
-----
.. versionadded:: 0.3
.. versionchanged:: 0.11
Add support for ``sigma`` defined as a covariance matrix.
.. versionchanged:: 0.12
Rename sample_gaussian_spd into sample_gaussian.
Add support for HPD matrices for float ``sigma``.
References
----------
.. [1] `Riemannian Gaussian distributions on the space of symmetric
positive definite matrices
<https://hal.archives-ouvertes.fr/hal-01710191>`_
S. Said, L. Bombrun, Y. Berthoumieu, and J. Manton. IEEE Trans Inf
Theory, vol. 63, pp. 2153-2170, 2017.
.. [2] `Wrapped gaussian on the manifold of symmetric positive
definite matrices
<https://openreview.net/pdf?id=EhStXG4dCS>`_
T. de Surrel, F. Lotte, S. Chevallier, and F. Yger. International
Conference on Machine Learning (ICML), July 2025, Vancouver, Canada.
"""
n_dim, _ = mean.shape
is_complex = np.iscomplexobj(mean)
if isinstance(sigma, (int, float)):
# generate samples centered at identity
samples_centered = _sample_gaussian_centered(
n_matrices=n_matrices,
n_dim=n_dim,
sigma=sigma / np.sqrt(n_dim), # dispersion corrected w.r.t. dim
random_state=random_state,
n_jobs=n_jobs,
sampling_method=sampling_method,
is_complex=is_complex,
)
# apply the parallel transport from identity to mean on samples
mean_sqrt = sqrtm(mean)
samples = mean_sqrt @ samples_centered @ ctranspose(mean_sqrt)
elif isinstance(sigma, np.ndarray):
if is_complex:
raise NotImplementedError(
"Wrapped Gaussian sampling (ndarray sigma) is not yet "
"supported for HPD matrices. Use a float sigma instead."
)
n_ts = n_dim * (n_dim + 1) // 2
if sigma.shape != (n_ts, n_ts):
raise ValueError(
f"sigma must be a covariance matrix of shape ({n_ts}, {n_ts})."
)
# generate samples from the multivariate normal distribution
rs = check_random_state(random_state)
samples_ts_norm = rs.multivariate_normal(
size=n_matrices,
mean=np.zeros(n_ts),
cov=sigma,
)
# send the tangent space at mean
mean_sqrt = sqrtm(mean)
samples_ = mean_sqrt @ unupper(samples_ts_norm) @ ctranspose(mean_sqrt)
# map back to the manifold
samples = exp_map_riemann(samples_, mean, Cm12=True)
else:
raise ValueError("sigma must be either a float or a ndarray.")
if not is_hpsd(samples):
msg = "Some of the sampled matrices are very badly conditioned and " \
"may not behave numerically as positive definite matrices. " \
"Try sampling again or reducing the dimensionality of matrices."
warnings.warn(msg)
return samples
@deprecated(
"sample_gaussian_spd() is deprecated and will be removed in 0.14.0."
"please use sample_gaussian()."
)
def sample_gaussian_spd(n_matrices, mean, sigma, random_state=None,
n_jobs=1, sampling_method="auto"):
return sample_gaussian(n_matrices, mean, sigma, random_state=random_state,
n_jobs=n_jobs, sampling_method=sampling_method)
###############################################################################
[docs]
class RandomOverSampler(BaseEstimator):
"""Random over-sampling for SPD/HPD matrices.
For each class, output SPD/HPD matrices are interpolated along the geodesic
between input SPD/HPD matrices [1]_.
Parameters
----------
metric : string, default="riemann"
Metric used for SPD/HPD matrices interpolation
(see :func:`pyriemann.geometry.geodesic.geodesic`).
sampling_strategy : str, default="auto"
Specify the class targeted by the resampling. The number of matrices in
the different classes will be equalized. Possible choices are:
- "minority": resample only the minority class;
- "not minority": resample all classes but the minority class;
- "not majority": resample all classes but the majority class;
- "all": resample all classes;
- "auto": equivalent to "not majority".
random_state : int | RandomState instance | None, default=None
Pass an int for reproducible output across multiple function calls.
n_jobs : int, default=1
Number of jobs to use for the computation. This works by computing
each of the class resampling in parallel.
If -1 all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. For n_jobs below -1,
(n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.
Notes
-----
.. versionadded:: 0.10
References
----------
.. [1] `Data augmentation in Riemannian space for brain-computer interfaces
<https://hal.science/hal-01351990/>`_
E. Kalunga, S. Chevallier and Q. Barthélemy.
ICML Workshop on Statistics, Machine Learning and Neuroscience, 2015.
"""
[docs]
def __init__(
self,
metric="riemann",
sampling_strategy="auto",
random_state=None,
n_jobs=1
):
"""Init."""
self.metric = metric
self.sampling_strategy = sampling_strategy
self.random_state = random_state
self.n_jobs = n_jobs
[docs]
def fit(self, X, y):
"""Check parameters of the sampler.
You should use ``fit_resample`` in all cases.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
y : ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
self : object
Return the instance itself.
"""
self._rs = check_random_state(self.random_state)
return self
[docs]
def fit_resample(self, X, y):
"""Resample the matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD/HPD matrices.
y : ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
X_resampled : ndarray, shape (n_matrices_new, n_channels, n_channels)
Set of resampled SPD/HPD matrices.
y_resampled : ndarray, shape (n_matrices_new,)
Labels for each resampled matrix.
"""
self.fit(X, y)
_, self._channels, _ = X.shape
output_counts = self._check_sampling_strategy(y)
res = Parallel(n_jobs=self.n_jobs)(
delayed(self._resample)(X[y == c], c, n_mats)
for c, n_mats in output_counts.items()
)
X_resampled_, y_resampled_ = zip(*res)
X_resampled = np.concatenate((X,) + X_resampled_, axis=0)
y_resampled = np.concatenate((y,) + y_resampled_, axis=0)
return X_resampled, y_resampled
def _check_sampling_strategy(self, y):
classes, counts = np.unique(y, return_counts=True)
input_counts = dict(zip(classes, counts))
n_mats_majority = max(input_counts.values())
if self.sampling_strategy == "minority":
class_minority = min(input_counts, key=input_counts.get)
return {
key: n_mats_majority - value
for (key, value) in input_counts.items()
if key == class_minority
}
if self.sampling_strategy == "not minority":
class_minority = min(input_counts, key=input_counts.get)
return {
key: n_mats_majority - value
for (key, value) in input_counts.items()
if key != class_minority
}
if self.sampling_strategy in ["not majority", "auto"]:
class_majority = max(input_counts, key=input_counts.get)
return {
key: n_mats_majority - value
for (key, value) in input_counts.items()
if key != class_majority
}
if self.sampling_strategy == "all":
return {
key: n_mats_majority - value
for (key, value) in input_counts.items()
}
raise ValueError(
f"Sampling strategy {self.sampling_strategy} is not supported."
)
def _resample(self, X, y, n_mats):
X_resampled = np.empty((n_mats, self._channels, self._channels))
for n in range(n_mats):
i, j = self._rs.choice(len(X), size=2, replace=False)
alpha = self._rs.uniform(0, 1)
X_resampled[n] = geodesic(X[i], X[j], alpha, metric=self.metric)
return X_resampled, np.full(n_mats, y)