Comparison of pipelines for transfer learning

We compare the classification performance of MDM on different strategies for transfer learning.

Matrices are simulated from a toy model based on the Riemannian Gaussian distribution and the differences in statistics between source and target distributions are determined by a set of parameters that have control over the distance between the centers of each dataset, the angle of rotation between the means of each class, and the differences in dispersion of the matrices from each dataset.

from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.pipeline import make_pipeline

from pyriemann.classification import MDM
from pyriemann.datasets.simulated import make_classification_transfer
from pyriemann.tangentspace import TangentSpace
from pyriemann.transfer import (
    TLSplitter,
    TLDummy,
    TLCenter,
    TLScale,
    TLRotate,
    TLClassifier,
    MDWM,
)

Pipelines

We consider several pipelines for transfer learning:

  • calib: use only data from target-train partition, classifier is trained only with matrices from the target domain.

  • dummy: no transfer learning at all, ie no transformation of data between domains, classifier is trained only with matrices from the source domain.

  • rct: recenter data from each domain to the identity matrix [1], classifier is trained only with matrices from the source domain.

  • rpa: match the statistical distributions in a semi-supervised way with Riemannian Procrustes Analysis (RPA) [2]: center, stretch and rotate matrices in manifold. Classifier is trained with matrices from source and target.

  • mdwm: improve the MDM classifier with a weighting strategy, giving the minimum distance to weighted mean (MDWM) [3].

  • tsa: align tangent vectors by Procrustes analysis [4]: center, normalize and rotate vectors in tangent space.

methods = ["calib", "dummy", "rct", "rpa", "mdwm", "tsa"]
scores = {meth: [] for meth in methods}

# Base classifier to consider in manifold
clf_base = MDM()

# Choose seed for reproducible results
seed = 100

# Create a dataset with two domains, each with two classes both datasets
# are generated by the same generative procedure with the SPD Gaussian
# and one of them is transformed by a matrix A, i.e. X <- A @ X @ A.T
X_enc, y_enc = make_classification_transfer(
    n_matrices=100,
    class_sep=0.75,
    class_disp=1.0,
    domain_sep=5.0,
    theta=3*np.pi/5,
    random_state=seed,
)

# Object for splitting the datasets into training and validation partitions
# the training set is composed of all matrices from the source domain
# plus a partition of the target domain whose size we can control
target_domain = "target_domain"
n_splits = 5  # how many times to split the target domain into train/test
tl_cv = TLSplitter(
    target_domain=target_domain,
    cv=StratifiedShuffleSplit(n_splits=n_splits, random_state=seed),
)
# Vary the proportion of the target domain for training
target_train_frac_array = np.linspace(0.01, 0.20, 10)
for target_train_frac in tqdm(target_train_frac_array):

    # Change fraction of the target training partition
    tl_cv.cv.train_size = target_train_frac

    # Create dict for storing results of this particular CV split
    scores_cv = {meth: [] for meth in scores.keys()}

    # Carry out the cross-validation
    for train_idx, test_idx in tl_cv.split(X_enc, y_enc):

        # Split the dataset into training and testing
        X_enc_train, X_enc_test = X_enc[train_idx], X_enc[test_idx]
        y_enc_train, y_enc_test = y_enc[train_idx], y_enc[test_idx]

        # Calibration
        pipeline = make_pipeline(
            TLClassifier(
                target_domain=target_domain,
                estimator=clf_base,
                domain_weight={"source_domain": 0.0, "target_domain": 1.0},
            ),
        )

        pipeline.fit(X_enc_train, y_enc_train)
        scores_cv["calib"].append(pipeline.score(X_enc_test, y_enc_test))

        # Dummy
        pipeline = make_pipeline(
            TLDummy(),
            TLClassifier(
                target_domain=target_domain,
                estimator=clf_base,
                domain_weight={"source_domain": 1.0, "target_domain": 0.0},
            ),
        )

        pipeline.fit(X_enc_train, y_enc_train)
        scores_cv["dummy"].append(pipeline.score(X_enc_test, y_enc_test))

        # Recentering pipeline
        pipeline = make_pipeline(
            TLCenter(target_domain=target_domain),
            TLClassifier(
                target_domain=target_domain,
                estimator=clf_base,
                domain_weight={"source_domain": 1.0, "target_domain": 0.0},
            ),
        )

        pipeline.fit(X_enc_train, y_enc_train)
        scores_cv["rct"].append(pipeline.score(X_enc_test, y_enc_test))

        # RPA pipeline
        pipeline = make_pipeline(
            TLCenter(target_domain=target_domain),
            TLScale(
                target_domain=target_domain,
                final_dispersion=1,
                centered_data=True,
            ),
            TLRotate(target_domain=target_domain, metric="euclid"),
            TLClassifier(
                target_domain=target_domain,
                estimator=clf_base,
                domain_weight={"source_domain": 0.5, "target_domain": 0.5},
            ),
        )

        pipeline.fit(X_enc_train, y_enc_train)
        scores_cv["rpa"].append(pipeline.score(X_enc_test, y_enc_test))

        # MDWM pipeline
        domain_tradeoff = 1 - np.exp(-100 * target_train_frac)
        pipeline = MDWM(
            domain_tradeoff=domain_tradeoff,
            target_domain=target_domain,
            metric="riemann",
        )

        pipeline.fit(X_enc_train, y_enc_train)
        scores_cv["mdwm"].append(pipeline.score(X_enc_test, y_enc_test))

        # TSA pipeline
        pipeline = make_pipeline(
            TangentSpace(metric="riemann"),
            TLCenter(target_domain=target_domain),
            TLScale(target_domain=target_domain),
            TLRotate(target_domain=target_domain),
            TLClassifier(
                target_domain=target_domain,
                estimator=LogisticRegression(),
                domain_weight={"source_domain": 0.5, "target_domain": 0.5},
            ),
        )

        pipeline.fit(X_enc_train, y_enc_train)
        scores_cv["tsa"].append(pipeline.score(X_enc_test, y_enc_test))

    # Get the average score of each pipeline
    for meth in scores.keys():
        scores[meth].append(np.mean(scores_cv[meth]))

