Source code for pyriemann.geometry.test

from array_api_compat import (
    array_namespace as get_namespace,
    is_torch_namespace,
)
import numpy as np


def _allclose(A, B):
    """Array-API equivalent of ``numpy.allclose``."""
    xp = get_namespace(A, B)
    return bool(xp.all(xp.isclose(A, B)))


def _get_eigenvals(X):
    """Real part of eigenvalues for the trailing matrix dimension.

    ``xp.linalg.eigvals`` always returns complex dtype on torch (even for
    real inputs), and complex tensors cannot be compared against a float
    tolerance — so the real part is taken here once for all callers.
    """
    xp = get_namespace(X)
    n = X.shape[-1]
    return xp.real(xp.linalg.eigvals(X.reshape((-1, n, n))))


[docs] def is_square(X): """Check if matrices are square. Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if matrices are square. """ return X.ndim >= 2 and X.shape[-2] == X.shape[-1]
[docs] def is_sym(X): """Check if all matrices are symmetric. Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if all matrices are symmetric. """ return is_square(X) and _allclose(X, X.mT)
[docs] def is_skew_sym(X): """Check if all matrices are skew-symmetric. Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if all matrices are skew-symmetric. """ return is_square(X) and _allclose(X, -X.mT)
def is_hankel(X): """Check if matrix is an Hankel matrix. Parameters ---------- X : ndarray, shape (n, n) Square matrix. Returns ------- ret : bool True if Hankel matrix. """ if not is_square(X) or X.ndim != 2: return False n, _ = X.shape for i in range(n): for j in range(n): if (i + j < n): if bool((X[i, j] != X[i + j, 0]).item()): return False else: if bool((X[i, j] != X[i + j - n + 1, n - 1]).item()): return False return True
[docs] def is_real(X): """Check if all matrices are strictly real. Better management of numerical imprecisions than np.all(np.isreal()). Parameters ---------- X : ndarray, shape (..., n, m) The set of matrices. Returns ------- ret : bool True if all matrices are strictly real. """ if is_real_type(X): return True xp = get_namespace(X) X_imag = xp.imag(X) return _allclose(X_imag, xp.zeros_like(X_imag))
[docs] def is_real_type(X): """Check if matrices are real type. Parameters ---------- X : ndarray, shape (..., n, m) The set of matrices. Returns ------- ret : bool True if matrices are real type. Notes ----- .. versionadded:: 0.6 """ xp = get_namespace(X) if is_torch_namespace(xp): return not X.dtype.is_complex return np.isrealobj(X)
[docs] def is_hermitian(X): """Check if all matrices are Hermitian. Check if all matrices are Hermitian, ie with a symmetric real part and a skew-symmetric imaginary part. Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if all matrices are Hermitian. """ if is_real_type(X): return is_sym(X) xp = get_namespace(X) return is_sym(xp.real(X)) and is_skew_sym(xp.imag(X))
[docs] def is_pos_def(X, tol=0.0, fast_mode=False): """Check if all matrices are positive definite (PD). Check if all matrices are positive definite, fast verification is done with Cholesky decomposition, while full check compute all eigenvalues to verify that they are positive. Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. tol : float, default=0.0 Threshold below which eigen values are considered zero. fast_mode : bool, default=False Use Cholesky decomposition to avoid computing all eigenvalues. Returns ------- ret : bool True if all matrices are positive definite. """ xp = get_namespace(X) if fast_mode: try: xp.linalg.cholesky(X) return True except (np.linalg.LinAlgError, RuntimeError): return False else: if not is_square(X): return False return bool(xp.all(_get_eigenvals(X) > tol))
[docs] def is_pos_semi_def(X): """Check if all matrices are positive semi-definite (PSD). Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if all matrices are positive semi-definite. """ xp = get_namespace(X) if not is_square(X): return False return bool(xp.all(_get_eigenvals(X) >= 0.0))
[docs] def is_sym_pos_def(X, tol=0.0): """Check if all matrices are symmetric positive-definite (SPD). Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. tol : float, default=0.0 Threshold below which eigen values are considered zero. Returns ------- ret : bool True if all matrices are symmetric positive-definite. """ return is_sym(X) and is_pos_def(X, tol=tol)
[docs] def is_sym_pos_semi_def(X): """Check if all matrices are symmetric positive semi-definite (SPSD). Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if all matrices are symmetric positive semi-definite. """ return is_sym(X) and is_pos_semi_def(X)
[docs] def is_herm_pos_def(X, tol=0.0): """Check if all matrices are Hermitian positive-definite (HPD). Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. tol : float, default=0.0 Threshold below which eigen values are considered zero. Returns ------- ret : bool True if all matrices are Hermitian positive-definite. """ return is_hermitian(X) and is_pos_def(X, tol=tol)
[docs] def is_herm_pos_semi_def(X): """Check if all matrices are Hermitian positive semi-definite (HPSD). Parameters ---------- X : ndarray, shape (..., n, n) The set of square matrices. Returns ------- ret : bool True if all matrices are Hermitian positive semi-definite. """ return is_hermitian(X) and is_pos_semi_def(X)