pyriemann.transfer.TLEstimator

class pyriemann.transfer.TLEstimator(target_domain, estimator, domain_weight=None)

Transfer learning wrapper for estimators.

This is a wrapper for any BaseEstimator (i.e. classifier or regressor) that converts extended labels used in Transfer Learning into the usual y array to train a classifier/regressor of choice.

Parameters
target_domainstr

Domain to consider as target.

estimatorBaseEstimator

The estimator to apply on matrices. It can be any regressor or classifier from pyRiemann.

domain_weightNone | dict, default=None

Weights to combine matrices from each domain to train the estimator. The dict contains key=domain_name and value=weight_to_assign. If None, it uses equal weights.

Notes

New in version 0.4.

__init__(target_domain, estimator, domain_weight=None)

Init.

fit(X, y_enc)

Fit TLEstimator.

Parameters
Xndarray, shape (n_matrices, n_channels, n_channels)

Set of SPD matrices.

y_encndarray, shape (n_matrices,)

Extended labels for each matrix.

Returns
selfTLEstimator instance

The TLEstimator instance.

get_params(deep=True)

Get parameters for this estimator.

Parameters
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns
paramsdict

Parameter names mapped to their values.

predict(X)

Get the predictions.

Parameters
Xndarray, shape (n_matrices, n_channels, n_channels)

Set of SPD matrices.

Returns
predndarray, shape (n_matrices,)

Predictions for each matrix according to the estimator.

set_params(**params)

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters
**paramsdict

Estimator parameters.

Returns
selfestimator instance

Estimator instance.