# Store the results for each method on this particular seed
for meth in scores.keys():
    scores[meth] = np.array(scores[meth])
  0%|          | 0/10 [00:00<?, ?it/s]/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/transfer/_estimators.py:741: UserWarning: Not enough vectors for target domain
  warnings.warn("Not enough vectors for target domain")
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/transfer/_estimators.py:741: UserWarning: Not enough vectors for target domain
  warnings.warn("Not enough vectors for target domain")
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/transfer/_estimators.py:741: UserWarning: Not enough vectors for target domain
  warnings.warn("Not enough vectors for target domain")
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/transfer/_estimators.py:741: UserWarning: Not enough vectors for target domain
  warnings.warn("Not enough vectors for target domain")
/home/docs/checkouts/readthedocs.org/user_builds/pyriemann/checkouts/latest/pyriemann/transfer/_estimators.py:741: UserWarning: Not enough vectors for target domain
  warnings.warn("Not enough vectors for target domain")

 10%|█         | 1/10 [00:00<00:06,  1.43it/s]
 20%|██        | 2/10 [00:01<00:06,  1.28it/s]
 30%|███       | 3/10 [00:02<00:05,  1.24it/s]
 40%|████      | 4/10 [00:03<00:04,  1.26it/s]
 50%|█████     | 5/10 [00:03<00:03,  1.28it/s]
 60%|██████    | 6/10 [00:04<00:03,  1.29it/s]
 70%|███████   | 7/10 [00:05<00:02,  1.31it/s]
 80%|████████  | 8/10 [00:06<00:01,  1.33it/s]
 90%|█████████ | 9/10 [00:06<00:00,  1.34it/s]
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
100%|██████████| 10/10 [00:07<00:00,  1.31it/s]

Results

Plot the results, reproducing Figure 2 of [2].

fig, ax = plt.subplots(figsize=(6.7, 5.7))
for meth in scores.keys():
    ax.plot(
        target_train_frac_array,
        scores[meth],
        label=meth,
        lw=3.0 if meth == "calib" else 2.0,
    )
ax.legend(loc="lower right")
ax.set_ylim(0.5, 0.75)
ax.set_yticks([0.5, 0.6, 0.7])
ax.set_xlim(0.00, 0.21)
ax.set_xticks([0.01, 0.05, 0.10, 0.15, 0.20])
ax.set_xticklabels([1, 5, 10, 15, 20])
ax.set_xlabel("Percentage of training partition in target domain")
ax.set_ylabel("Classification accuracy")
ax.set_title("Comparison of transfer learning pipelines")
plt.show()
Comparison of transfer learning pipelines

References

Total running time of the script: (0 minutes 8.312 seconds)

Gallery generated by Sphinx-Gallery