pyriemann.transfer.TLSplitter

class pyriemann.transfer.TLSplitter(target_domain, cv)

Class for handling the cross-validation splits of multi-domain data.

This is a wrapper to sklearn’s cross-validation iterators [1] which ensures the handling of domain information with the data points. In fact, the data from source domain is always fully available in the training partition whereas the random splits are done on the data points from the target domain.

Parameters
target_domainstr

Domain considered as target.

cvNone | BaseCrossValidator | BaseShuffleSplit, default=None

An instance of a cross validation iterator from sklearn.

Notes

New in version 0.4.

References

1

https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators

__init__(target_domain, cv)
get_n_splits(X=None, y=None)

Returns the number of splitting iterations in the cross-validator.

Parameters
Xobject

Ignored, exists for compatibility.

yobject

Ignored, exists for compatibility.

Returns
n_splitsint

Returns the number of splitting iterations in the cross-validator.

split(X, y)

Generate indices to split data into training and test set.

Parameters
Xndarray, shape (n_matrices, n_channels, n_channels)

Set of SPD matrices.

yndarray, shape (n_matrices,)

Extended labels for each matrix.

Yields
trainndarray

The training set indices for that split.

testndarray

The testing set indices for that split.