pyriemann.transfer.TLRegressor

class pyriemann.transfer.TLRegressor(target_domain, estimator, domain_weight=None)

Transfer learning wrapper for regressors.

This is a wrapper for any regressor that converts extended labels used in Transfer Learning into the usual y array to train a regressor of choice.

Parameters
target_domainstr

Domain to consider as target.

estimatorBaseRegressor

The regressor to apply on matrices.

domain_weightNone | dict, default=None

Weights to combine matrices from each domain to train the regressor. The dict contains key=domain_name and value=weight_to_assign. If None, it uses equal weights.

Notes

New in version 0.3.1.

__init__(target_domain, estimator, domain_weight=None)

Init.

fit(X, y_enc)

Fit TLRegressor.

Parameters
Xndarray, shape (n_matrices, n_channels, n_channels)

Set of SPD matrices.

y_encndarray, shape (n_matrices,)

Extended labels for each matrix.

Returns
selfTLRegressor instance

The TLRegressor instance.

get_params(deep=True)

Get parameters for this estimator.

Parameters
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns
paramsdict

Parameter names mapped to their values.

predict(X)

Get the predictions.

Parameters
Xndarray, shape (n_matrices, n_channels, n_channels)

Set of SPD matrices.

Returns
predndarray, shape (n_matrices,)

Predictions for each matrix according to the estimator.

score(X, y_enc)

Return the coefficient of determination of the prediction.

Parameters
Xndarray, shape (n_matrices, n_channels, n_channels)

Test set of SPD matrices.

y_encndarray, shape (n_matrices,)

Extended true values for each matrix.

Returns
scorefloat

R2 of self.predict(X) wrt. y.

set_params(**params)

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters
**paramsdict

Estimator parameters.

Returns
selfestimator instance

Estimator instance.