From fec5185ca3aef0adb72fa08f4bc7daf0d34b3a28 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Tue, 4 Feb 2025 15:22:43 +0100 Subject: [PATCH 01/16] cross dataset working with pyriemann --- examples/cross_dataset.py | 81 +++++++++++++++++++++++++++++++++++++++ examples/testing.py | 8 ++++ 2 files changed, 89 insertions(+) create mode 100644 examples/cross_dataset.py create mode 100644 examples/testing.py diff --git a/examples/cross_dataset.py b/examples/cross_dataset.py new file mode 100644 index 000000000..5edad00de --- /dev/null +++ b/examples/cross_dataset.py @@ -0,0 +1,81 @@ +import logging +import yaml +from moabb.datasets import BNCI2014001, Zhou2016 +from moabb.paradigms import MotorImagery +from moabb.evaluations.evaluations import CrossDatasetEvaluation +from sklearn.pipeline import Pipeline +from pyriemann.estimation import Covariances +from pyriemann.spatialfilters import CSP +from sklearn.svm import SVC + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def create_pipeline() -> Pipeline: + """Create the CSP + SVM pipeline manually.""" + return Pipeline([ + ('covariances', Covariances(estimator='oas')), + ('csp', CSP(nfilter=6)), + ('svc', SVC(kernel='linear')) + ]) + +def main(): + # Define train and test datasets + train_dataset = BNCI2014001() + test_dataset = Zhou2016() + + # Initialize the paradigm + paradigm = MotorImagery(n_classes=2) + + # Initialize the CrossDatasetEvaluation + evaluation = CrossDatasetEvaluation( + paradigm=paradigm, + train_dataset=train_dataset, + test_dataset=test_dataset, + pretrained_model=None, # Not using a pre-trained model + fine_tune=False, # Not fine-tuning + target_channels=None, # Use all channels from train_dataset + sfreq=128, # Target sampling frequency + channel_strategy='zero',# Strategy for handling channels + montage='standard_1020',# EEG montage for SSI + min_channels=3, # Minimum common channels for subset strategy + hdf5_path=None, # Path to save results and models + save_model=False # Do not save models + ) + + # Create the pipeline + pipeline = create_pipeline() + + # Define parameter grid if needed (optional) + param_grid = { + 'svc__C': [0.1, 1, 10], + 'svc__kernel': ['linear'] + } + + # Run the evaluation + results = evaluation.evaluate( + dataset=None, # Not used in CrossDatasetEvaluation + pipelines={'CSP_SVM': pipeline}, + param_grid=param_grid + ) + + # Collect and display results + for res in results: + # Create log message with available information + log_msg = [ + f"Dataset: {res['dataset'].code}", + f"Subject: {res['subject']}", + f"Pipeline: {res['pipeline']}", + f"Score: {res['score']:.4f}", + f"Time: {res['time']:.2f}s" + ] + + # Add session info if available + if 'session' in res: + log_msg.insert(2, f"Session: {res['session']}") + + logger.info(", ".join(log_msg)) + +if __name__ == "__main__": + main() diff --git a/examples/testing.py b/examples/testing.py new file mode 100644 index 000000000..6cfa2807e --- /dev/null +++ b/examples/testing.py @@ -0,0 +1,8 @@ +from moabb.paradigms import MotorImagery +from moabb.datasets import utils + +paradigm = MotorImagery() +compatible_datasets = utils.dataset_search(paradigm=paradigm) +print("\nCompatible datasets for Motor Imagery paradigm:") +for dataset in compatible_datasets: + print(f"- {dataset.code}: {dataset.n_classes} classes") \ No newline at end of file From 6812f125bcee5c47afcf2ee56c354d11fa59ce32 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Wed, 5 Feb 2025 12:56:04 +0100 Subject: [PATCH 02/16] cross dataset eval --- moabb/evaluations/__init__.py | 1 + moabb/evaluations/evaluations.py | 323 ++++++++++++++++++++++++++++++- 2 files changed, 322 insertions(+), 2 deletions(-) diff --git a/moabb/evaluations/__init__.py b/moabb/evaluations/__init__.py index 453be2ed5..b5db9c26f 100644 --- a/moabb/evaluations/__init__.py +++ b/moabb/evaluations/__init__.py @@ -8,6 +8,7 @@ CrossSessionEvaluation, CrossSubjectEvaluation, WithinSessionEvaluation, + CrossDatasetEvaluation, ) from .splitters import WithinSessionSplitter from .utils import create_save_path, save_model_cv, save_model_list diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index b331aba6e..2451db8ff 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -20,8 +20,9 @@ from moabb.evaluations.base import BaseEvaluation from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list - - +from mne.channels import make_standard_montage +from mne.io.constants import FIFF +import mne try: from codecarbon import EmissionsTracker @@ -780,3 +781,321 @@ def evaluate( def is_valid(self, dataset): return len(dataset.subject_list) > 1 + +class CrossDatasetEvaluation(BaseEvaluation): + """Evaluation class for deep learning models across datasets. + + Parameters + ---------- + train_dataset : Dataset or list of Dataset + Dataset(s) to use for training + test_dataset : Dataset or list of Dataset + Dataset(s) to use for testing + pretrained_model : Optional[BaseEstimator] + Pre-trained model to use (if None, will train from scratch) + fine_tune : bool + Whether to fine-tune the pretrained model on train_dataset + target_channels : list or None + List of channel names to use. If None, will use all channels from train_dataset + sfreq : float + Target sampling frequency for all datasets + channel_strategy : str + Strategy for handling different channel configurations: + - 'zero': Zero-filling for missing channels (default) + - 'ssi': Spherical spline interpolation + - 'subset': Use only common channels across datasets + montage : str or mne.channels.DigMontage + EEG montage to use for SSI. Default is 'standard_1020'. + Can be a string (standard montage name) or custom DigMontage. + min_channels : int + Minimum number of common channels required for subset strategy + """ + def __init__( + self, + train_dataset, + test_dataset, + pretrained_model=None, + fine_tune=True, + target_channels=None, + sfreq=128, + channel_strategy='zero', + montage='standard_1020', + min_channels=3, + **kwargs + + ): + super().__init__(**kwargs) + self.train_dataset = train_dataset if isinstance(train_dataset, list) else [train_dataset] + self.test_dataset = test_dataset if isinstance(test_dataset, list) else [test_dataset] + self.pretrained_model = pretrained_model + self.fine_tune = fine_tune + self.sfreq = sfreq + self.channel_strategy = channel_strategy + self.montage = montage + self.min_channels = min_channels + + # Get channels from paradigm only + all_train_channels = set() + for dataset in self.train_dataset: + # Get channels from paradigm + X, _, _ = self.paradigm.get_data(dataset, [dataset.subject_list[0]], return_epochs=True) + all_train_channels.update(X.ch_names) + + # Set target channels based on strategy + self.target_channels = (target_channels if target_channels is not None + else list(all_train_channels)) + + # Load montage if SSI strategy is selected + if self.channel_strategy == 'ssi': + self._setup_montage() + + # Validate datasets and channel strategy + self._validate_datasets() + + def _setup_montage(self): + """Set up and validate EEG montage for SSI strategy.""" + try: + if isinstance(self.montage, str): + self.montage = make_standard_montage(self.montage) + + # Verify target channels exist in montage + missing_in_montage = set(self.target_channels) - set(self.montage.ch_names) + if missing_in_montage: + raise ValueError( + f"Channels {missing_in_montage} not found in montage. " + "Please use a different montage or channel strategy." + ) + except Exception as e: + raise ValueError(f"Error setting up montage: {str(e)}") + + def _validate_datasets(self): + + """Validate compatibility of train and test datasets.""" + all_datasets = self.train_dataset + self.test_dataset + + # Check paradigm compatibility + for dataset in all_datasets: + if not self.paradigm.is_valid(dataset): + raise ValueError(f"Dataset {dataset.code} not compatible with paradigm") + + # Validate channel strategy + valid_strategies = ['zero', 'ssi', 'subset'] + if self.channel_strategy not in valid_strategies: + raise ValueError(f"Invalid channel strategy. Must be one of {valid_strategies}") + + # For subset strategy, verify common channels exist + if self.channel_strategy == 'subset': + # Get channels from first dataset through paradigm + X, _, _ = self.paradigm.get_data(all_datasets[0], [all_datasets[0].subject_list[0]], return_epochs=True) + common_channels = set(X.ch_names) + + # Get channels from remaining datasets + for dataset in all_datasets[1:]: + X, _, _ = self.paradigm.get_data(dataset, [dataset.subject_list[0]], return_epochs=True) + common_channels &= set(X.ch_names) + + if len(common_channels) < self.min_channels: + raise ValueError( + f"Insufficient common channels found ({len(common_channels)}). " + f"Minimum required: {self.min_channels}" + ) + + self.target_channels = list(common_channels) + + def _resolve_channels(self, epochs, dataset_channels): + """Apply channel resolution strategy to match target channels.""" + try: + if self.channel_strategy == 'subset': + return epochs.pick_channels(self.target_channels, ordered=True) + + # Get missing and extra channels + missing_channels = list(set(self.target_channels) - set(dataset_channels)) + extra_channels = list(set(dataset_channels) - set(self.target_channels)) + + # Remove extra channels if any + if extra_channels: + epochs = epochs.drop_channels(extra_channels) + + # Handle missing channels + if missing_channels: + if self.channel_strategy == 'zero': + # Create new info with all target channels + info = mne.create_info( + ch_names=self.target_channels, + sfreq=epochs.info['sfreq'], + ch_types=['eeg'] * len(self.target_channels) + ) + + # Get current data + data = epochs.get_data() + n_epochs, _, n_times = data.shape + + # Create new data array with correct channel order + new_data = np.zeros((n_epochs, len(self.target_channels), n_times)) + + # Fill in existing channels + valid_channels = set(dataset_channels) & set(self.target_channels) + for ch in valid_channels: + target_idx = self.target_channels.index(ch) + data_idx = dataset_channels.index(ch) + if target_idx < len(self.target_channels) and data_idx < data.shape[1]: + new_data[:, target_idx, :] = data[:, data_idx, :] + + # Create new epochs with zero-filled channels + epochs = mne.EpochsArray(new_data, info) + + elif self.channel_strategy == 'ssi': + # Create temporary info with all channels + info = mne.create_info( + ch_names=self.target_channels, + sfreq=epochs.info['sfreq'], + ch_types=['eeg'] * len(self.target_channels) + ) + + # Get current data + data = epochs.get_data() + n_epochs, _, n_times = data.shape + + # Create temporary data array + temp_data = np.zeros((n_epochs, len(self.target_channels), n_times)) + + # Fill in existing channels + valid_channels = set(dataset_channels) & set(self.target_channels) + for ch in valid_channels: + target_idx = self.target_channels.index(ch) + data_idx = dataset_channels.index(ch) + if target_idx < len(self.target_channels) and data_idx < data.shape[1]: + temp_data[:, target_idx, :] = data[:, data_idx, :] + + # Create temporary epochs + temp_epochs = mne.EpochsArray(temp_data, info) + + # Set montage for interpolation + temp_epochs.set_montage(self.montage) + + # Mark missing channels as bad + temp_epochs.info['bads'] = missing_channels + + # Interpolate missing channels + epochs = temp_epochs.interpolate_bads( + method='spline', + reset_bads=True + ) + + return epochs + + except Exception as e: + raise RuntimeError(f"Error resolving channels: {str(e)}") + + def _prepare_data(self, dataset): + """Prepare data from a dataset with consistent format.""" + X, y, metadata = self.paradigm.get_data( + dataset=dataset, + subjects=dataset.subject_list, + return_epochs=True + ) + + # Apply channel resolution strategy using channels from epochs + X = self._resolve_channels(X, X.ch_names) + + # Resample if needed + if X.info['sfreq'] != self.sfreq: + X = X.resample(self.sfreq) + + return X, y, metadata + + def evaluate(self, dataset, pipelines, groups=None, param_grid=None, **kwargs): + """Evaluate models across datasets. + + Parameters + ---------- + dataset : Dataset + The dataset to evaluate + pipelines : dict + Dictionary of pipelines to evaluate + groups : array-like, optional + Groups for cross-validation + param_grid : dict, optional + Parameters grid for grid search + **kwargs : dict + Additional parameters + + Yields + ------ + dict + Evaluation results + """ + log.info("Starting cross-dataset evaluation") + + # Prepare training data + train_X, train_y, train_metadata = [], [], [] + for train_ds in self.train_dataset: + X, y, meta = self._prepare_data(train_ds) + train_X.append(X) + train_y.extend(y) + train_metadata.append(meta) + + # Combine training data + train_X = mne.concatenate_epochs(train_X) + train_y = np.array(train_y) + + # Convert EpochsArray to numpy array + train_X_data = train_X.get_data() + + # For each test dataset + for test_ds in self.test_dataset: + test_X, test_y, test_metadata = self._prepare_data(test_ds) + # Convert test data to numpy array + test_X_data = test_X.get_data() + + # Evaluate each pipeline + for name, pipeline in pipelines.items(): + if _carbonfootprint: + tracker = EmissionsTracker(save_to_file=False, log_level="error") + tracker.start() + + t_start = time() + + try: + # If using pretrained model, clone it + if self.pretrained_model is not None: + model = clone(self.pretrained_model) + if self.fine_tune: + model.fit(train_X_data, train_y) + else: + # Train from scratch + model = clone(pipeline).fit(train_X_data, train_y) + + # Evaluate on test subjects + for subject in test_ds.subject_list: + subject_mask = test_metadata.subject == subject + subject_X = test_X_data[subject_mask] + subject_y = test_y[subject_mask] + + score = model.score(subject_X, subject_y) + + duration = time() - t_start + + result = { + 'time': duration, + 'dataset': test_ds, + 'subject': subject, + 'score': score, + 'n_samples': len(subject_y), + 'n_channels': subject_X.shape[1], + 'pipeline': name, + 'training_datasets': [ds.code for ds in self.train_dataset], + 'pretrained': self.pretrained_model is not None, + 'fine_tuned': self.fine_tune, + 'channel_strategy': self.channel_strategy + } + + yield result + + except Exception as e: + log.error(f"Error evaluating pipeline {name}: {str(e)}") + raise + + def is_valid(self, dataset): + """Check if dataset is valid for this evaluation.""" + return True # All datasets are valid for cross-dataset evaluation \ No newline at end of file From c203219acb634588a78bf85792ed26b907762492 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Thu, 6 Feb 2025 18:28:29 +0100 Subject: [PATCH 03/16] changes to match MOABB syntax and format --- examples/cross_dataset.py | 166 ++++++++++------- moabb/evaluations/evaluations.py | 297 ++++++------------------------- 2 files changed, 155 insertions(+), 308 deletions(-) diff --git a/examples/cross_dataset.py b/examples/cross_dataset.py index 5edad00de..374e12379 100644 --- a/examples/cross_dataset.py +++ b/examples/cross_dataset.py @@ -1,81 +1,109 @@ -import logging -import yaml +from moabb import set_log_level from moabb.datasets import BNCI2014001, Zhou2016 from moabb.paradigms import MotorImagery from moabb.evaluations.evaluations import CrossDatasetEvaluation from sklearn.pipeline import Pipeline +from sklearn.preprocessing import FunctionTransformer from pyriemann.estimation import Covariances from pyriemann.spatialfilters import CSP from sklearn.svm import SVC +import matplotlib.pyplot as plt +from moabb.analysis.plotting import score_plot +import pandas as pd +import mne +import logging + +# Configure logging - reduce verbosity +set_log_level("WARNING") # Changed from "info" to "WARNING" +logging.getLogger('mne').setLevel(logging.ERROR) # Reduce MNE logging + +def get_common_channels(datasets): + """Get channels that are available across all datasets.""" + all_channels = [] + for dataset in datasets: + # Get a sample raw from each dataset + subject = dataset.subject_list[0] + raw_dict = dataset.get_data([subject]) + # Navigate through the nested dictionary structure + subject_data = raw_dict[subject] # Get subject's data + first_session = list(subject_data.keys())[0] # Get first session + first_run = list(subject_data[first_session].keys())[0] # Get first run + raw = subject_data[first_session][first_run] # Get raw data + all_channels.append(raw.ch_names) + + # Find common channels across all datasets + common_channels = set.intersection(*map(set, all_channels)) + return sorted(list(common_channels)) -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +def create_pipeline(common_channels) -> Pipeline: + """Create classification pipeline.""" + def raw_to_data(X): + """Convert raw MNE data to numpy array format""" + if hasattr(X, 'get_data'): + # Get only common channels to ensure consistency + picks = mne.pick_channels(X.info['ch_names'], + include=common_channels, + ordered=True) + data = X.get_data() + if data.ndim == 2: + data = data.reshape(1, *data.shape) + data = data[:, picks, :] + return data + return X -def create_pipeline() -> Pipeline: - """Create the CSP + SVM pipeline manually.""" - return Pipeline([ + pipeline = Pipeline([ + ('to_array', FunctionTransformer(raw_to_data)), ('covariances', Covariances(estimator='oas')), - ('csp', CSP(nfilter=6)), - ('svc', SVC(kernel='linear')) + ('csp', CSP(nfilter=4, log=True)), # Changed n_components to nfilter, removed invalid parameters + ('classifier', SVC(kernel='rbf', C=0.1)) ]) -def main(): - # Define train and test datasets - train_dataset = BNCI2014001() - test_dataset = Zhou2016() - - # Initialize the paradigm - paradigm = MotorImagery(n_classes=2) - - # Initialize the CrossDatasetEvaluation - evaluation = CrossDatasetEvaluation( - paradigm=paradigm, - train_dataset=train_dataset, - test_dataset=test_dataset, - pretrained_model=None, # Not using a pre-trained model - fine_tune=False, # Not fine-tuning - target_channels=None, # Use all channels from train_dataset - sfreq=128, # Target sampling frequency - channel_strategy='zero',# Strategy for handling channels - montage='standard_1020',# EEG montage for SSI - min_channels=3, # Minimum common channels for subset strategy - hdf5_path=None, # Path to save results and models - save_model=False # Do not save models - ) - - # Create the pipeline - pipeline = create_pipeline() - - # Define parameter grid if needed (optional) - param_grid = { - 'svc__C': [0.1, 1, 10], - 'svc__kernel': ['linear'] - } - - # Run the evaluation - results = evaluation.evaluate( - dataset=None, # Not used in CrossDatasetEvaluation - pipelines={'CSP_SVM': pipeline}, - param_grid=param_grid - ) - - # Collect and display results - for res in results: - # Create log message with available information - log_msg = [ - f"Dataset: {res['dataset'].code}", - f"Subject: {res['subject']}", - f"Pipeline: {res['pipeline']}", - f"Score: {res['score']:.4f}", - f"Time: {res['time']:.2f}s" - ] - - # Add session info if available - if 'session' in res: - log_msg.insert(2, f"Session: {res['session']}") - - logger.info(", ".join(log_msg)) - -if __name__ == "__main__": - main() + return pipeline + +# Define datasets +train_dataset = BNCI2014001() +test_dataset = Zhou2016() + +# Get common channels across datasets +common_channels = get_common_channels([train_dataset, test_dataset]) +print(f"\nCommon channels across datasets: {common_channels}\n") + +# Initialize the paradigm with common channels +paradigm = MotorImagery( + channels=common_channels, # Use common channels + n_classes=2, + fmin=8, + fmax=32 +) + +# Initialize the CrossDatasetEvaluation +evaluation = CrossDatasetEvaluation( + paradigm=paradigm, + train_dataset=train_dataset, + test_dataset=test_dataset, + hdf5_path="./res_test", + save_model=True +) + +# Run the evaluation +results = [] +for result in evaluation.evaluate( + dataset=None, + pipelines={'CSP_SVM': create_pipeline(common_channels)}, + param_grid=None +): + result['subject'] = 'all' + print(f"Cross-dataset score: {result.get('score', 'N/A'):.3f}") + results.append(result) + +# Convert list of results to DataFrame +results_df = pd.DataFrame(results) +results_df['dataset'] = results_df['dataset'].apply(lambda x: x.__class__.__name__) + +# Print evaluation scores +print("\nCross-dataset evaluation scores:") +print(results_df[['dataset', 'score', 'time']]) + +# Plot the results +score_plot(results_df) +plt.show() \ No newline at end of file diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 2451db8ff..abc1292c3 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -795,34 +795,17 @@ class CrossDatasetEvaluation(BaseEvaluation): Pre-trained model to use (if None, will train from scratch) fine_tune : bool Whether to fine-tune the pretrained model on train_dataset - target_channels : list or None - List of channel names to use. If None, will use all channels from train_dataset sfreq : float Target sampling frequency for all datasets - channel_strategy : str - Strategy for handling different channel configurations: - - 'zero': Zero-filling for missing channels (default) - - 'ssi': Spherical spline interpolation - - 'subset': Use only common channels across datasets - montage : str or mne.channels.DigMontage - EEG montage to use for SSI. Default is 'standard_1020'. - Can be a string (standard montage name) or custom DigMontage. - min_channels : int - Minimum number of common channels required for subset strategy """ def __init__( - self, - train_dataset, - test_dataset, - pretrained_model=None, - fine_tune=True, - target_channels=None, - sfreq=128, - channel_strategy='zero', - montage='standard_1020', - min_channels=3, - **kwargs - + self, + train_dataset, + test_dataset, + pretrained_model=None, + fine_tune=True, + sfreq=128, + **kwargs ): super().__init__(**kwargs) self.train_dataset = train_dataset if isinstance(train_dataset, list) else [train_dataset] @@ -830,46 +813,11 @@ def __init__( self.pretrained_model = pretrained_model self.fine_tune = fine_tune self.sfreq = sfreq - self.channel_strategy = channel_strategy - self.montage = montage - self.min_channels = min_channels - - # Get channels from paradigm only - all_train_channels = set() - for dataset in self.train_dataset: - # Get channels from paradigm - X, _, _ = self.paradigm.get_data(dataset, [dataset.subject_list[0]], return_epochs=True) - all_train_channels.update(X.ch_names) - - # Set target channels based on strategy - self.target_channels = (target_channels if target_channels is not None - else list(all_train_channels)) - # Load montage if SSI strategy is selected - if self.channel_strategy == 'ssi': - self._setup_montage() - - # Validate datasets and channel strategy + # Validate datasets self._validate_datasets() - def _setup_montage(self): - """Set up and validate EEG montage for SSI strategy.""" - try: - if isinstance(self.montage, str): - self.montage = make_standard_montage(self.montage) - - # Verify target channels exist in montage - missing_in_montage = set(self.target_channels) - set(self.montage.ch_names) - if missing_in_montage: - raise ValueError( - f"Channels {missing_in_montage} not found in montage. " - "Please use a different montage or channel strategy." - ) - except Exception as e: - raise ValueError(f"Error setting up montage: {str(e)}") - def _validate_datasets(self): - """Validate compatibility of train and test datasets.""" all_datasets = self.train_dataset + self.test_dataset @@ -877,115 +825,6 @@ def _validate_datasets(self): for dataset in all_datasets: if not self.paradigm.is_valid(dataset): raise ValueError(f"Dataset {dataset.code} not compatible with paradigm") - - # Validate channel strategy - valid_strategies = ['zero', 'ssi', 'subset'] - if self.channel_strategy not in valid_strategies: - raise ValueError(f"Invalid channel strategy. Must be one of {valid_strategies}") - - # For subset strategy, verify common channels exist - if self.channel_strategy == 'subset': - # Get channels from first dataset through paradigm - X, _, _ = self.paradigm.get_data(all_datasets[0], [all_datasets[0].subject_list[0]], return_epochs=True) - common_channels = set(X.ch_names) - - # Get channels from remaining datasets - for dataset in all_datasets[1:]: - X, _, _ = self.paradigm.get_data(dataset, [dataset.subject_list[0]], return_epochs=True) - common_channels &= set(X.ch_names) - - if len(common_channels) < self.min_channels: - raise ValueError( - f"Insufficient common channels found ({len(common_channels)}). " - f"Minimum required: {self.min_channels}" - ) - - self.target_channels = list(common_channels) - - def _resolve_channels(self, epochs, dataset_channels): - """Apply channel resolution strategy to match target channels.""" - try: - if self.channel_strategy == 'subset': - return epochs.pick_channels(self.target_channels, ordered=True) - - # Get missing and extra channels - missing_channels = list(set(self.target_channels) - set(dataset_channels)) - extra_channels = list(set(dataset_channels) - set(self.target_channels)) - - # Remove extra channels if any - if extra_channels: - epochs = epochs.drop_channels(extra_channels) - - # Handle missing channels - if missing_channels: - if self.channel_strategy == 'zero': - # Create new info with all target channels - info = mne.create_info( - ch_names=self.target_channels, - sfreq=epochs.info['sfreq'], - ch_types=['eeg'] * len(self.target_channels) - ) - - # Get current data - data = epochs.get_data() - n_epochs, _, n_times = data.shape - - # Create new data array with correct channel order - new_data = np.zeros((n_epochs, len(self.target_channels), n_times)) - - # Fill in existing channels - valid_channels = set(dataset_channels) & set(self.target_channels) - for ch in valid_channels: - target_idx = self.target_channels.index(ch) - data_idx = dataset_channels.index(ch) - if target_idx < len(self.target_channels) and data_idx < data.shape[1]: - new_data[:, target_idx, :] = data[:, data_idx, :] - - # Create new epochs with zero-filled channels - epochs = mne.EpochsArray(new_data, info) - - elif self.channel_strategy == 'ssi': - # Create temporary info with all channels - info = mne.create_info( - ch_names=self.target_channels, - sfreq=epochs.info['sfreq'], - ch_types=['eeg'] * len(self.target_channels) - ) - - # Get current data - data = epochs.get_data() - n_epochs, _, n_times = data.shape - - # Create temporary data array - temp_data = np.zeros((n_epochs, len(self.target_channels), n_times)) - - # Fill in existing channels - valid_channels = set(dataset_channels) & set(self.target_channels) - for ch in valid_channels: - target_idx = self.target_channels.index(ch) - data_idx = dataset_channels.index(ch) - if target_idx < len(self.target_channels) and data_idx < data.shape[1]: - temp_data[:, target_idx, :] = data[:, data_idx, :] - - # Create temporary epochs - temp_epochs = mne.EpochsArray(temp_data, info) - - # Set montage for interpolation - temp_epochs.set_montage(self.montage) - - # Mark missing channels as bad - temp_epochs.info['bads'] = missing_channels - - # Interpolate missing channels - epochs = temp_epochs.interpolate_bads( - method='spline', - reset_bads=True - ) - - return epochs - - except Exception as e: - raise RuntimeError(f"Error resolving channels: {str(e)}") def _prepare_data(self, dataset): """Prepare data from a dataset with consistent format.""" @@ -994,10 +833,7 @@ def _prepare_data(self, dataset): subjects=dataset.subject_list, return_epochs=True ) - - # Apply channel resolution strategy using channels from epochs - X = self._resolve_channels(X, X.ch_names) - + # Resample if needed if X.info['sfreq'] != self.sfreq: X = X.resample(self.sfreq) @@ -1005,48 +841,47 @@ def _prepare_data(self, dataset): return X, y, metadata def evaluate(self, dataset, pipelines, groups=None, param_grid=None, **kwargs): - """Evaluate models across datasets. - - Parameters - ---------- - dataset : Dataset - The dataset to evaluate - pipelines : dict - Dictionary of pipelines to evaluate - groups : array-like, optional - Groups for cross-validation - param_grid : dict, optional - Parameters grid for grid search - **kwargs : dict - Additional parameters - - Yields - ------ - dict - Evaluation results - """ + """Evaluate models across datasets.""" log.info("Starting cross-dataset evaluation") # Prepare training data train_X, train_y, train_metadata = [], [], [] for train_ds in self.train_dataset: - X, y, meta = self._prepare_data(train_ds) - train_X.append(X) - train_y.extend(y) - train_metadata.append(meta) - - # Combine training data - train_X = mne.concatenate_epochs(train_X) - train_y = np.array(train_y) + # Get raw data directly from paradigm + raw, labels, events = self.paradigm.get_data( + dataset=train_ds, + subjects=train_ds.subject_list, + return_epochs=False + ) + + # Skip if no events found + if len(events) == 0: + log.warning(f"No events found in training dataset {train_ds.code}, skipping...") + continue + + train_X.append(raw) # Just pass the raw data + train_y.extend(labels) + train_metadata.append({'events': events}) - # Convert EpochsArray to numpy array - train_X_data = train_X.get_data() + if not train_X: + raise ValueError("No valid training data found with events") # For each test dataset for test_ds in self.test_dataset: - test_X, test_y, test_metadata = self._prepare_data(test_ds) - # Convert test data to numpy array - test_X_data = test_X.get_data() + # Get raw data directly from paradigm + raw, labels, events = self.paradigm.get_data( + dataset=test_ds, + subjects=test_ds.subject_list, + return_epochs=False + ) + + # Skip if no events found + if len(events) == 0: + log.warning(f"No events found in test dataset {test_ds.code}, skipping...") + continue + + test_X = raw # Just pass the raw data + test_y = labels # Evaluate each pipeline for name, pipeline in pipelines.items(): @@ -1057,41 +892,25 @@ def evaluate(self, dataset, pipelines, groups=None, param_grid=None, **kwargs): t_start = time() try: - # If using pretrained model, clone it - if self.pretrained_model is not None: - model = clone(self.pretrained_model) - if self.fine_tune: - model.fit(train_X_data, train_y) - else: - # Train from scratch - model = clone(pipeline).fit(train_X_data, train_y) + # Train from scratch + model = clone(pipeline).fit(train_X[0], train_y) # Use first training dataset # Evaluate on test subjects - for subject in test_ds.subject_list: - subject_mask = test_metadata.subject == subject - subject_X = test_X_data[subject_mask] - subject_y = test_y[subject_mask] - - score = model.score(subject_X, subject_y) - - duration = time() - t_start - - result = { - 'time': duration, - 'dataset': test_ds, - 'subject': subject, - 'score': score, - 'n_samples': len(subject_y), - 'n_channels': subject_X.shape[1], - 'pipeline': name, - 'training_datasets': [ds.code for ds in self.train_dataset], - 'pretrained': self.pretrained_model is not None, - 'fine_tuned': self.fine_tune, - 'channel_strategy': self.channel_strategy - } - - yield result - + score = model.score(test_X, test_y) + + duration = time() - t_start + + result = { + 'time': duration, + 'dataset': test_ds, + 'score': score, + 'n_samples': len(test_y), + 'pipeline': name, + 'training_datasets': [ds.code for ds in self.train_dataset] + } + + yield result + except Exception as e: log.error(f"Error evaluating pipeline {name}: {str(e)}") raise From 5228cbcf10132bfb9594293ef6354e2e38a6eabd Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Sat, 8 Feb 2025 17:32:16 +0100 Subject: [PATCH 04/16] deep learning example working, pls make it clean --- examples/cross_dataset_braindecode.py | 272 ++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 examples/cross_dataset_braindecode.py diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py new file mode 100644 index 000000000..3e94d1ecc --- /dev/null +++ b/examples/cross_dataset_braindecode.py @@ -0,0 +1,272 @@ +""" +Cross-Dataset Brain Decoding with Deep Learning +============================================= +This example shows how to train deep learning models on one dataset +and test on another using Braindecode. +""" + +from braindecode.datasets import MOABBDataset +from numpy import multiply +from braindecode.preprocessing import (Preprocessor, + exponential_moving_standardize, + preprocess, + create_windows_from_events) +from braindecode.models import ShallowFBCSPNet +from braindecode.util import set_random_seeds +from braindecode import EEGClassifier +from skorch.callbacks import LRScheduler +from skorch.helper import predefined_split +import torch +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.lines import Line2D +from sklearn.metrics import confusion_matrix +from braindecode.visualization import plot_confusion_matrix +import logging +import mne +from numpy import array +import numpy as np + +# Configure logging +logging.basicConfig(level=logging.WARNING) + +def get_common_channels(train_dataset, test_dataset): + """Get channels that are available across both datasets.""" + train_chans = train_dataset.datasets[0].raw.ch_names + test_chans = test_dataset.datasets[0].raw.ch_names + common_channels = sorted(list(set(train_chans).intersection(set(test_chans)))) + print(f"\nCommon channels across datasets: {common_channels}\n") + return common_channels + +def get_common_events(train_dataset, test_dataset): + """Get events that are available across both datasets.""" + # Get events from first subject of each dataset + train_events = train_dataset.datasets[0].raw.annotations.description + test_events = test_dataset.datasets[0].raw.annotations.description + + # Find common events + common_events = sorted(list(set(train_events).intersection(set(test_events)))) + print(f"\nCommon events across datasets: {common_events}\n") + + # Create event mapping (event description -> numerical ID) + event_id = {str(event): idx for idx, event in enumerate(common_events)} + print(f"Event mapping: {event_id}\n") + + return event_id + +def standardize_windows(train_dataset, test_dataset, common_channels, event_id): + """Standardize datasets with consistent preprocessing.""" + # Define preprocessing parameters + target_sfreq = 100 # Target sampling frequency + + print("\nInitial dataset properties:") + for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: + for i, ds in enumerate(dataset.datasets): + print(f"{name} dataset {i} sampling rate: {ds.raw.info['sfreq']} Hz") + + # Define preprocessors for standardization + preprocessors = [ + Preprocessor('pick_channels', ch_names=common_channels, ordered=True), + Preprocessor('resample', sfreq=target_sfreq), # Standardize sampling rate + Preprocessor(lambda data: multiply(data, 1e6)), # Convert to microvolts + Preprocessor('filter', l_freq=4., h_freq=38.), # Bandpass filter + Preprocessor(exponential_moving_standardize, # Normalize + factor_new=1e-3, init_block_size=1000) + ] + + # Apply preprocessing + preprocess(train_dataset, preprocessors, n_jobs=-1) + preprocess(test_dataset, preprocessors, n_jobs=-1) + + print("\nAfter resampling:") + for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: + for i, ds in enumerate(dataset.datasets): + print(f"{name} dataset {i} sampling rate: {ds.raw.info['sfreq']} Hz") + + # Fixed window parameters (in seconds) + window_start = -0.5 # Start 0.5s before event + window_duration = 4.0 # Increased from 3.0 to ensure enough samples for kernel + + # Convert to samples based on target frequency + samples_before = int(abs(window_start) * target_sfreq) # 50 samples + samples_after = int(window_duration * target_sfreq) # 400 samples + + print(f"\nWindow configuration:") + print(f"Sampling frequency: {target_sfreq} Hz") + print(f"Window: {window_start}s to {window_duration}s") + print(f"Samples before: {samples_before}") + print(f"Samples after: {samples_after}") + print(f"Total window length: {samples_before + samples_after} samples") + + # Standardize event durations to 0 for both datasets + for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: + for i, ds in enumerate(dataset.datasets): + # Drop last event and standardize durations + events = ds.raw.annotations[:-1] + new_annotations = mne.Annotations( + onset=events.onset, + duration=np.zeros_like(events.duration), # Set all durations to 0 + description=events.description + ) + ds.raw.set_annotations(new_annotations) + + print(f"\n{name} dataset {i}:") + print(f"Number of events: {len(ds.raw.annotations)}") + print(f"Event timings: {ds.raw.annotations.onset[:5]}") + print(f"Event durations: {ds.raw.annotations.duration[:5]}") + + # Create windows with explicit trial_stop_offset_samples + def create_fixed_windows(dataset): + return create_windows_from_events( + dataset, + trial_start_offset_samples=-samples_before, + trial_stop_offset_samples=samples_after, + preload=True, + mapping=event_id, + window_size_samples=samples_before + samples_after, # Force window size + window_stride_samples=samples_before + samples_after # Add stride parameter + ) + + # Create windows + train_windows = create_fixed_windows(train_dataset) + test_windows = create_fixed_windows(test_dataset) + + # Verify window shapes + train_shape = train_windows[0][0].shape + test_shape = test_windows[0][0].shape + print(f"\nActual window shapes:") + print(f"Train windows: {train_shape}") + print(f"Test windows: {test_shape}") + print(f"Expected shape: (9, {samples_before + samples_after})") + + if train_shape != test_shape: + print("\nWindow size mismatch analysis:") + print(f"Train window size: {train_shape[1]} samples") + print(f"Test window size: {test_shape[1]} samples") + print(f"Difference: {train_shape[1] - test_shape[1]} samples") + print(f"In seconds: {(train_shape[1] - test_shape[1])/target_sfreq:.2f}s") + raise ValueError(f"Window shapes don't match: train={train_shape}, test={test_shape}") + + window_length = train_shape[1] + print(f"Final window length: {window_length} samples") + + return train_windows, test_windows, window_length, samples_before, samples_after + +# Load datasets +train_dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[3, 4]) +test_dataset = MOABBDataset(dataset_name="Zhou2016", subject_ids=[1]) + +# Get common channels and events +common_channels = get_common_channels(train_dataset, test_dataset) +event_id = get_common_events(train_dataset, test_dataset) +n_classes = len(event_id) + +# Standardize datasets and get window length +train_windows, test_windows, window_length, samples_before, samples_after = standardize_windows( + train_dataset, test_dataset, common_channels, event_id +) + +# Split training data +splitted = train_windows.split('session') +train_set = splitted['0train'] +valid_set = splitted['1test'] + +# Setup compute device +cuda = torch.cuda.is_available() +device = 'cuda' if cuda else 'cpu' +if cuda: + torch.backends.cudnn.benchmark = True + +# Set random seed +set_random_seeds(seed=20200220, cuda=cuda) + +# Calculate model parameters based on standardized data +n_chans = len(common_channels) +input_window_samples = window_length # 450 samples + +# Create model with adjusted parameters +model = ShallowFBCSPNet( + n_chans, + n_classes, + n_times=input_window_samples, # Use n_times instead of deprecated input_window_samples + n_filters_time=40, + filter_time_length=20, # Reduced from default + pool_time_length=35, # Reduced from default + pool_time_stride=7, # Reduced from default + final_conv_length='auto' # Let the model calculate the appropriate final conv length +) + +if cuda: + model = model.cuda() + +# Create and train classifier +clf = EEGClassifier( + model, + criterion=torch.nn.NLLLoss, + optimizer=torch.optim.AdamW, + train_split=predefined_split(valid_set), + optimizer__lr=0.0625 * 0.01, + optimizer__weight_decay=0, + batch_size=64, + callbacks=[ + "accuracy", + ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=11)), + ], + device=device, + classes=list(range(n_classes)), +) + +# Train the model +_ = clf.fit(train_set, y=None, epochs=12) + +# Get test labels from the windows +y_true = test_windows.get_metadata().target +y_pred = clf.predict(test_windows) +test_accuracy = np.mean(y_true == y_pred) +print(f"\nTest accuracy: {test_accuracy:.4f}") + +# Generate confusion matrix for test set +confusion_mat = confusion_matrix(y_true, y_pred) + +# Plot confusion matrix with dynamic class names +plot_confusion_matrix(confusion_mat, class_names=list(event_id.keys())) +plt.show() + +# Create visualization +fig = plt.figure(figsize=(10, 5)) +plt.plot(clf.history[:, 'train_loss'], label='Training Loss') +plt.plot(clf.history[:, 'valid_loss'], label='Validation Loss') +plt.legend() +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Training and Validation Loss Over Time') +plt.show() + +# Plot training curves +results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'] +df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns, + index=clf.history[:, 'epoch']) + +df = df.assign(train_misclass=100 - 100 * df.train_accuracy, + valid_misclass=100 - 100 * df.valid_accuracy) + +# Plot results +fig, ax1 = plt.subplots(figsize=(8, 3)) +df.loc[:, ['train_loss', 'valid_loss']].plot( + ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False) + +ax1.tick_params(axis='y', labelcolor='tab:blue') +ax1.set_ylabel("Loss", color='tab:blue') + +ax2 = ax1.twinx() +df.loc[:, ['train_misclass', 'valid_misclass']].plot( + ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False) +ax2.tick_params(axis='y', labelcolor='tab:red') +ax2.set_ylabel("Misclassification Rate [%]", color='tab:red') +ax2.set_ylim(ax2.get_ylim()[0], 85) + +handles = [] +handles.append(Line2D([0], [0], color='black', linestyle='-', label='Train')) +handles.append(Line2D([0], [0], color='black', linestyle=':', label='Valid')) +plt.legend(handles, [h.get_label() for h in handles]) +plt.tight_layout() From fecf06cf5dedf42811ed608bf69e4725c38c60d6 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Mon, 10 Feb 2025 01:38:36 +0100 Subject: [PATCH 05/16] multiple dataset training and testing with braindecode working - please improve on spline channel interpolation and handling of events. Please add dataset augmentation and load balancing if necessary --- examples/cross_dataset_braindecode.py | 211 ++++++++++++++++++++++---- 1 file changed, 178 insertions(+), 33 deletions(-) diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index 3e94d1ecc..0c1e14460 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -26,6 +26,7 @@ import mne from numpy import array import numpy as np +from braindecode.datasets import BaseConcatDataset # Configure logging logging.basicConfig(level=logging.WARNING) @@ -38,23 +39,23 @@ def get_common_channels(train_dataset, test_dataset): print(f"\nCommon channels across datasets: {common_channels}\n") return common_channels -def get_common_events(train_dataset, test_dataset): - """Get events that are available across both datasets.""" +def get_all_events(train_dataset, test_dataset): + """Get all unique events across datasets.""" # Get events from first subject of each dataset train_events = train_dataset.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description - # Find common events - common_events = sorted(list(set(train_events).intersection(set(test_events)))) - print(f"\nCommon events across datasets: {common_events}\n") + # Get all unique events + all_events = sorted(list(set(train_events).union(set(test_events)))) + print(f"\nAll unique events across datasets: {all_events}\n") # Create event mapping (event description -> numerical ID) - event_id = {str(event): idx for idx, event in enumerate(common_events)} + event_id = {str(event): idx for idx, event in enumerate(all_events)} print(f"Event mapping: {event_id}\n") return event_id -def standardize_windows(train_dataset, test_dataset, common_channels, event_id): +def standardize_windows(train_dataset, test_dataset, all_channels, event_id): """Standardize datasets with consistent preprocessing.""" # Define preprocessing parameters target_sfreq = 100 # Target sampling frequency @@ -64,13 +65,79 @@ def standardize_windows(train_dataset, test_dataset, common_channels, event_id): for i, ds in enumerate(dataset.datasets): print(f"{name} dataset {i} sampling rate: {ds.raw.info['sfreq']} Hz") - # Define preprocessors for standardization + def interpolate_missing_channels(raw_data, all_channels): + """Interpolate missing channels using spherical spline interpolation.""" + if isinstance(raw_data, np.ndarray): + return raw_data + + if not isinstance(raw_data, mne.io.Raw): + raise TypeError("Expected MNE Raw object") + + missing_channels = [ch for ch in all_channels if ch not in raw_data.ch_names] + existing_channels = raw_data.ch_names + + print("\nChannel Information:") + print(f"Total channels needed: {len(all_channels)}") + print(f"Existing channels: {len(existing_channels)}") + print(f"Missing channels to interpolate: {len(missing_channels)}") + + if missing_channels: + print("\nMissing channels:") + for ch in missing_channels: + print(f"- {ch}") + + # Mark missing channels as bad + raw_data.info['bads'] = missing_channels + + # Add missing channels (temporarily with zeros) + print("\nAdding temporary channels for interpolation...") + raw_data.add_channels([ + mne.io.RawArray( + np.zeros((1, len(raw_data.times))), + mne.create_info([ch], raw_data.info['sfreq'], ['eeg']) + ) for ch in missing_channels + ]) + + # Get data before interpolation + data_before = raw_data.get_data() + + # Interpolate the bad channels + print("Performing spherical spline interpolation...") + raw_data.interpolate_bads(reset_bads=True) + + # Get data after interpolation + data_after = raw_data.get_data() + + # Calculate and print interpolation statistics + for idx, ch in enumerate(missing_channels): + ch_idx = raw_data.ch_names.index(ch) + interpolated_data = data_after[ch_idx] + stats = { + 'mean': np.mean(interpolated_data), + 'std': np.std(interpolated_data), + 'min': np.min(interpolated_data), + 'max': np.max(interpolated_data) + } + print(f"\nInterpolated channel {ch} statistics:") + print(f"- Mean: {stats['mean']:.2f}") + print(f"- Std: {stats['std']:.2f}") + print(f"- Range: [{stats['min']:.2f}, {stats['max']:.2f}]") + + print("\nInterpolation complete.") + else: + print("No channels need interpolation.") + + return raw_data + + # Modified preprocessors to handle missing channels with interpolation preprocessors = [ - Preprocessor('pick_channels', ch_names=common_channels, ordered=True), - Preprocessor('resample', sfreq=target_sfreq), # Standardize sampling rate - Preprocessor(lambda data: multiply(data, 1e6)), # Convert to microvolts - Preprocessor('filter', l_freq=4., h_freq=38.), # Bandpass filter - Preprocessor(exponential_moving_standardize, # Normalize + # Add and interpolate missing channels only for MNE Raw objects + Preprocessor(lambda raw: interpolate_missing_channels(raw, all_channels)), + Preprocessor('pick_channels', ch_names=all_channels, ordered=True), + Preprocessor('resample', sfreq=target_sfreq), + Preprocessor(lambda data: multiply(data, 1e6)), + Preprocessor('filter', l_freq=4., h_freq=38.), + Preprocessor(exponential_moving_standardize, factor_new=1e-3, init_block_size=1000) ] @@ -152,22 +219,100 @@ def create_fixed_windows(dataset): return train_windows, test_windows, window_length, samples_before, samples_after -# Load datasets -train_dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[3, 4]) -test_dataset = MOABBDataset(dataset_name="Zhou2016", subject_ids=[1]) +# Load datasets with validation +def load_and_validate_dataset(dataset_name, subject_ids): + """Load dataset and validate its contents.""" + dataset = MOABBDataset(dataset_name=dataset_name, subject_ids=subject_ids) + + print(f"\nValidating dataset: {dataset_name}") + for i, ds in enumerate(dataset.datasets): + events = ds.raw.annotations + print(f"\nSubject {i+1}:") # Changed from subject_ids[i] to i+1 + print(f"Number of events: {len(events)}") + print(f"Unique event types: {set(events.description)}") + print(f"First 5 event timings: {events.onset[:5]}") + print(f"Sample event descriptions: {list(events.description[:5])}") + + return dataset -# Get common channels and events -common_channels = get_common_channels(train_dataset, test_dataset) -event_id = get_common_events(train_dataset, test_dataset) +# Load datasets with validation +print("\nLoading training datasets...") +train_dataset_1 = load_and_validate_dataset("BNCI2014_001", subject_ids=[1, 2, 3, 4]) +train_dataset_2 = load_and_validate_dataset("BNCI2015_001", subject_ids=[1, 2, 3, 4]) +train_dataset_3 = load_and_validate_dataset("BNCI2014_004", subject_ids=[1, 2, 3, 4]) + +print("\nLoading test dataset...") +test_dataset = load_and_validate_dataset("Zhou2016", subject_ids=[1]) + +# Verify datasets are different +print("\nVerifying dataset uniqueness...") +for name1, ds1 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), + ("Train3", train_dataset_3), ("Test", test_dataset)]: + for name2, ds2 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), + ("Train3", train_dataset_3), ("Test", test_dataset)]: + if name1 < name2: # Compare each pair only once + print(f"\nComparing {name1} vs {name2}:") + # Compare first subject of each dataset + events1 = ds1.datasets[0].raw.annotations + events2 = ds2.datasets[0].raw.annotations + print(f"Event counts: {len(events1)} vs {len(events2)}") + print(f"Event types: {set(events1.description)} vs {set(events2.description)}") + print(f"First timing: {events1.onset[0]} vs {events2.onset[0]}") + +# Get common channels across all datasets +train_chans_1 = train_dataset_1.datasets[0].raw.ch_names +train_chans_2 = train_dataset_2.datasets[0].raw.ch_names +train_chans_3 = train_dataset_3.datasets[0].raw.ch_names +test_chans = test_dataset.datasets[0].raw.ch_names + +common_channels = sorted(list(set(train_chans_1) + .intersection(set(train_chans_2)) + .intersection(set(train_chans_3)) + .intersection(set(test_chans)))) +print(f"\nCommon channels across all datasets: {common_channels}\n") + +# Get all events across all datasets +train_events_1 = train_dataset_1.datasets[0].raw.annotations.description +train_events_2 = train_dataset_2.datasets[0].raw.annotations.description +train_events_3 = train_dataset_3.datasets[0].raw.annotations.description +test_events = test_dataset.datasets[0].raw.annotations.description + +all_events = sorted(list(set(train_events_1) + .union(set(train_events_2)) + .union(set(train_events_3)) + .union(set(test_events)))) +print(f"\nAll unique events across datasets: {all_events}\n") + +event_id = {str(event): idx for idx, event in enumerate(all_events)} +print(f"Event mapping: {event_id}\n") + +# Define number of classes based on all unique events n_classes = len(event_id) +print(f"Number of classes: {n_classes}\n") -# Standardize datasets and get window length -train_windows, test_windows, window_length, samples_before, samples_after = standardize_windows( - train_dataset, test_dataset, common_channels, event_id +# Process all training datasets +train_windows_1, _, _, _, _ = standardize_windows( + train_dataset_1, test_dataset, common_channels, event_id +) +train_windows_2, _, _, _, _ = standardize_windows( + train_dataset_2, test_dataset, common_channels, event_id ) +train_windows_3, _, _, _, _ = standardize_windows( + train_dataset_3, test_dataset, common_channels, event_id +) +train_windows_test, test_windows, window_length, samples_before, samples_after = standardize_windows( + train_dataset_1, test_dataset, common_channels, event_id +) + +# Combine all training windows +combined_train_windows = BaseConcatDataset([ + train_windows_1, + train_windows_2, + train_windows_3 +]) -# Split training data -splitted = train_windows.split('session') +# Split training data using only the combined dataset +splitted = combined_train_windows.split('session') train_set = splitted['0train'] valid_set = splitted['1test'] @@ -184,16 +329,16 @@ def create_fixed_windows(dataset): n_chans = len(common_channels) input_window_samples = window_length # 450 samples -# Create model with adjusted parameters +# Create model with adjusted parameters for all classes model = ShallowFBCSPNet( n_chans, - n_classes, - n_times=input_window_samples, # Use n_times instead of deprecated input_window_samples + n_classes, # This will now be the total number of unique events + n_times=input_window_samples, n_filters_time=40, - filter_time_length=20, # Reduced from default - pool_time_length=35, # Reduced from default - pool_time_stride=7, # Reduced from default - final_conv_length='auto' # Let the model calculate the appropriate final conv length + filter_time_length=20, + pool_time_length=35, + pool_time_stride=7, + final_conv_length='auto' ) if cuda: @@ -217,7 +362,7 @@ def create_fixed_windows(dataset): ) # Train the model -_ = clf.fit(train_set, y=None, epochs=12) +_ = clf.fit(train_set, y=None, epochs=100) # Get test labels from the windows y_true = test_windows.get_metadata().target @@ -228,7 +373,7 @@ def create_fixed_windows(dataset): # Generate confusion matrix for test set confusion_mat = confusion_matrix(y_true, y_pred) -# Plot confusion matrix with dynamic class names +# Plot confusion matrix with all event names plot_confusion_matrix(confusion_mat, class_names=list(event_id.keys())) plt.show() From d9f598c7c46597ad9a0f610ec933b02bca475f17 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Mon, 24 Feb 2025 18:02:46 +0100 Subject: [PATCH 06/16] Few changes --- examples/cross_dataset_braindecode.py | 2 +- run_cross_dataset.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 run_cross_dataset.py diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index 0c1e14460..f2e019e74 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -242,7 +242,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_dataset_3 = load_and_validate_dataset("BNCI2014_004", subject_ids=[1, 2, 3, 4]) print("\nLoading test dataset...") -test_dataset = load_and_validate_dataset("Zhou2016", subject_ids=[1]) +test_dataset = load_and_validate_dataset("Zhou2016", subject_ids=[1, 2, 3]) # Verify datasets are different print("\nVerifying dataset uniqueness...") diff --git a/run_cross_dataset.py b/run_cross_dataset.py new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/run_cross_dataset.py @@ -0,0 +1 @@ + \ No newline at end of file From 44cd428a4d52bf4883da8adc9fc3fb429aae6f0f Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Mon, 24 Feb 2025 18:46:59 +0100 Subject: [PATCH 07/16] cross dataset eval with examples --- examples/cross_dataset.py | 125 ++++++--- examples/cross_dataset_braindecode.py | 351 ++++++++++++++------------ examples/testing.py | 8 - moabb/evaluations/__init__.py | 2 +- moabb/evaluations/evaluations.py | 101 ++++---- run_cross_dataset.py | 1 - 6 files changed, 328 insertions(+), 260 deletions(-) delete mode 100644 examples/testing.py diff --git a/examples/cross_dataset.py b/examples/cross_dataset.py index 374e12379..41e970e22 100644 --- a/examples/cross_dataset.py +++ b/examples/cross_dataset.py @@ -1,49 +1,102 @@ -from moabb import set_log_level -from moabb.datasets import BNCI2014001, Zhou2016 -from moabb.paradigms import MotorImagery -from moabb.evaluations.evaluations import CrossDatasetEvaluation -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import FunctionTransformer +"""Cross-dataset motor imagery classification example. + +This example demonstrates how to perform cross-dataset evaluation using MOABB, +training on one dataset and testing on another. +""" + +# Standard library imports +import logging +from typing import Any, List + +import matplotlib.pyplot as plt + +# Third-party imports +import mne +import numpy as np +import pandas as pd from pyriemann.estimation import Covariances from pyriemann.spatialfilters import CSP +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import FunctionTransformer from sklearn.svm import SVC -import matplotlib.pyplot as plt + +# MOABB imports +from moabb import set_log_level from moabb.analysis.plotting import score_plot -import pandas as pd -import mne -import logging +from moabb.datasets import BNCI2014001, Zhou2016 +from moabb.evaluations.evaluations import CrossDatasetEvaluation +from moabb.paradigms import MotorImagery -# Configure logging - reduce verbosity -set_log_level("WARNING") # Changed from "info" to "WARNING" -logging.getLogger('mne').setLevel(logging.ERROR) # Reduce MNE logging -def get_common_channels(datasets): - """Get channels that are available across all datasets.""" +# Configure logging +set_log_level("WARNING") +logging.getLogger('mne').setLevel(logging.ERROR) + + +def get_common_channels(datasets: List[Any]) -> List[str]: + """Get channels that are available across all datasets. + + Parameters + ---------- + datasets : List[Dataset] + List of MOABB dataset objects to analyze + + Returns + ------- + List[str] + Sorted list of common channel names + """ all_channels = [] for dataset in datasets: # Get a sample raw from each dataset subject = dataset.subject_list[0] raw_dict = dataset.get_data([subject]) + # Navigate through the nested dictionary structure - subject_data = raw_dict[subject] # Get subject's data - first_session = list(subject_data.keys())[0] # Get first session - first_run = list(subject_data[first_session].keys())[0] # Get first run - raw = subject_data[first_session][first_run] # Get raw data + subject_data = raw_dict[subject] + first_session = list(subject_data.keys())[0] + first_run = list(subject_data[first_session].keys())[0] + raw = subject_data[first_session][first_run] + all_channels.append(raw.ch_names) - + # Find common channels across all datasets common_channels = set.intersection(*map(set, all_channels)) return sorted(list(common_channels)) -def create_pipeline(common_channels) -> Pipeline: - """Create classification pipeline.""" - def raw_to_data(X): - """Convert raw MNE data to numpy array format""" + +def create_pipeline(common_channels: List[str]) -> Pipeline: + """Create classification pipeline with CSP and SVM. + + Parameters + ---------- + common_channels : List[str] + List of channel names to use in the pipeline + + Returns + ------- + Pipeline + Sklearn pipeline for classification + """ + def raw_to_data(X: np.ndarray) -> np.ndarray: + """Convert raw MNE data to numpy array format. + + Parameters + ---------- + X : np.ndarray or mne.io.Raw + Input data to convert + + Returns + ------- + np.ndarray + Converted data array + """ if hasattr(X, 'get_data'): - # Get only common channels to ensure consistency - picks = mne.pick_channels(X.info['ch_names'], - include=common_channels, - ordered=True) + picks = mne.pick_channels( + X.info['ch_names'], + include=common_channels, + ordered=True + ) data = X.get_data() if data.ndim == 2: data = data.reshape(1, *data.shape) @@ -54,12 +107,13 @@ def raw_to_data(X): pipeline = Pipeline([ ('to_array', FunctionTransformer(raw_to_data)), ('covariances', Covariances(estimator='oas')), - ('csp', CSP(nfilter=4, log=True)), # Changed n_components to nfilter, removed invalid parameters + ('csp', CSP(nfilter=4, log=True)), ('classifier', SVC(kernel='rbf', C=0.1)) ]) return pipeline + # Define datasets train_dataset = BNCI2014001() test_dataset = Zhou2016() @@ -70,7 +124,7 @@ def raw_to_data(X): # Initialize the paradigm with common channels paradigm = MotorImagery( - channels=common_channels, # Use common channels + channels=common_channels, n_classes=2, fmin=8, fmax=32 @@ -89,16 +143,17 @@ def raw_to_data(X): results = [] for result in evaluation.evaluate( dataset=None, - pipelines={'CSP_SVM': create_pipeline(common_channels)}, - param_grid=None + pipelines={'CSP_SVM': create_pipeline(common_channels)} ): result['subject'] = 'all' print(f"Cross-dataset score: {result.get('score', 'N/A'):.3f}") results.append(result) -# Convert list of results to DataFrame +# Convert results to DataFrame and process results_df = pd.DataFrame(results) -results_df['dataset'] = results_df['dataset'].apply(lambda x: x.__class__.__name__) +results_df['dataset'] = results_df['dataset'].apply( + lambda x: x.__class__.__name__ +) # Print evaluation scores print("\nCross-dataset evaluation scores:") @@ -106,4 +161,4 @@ def raw_to_data(X): # Plot the results score_plot(results_df) -plt.show() \ No newline at end of file +plt.show() diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index f2e019e74..279fa6229 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -5,28 +5,30 @@ and test on another using Braindecode. """ -from braindecode.datasets import MOABBDataset -from numpy import multiply -from braindecode.preprocessing import (Preprocessor, - exponential_moving_standardize, - preprocess, - create_windows_from_events) -from braindecode.models import ShallowFBCSPNet -from braindecode.util import set_random_seeds -from braindecode import EEGClassifier -from skorch.callbacks import LRScheduler -from skorch.helper import predefined_split -import torch +import logging +from typing import Dict, List, Tuple + import matplotlib.pyplot as plt +import mne +import numpy as np import pandas as pd +import torch +from braindecode import EEGClassifier +from braindecode.datasets import BaseConcatDataset, MOABBDataset +from braindecode.models import ShallowFBCSPNet +from braindecode.preprocessing import ( + Preprocessor, + create_windows_from_events, + exponential_moving_standardize, + preprocess, +) +from braindecode.util import set_random_seeds +from braindecode.visualization import plot_confusion_matrix from matplotlib.lines import Line2D from sklearn.metrics import confusion_matrix -from braindecode.visualization import plot_confusion_matrix -import logging -import mne -from numpy import array -import numpy as np -from braindecode.datasets import BaseConcatDataset +from skorch.callbacks import LRScheduler +from skorch.helper import predefined_split + # Configure logging logging.basicConfig(level=logging.WARNING) @@ -44,186 +46,217 @@ def get_all_events(train_dataset, test_dataset): # Get events from first subject of each dataset train_events = train_dataset.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description - + # Get all unique events all_events = sorted(list(set(train_events).union(set(test_events)))) print(f"\nAll unique events across datasets: {all_events}\n") - + # Create event mapping (event description -> numerical ID) event_id = {str(event): idx for idx, event in enumerate(all_events)} print(f"Event mapping: {event_id}\n") - + return event_id -def standardize_windows(train_dataset, test_dataset, all_channels, event_id): - """Standardize datasets with consistent preprocessing.""" - # Define preprocessing parameters +def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) -> mne.io.Raw: + """Interpolate missing channels using spherical spline interpolation. + + Parameters + ---------- + raw_data : mne.io.Raw + Raw EEG data to process + all_channels : List[str] + List of all required channel names + + Returns + ------- + mne.io.Raw + Processed data with interpolated channels + + Raises + ------ + TypeError + If raw_data is not an MNE Raw object + """ + if isinstance(raw_data, np.ndarray): + return raw_data + + if not isinstance(raw_data, mne.io.Raw): + raise TypeError("Expected MNE Raw object") + + missing_channels = [ch for ch in all_channels if ch not in raw_data.ch_names] + existing_channels = raw_data.ch_names + + print("\nChannel Information:") + print(f"Total channels needed: {len(all_channels)}") + print(f"Existing channels: {len(existing_channels)}") + print(f"Missing channels to interpolate: {len(missing_channels)}") + + if missing_channels: + print("\nMissing channels:") + for ch in missing_channels: + print(f"- {ch}") + + # Mark missing channels as bad + raw_data.info['bads'] = missing_channels + + # Add missing channels (temporarily with zeros) + print("\nAdding temporary channels for interpolation...") + raw_data.add_channels([ + mne.io.RawArray( + np.zeros((1, len(raw_data.times))), + mne.create_info([ch], raw_data.info['sfreq'], ['eeg']) + ) for ch in missing_channels + ]) + + # Interpolate the bad channels + print("Performing spherical spline interpolation...") + raw_data.interpolate_bads(reset_bads=True) + + # Calculate and print interpolation statistics + data_after = raw_data.get_data() + for ch in missing_channels: + ch_idx = raw_data.ch_names.index(ch) + interpolated_data = data_after[ch_idx] + stats = { + 'mean': np.mean(interpolated_data), + 'std': np.std(interpolated_data), + 'min': np.min(interpolated_data), + 'max': np.max(interpolated_data) + } + print(f"\nInterpolated channel {ch} statistics:") + print(f"- Mean: {stats['mean']:.2f}") + print(f"- Std: {stats['std']:.2f}") + print(f"- Range: [{stats['min']:.2f}, {stats['max']:.2f}]") + + print("\nInterpolation complete.") + else: + print("No channels need interpolation.") + + return raw_data + +def create_fixed_windows( + dataset: BaseConcatDataset, + samples_before: int, + samples_after: int, + event_id: Dict[str, int] +) -> BaseConcatDataset: + """Create windows with consistent size across datasets. + + Parameters + ---------- + dataset : BaseConcatDataset + Dataset to create windows from + samples_before : int + Number of samples before event + samples_after : int + Number of samples after event + event_id : Dict[str, int] + Mapping of event names to numerical IDs + + Returns + ------- + BaseConcatDataset + Windowed dataset + """ + return create_windows_from_events( + dataset, + trial_start_offset_samples=-samples_before, + trial_stop_offset_samples=samples_after, + preload=True, + mapping=event_id, + window_size_samples=samples_before + samples_after, + window_stride_samples=samples_before + samples_after + ) + +def standardize_windows( + train_dataset: BaseConcatDataset, + test_dataset: BaseConcatDataset, + all_channels: List[str], + event_id: Dict[str, int] +) -> Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int]: + """Standardize datasets with consistent preprocessing. + + Parameters + ---------- + train_dataset : BaseConcatDataset + Training dataset to standardize + test_dataset : BaseConcatDataset + Test dataset to standardize + all_channels : List[str] + List of all required channel names + event_id : Dict[str, int] + Mapping of event names to numerical IDs + + Returns + ------- + Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int] + Processed training windows, test windows, window length, + samples before and after + """ target_sfreq = 100 # Target sampling frequency - + print("\nInitial dataset properties:") for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: for i, ds in enumerate(dataset.datasets): print(f"{name} dataset {i} sampling rate: {ds.raw.info['sfreq']} Hz") - - def interpolate_missing_channels(raw_data, all_channels): - """Interpolate missing channels using spherical spline interpolation.""" - if isinstance(raw_data, np.ndarray): - return raw_data - - if not isinstance(raw_data, mne.io.Raw): - raise TypeError("Expected MNE Raw object") - - missing_channels = [ch for ch in all_channels if ch not in raw_data.ch_names] - existing_channels = raw_data.ch_names - - print("\nChannel Information:") - print(f"Total channels needed: {len(all_channels)}") - print(f"Existing channels: {len(existing_channels)}") - print(f"Missing channels to interpolate: {len(missing_channels)}") - - if missing_channels: - print("\nMissing channels:") - for ch in missing_channels: - print(f"- {ch}") - - # Mark missing channels as bad - raw_data.info['bads'] = missing_channels - - # Add missing channels (temporarily with zeros) - print("\nAdding temporary channels for interpolation...") - raw_data.add_channels([ - mne.io.RawArray( - np.zeros((1, len(raw_data.times))), - mne.create_info([ch], raw_data.info['sfreq'], ['eeg']) - ) for ch in missing_channels - ]) - - # Get data before interpolation - data_before = raw_data.get_data() - - # Interpolate the bad channels - print("Performing spherical spline interpolation...") - raw_data.interpolate_bads(reset_bads=True) - - # Get data after interpolation - data_after = raw_data.get_data() - - # Calculate and print interpolation statistics - for idx, ch in enumerate(missing_channels): - ch_idx = raw_data.ch_names.index(ch) - interpolated_data = data_after[ch_idx] - stats = { - 'mean': np.mean(interpolated_data), - 'std': np.std(interpolated_data), - 'min': np.min(interpolated_data), - 'max': np.max(interpolated_data) - } - print(f"\nInterpolated channel {ch} statistics:") - print(f"- Mean: {stats['mean']:.2f}") - print(f"- Std: {stats['std']:.2f}") - print(f"- Range: [{stats['min']:.2f}, {stats['max']:.2f}]") - - print("\nInterpolation complete.") - else: - print("No channels need interpolation.") - - return raw_data - - # Modified preprocessors to handle missing channels with interpolation + + # Define preprocessing pipeline preprocessors = [ - # Add and interpolate missing channels only for MNE Raw objects Preprocessor(lambda raw: interpolate_missing_channels(raw, all_channels)), Preprocessor('pick_channels', ch_names=all_channels, ordered=True), Preprocessor('resample', sfreq=target_sfreq), - Preprocessor(lambda data: multiply(data, 1e6)), + Preprocessor(lambda data: np.multiply(data, 1e6)), Preprocessor('filter', l_freq=4., h_freq=38.), Preprocessor(exponential_moving_standardize, factor_new=1e-3, init_block_size=1000) ] - + # Apply preprocessing preprocess(train_dataset, preprocessors, n_jobs=-1) preprocess(test_dataset, preprocessors, n_jobs=-1) - - print("\nAfter resampling:") - for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: - for i, ds in enumerate(dataset.datasets): - print(f"{name} dataset {i} sampling rate: {ds.raw.info['sfreq']} Hz") - - # Fixed window parameters (in seconds) + + # Define window parameters window_start = -0.5 # Start 0.5s before event - window_duration = 4.0 # Increased from 3.0 to ensure enough samples for kernel - - # Convert to samples based on target frequency - samples_before = int(abs(window_start) * target_sfreq) # 50 samples - samples_after = int(window_duration * target_sfreq) # 400 samples - - print(f"\nWindow configuration:") - print(f"Sampling frequency: {target_sfreq} Hz") - print(f"Window: {window_start}s to {window_duration}s") - print(f"Samples before: {samples_before}") - print(f"Samples after: {samples_after}") - print(f"Total window length: {samples_before + samples_after} samples") - - # Standardize event durations to 0 for both datasets - for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: - for i, ds in enumerate(dataset.datasets): - # Drop last event and standardize durations + window_duration = 4.0 # Window duration in seconds + samples_before = int(abs(window_start) * target_sfreq) + samples_after = int(window_duration * target_sfreq) + + # Standardize event durations + for dataset in [train_dataset, test_dataset]: + for ds in dataset.datasets: events = ds.raw.annotations[:-1] new_annotations = mne.Annotations( onset=events.onset, - duration=np.zeros_like(events.duration), # Set all durations to 0 + duration=np.zeros_like(events.duration), description=events.description ) ds.raw.set_annotations(new_annotations) - - print(f"\n{name} dataset {i}:") - print(f"Number of events: {len(ds.raw.annotations)}") - print(f"Event timings: {ds.raw.annotations.onset[:5]}") - print(f"Event durations: {ds.raw.annotations.duration[:5]}") - - # Create windows with explicit trial_stop_offset_samples - def create_fixed_windows(dataset): - return create_windows_from_events( - dataset, - trial_start_offset_samples=-samples_before, - trial_stop_offset_samples=samples_after, - preload=True, - mapping=event_id, - window_size_samples=samples_before + samples_after, # Force window size - window_stride_samples=samples_before + samples_after # Add stride parameter - ) - - # Create windows - train_windows = create_fixed_windows(train_dataset) - test_windows = create_fixed_windows(test_dataset) - + + # Create and validate windows + train_windows = create_fixed_windows( + train_dataset, samples_before, samples_after, event_id + ) + test_windows = create_fixed_windows( + test_dataset, samples_before, samples_after, event_id + ) + # Verify window shapes train_shape = train_windows[0][0].shape test_shape = test_windows[0][0].shape - print(f"\nActual window shapes:") - print(f"Train windows: {train_shape}") - print(f"Test windows: {test_shape}") - print(f"Expected shape: (9, {samples_before + samples_after})") - + if train_shape != test_shape: - print("\nWindow size mismatch analysis:") - print(f"Train window size: {train_shape[1]} samples") - print(f"Test window size: {test_shape[1]} samples") - print(f"Difference: {train_shape[1] - test_shape[1]} samples") - print(f"In seconds: {(train_shape[1] - test_shape[1])/target_sfreq:.2f}s") - raise ValueError(f"Window shapes don't match: train={train_shape}, test={test_shape}") - + raise ValueError( + f"Window shapes don't match: train={train_shape}, test={test_shape}" + ) + window_length = train_shape[1] - print(f"Final window length: {window_length} samples") - return train_windows, test_windows, window_length, samples_before, samples_after # Load datasets with validation def load_and_validate_dataset(dataset_name, subject_ids): """Load dataset and validate its contents.""" dataset = MOABBDataset(dataset_name=dataset_name, subject_ids=subject_ids) - + print(f"\nValidating dataset: {dataset_name}") for i, ds in enumerate(dataset.datasets): events = ds.raw.annotations @@ -232,7 +265,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): print(f"Unique event types: {set(events.description)}") print(f"First 5 event timings: {events.onset[:5]}") print(f"Sample event descriptions: {list(events.description[:5])}") - + return dataset # Load datasets with validation @@ -246,9 +279,9 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Verify datasets are different print("\nVerifying dataset uniqueness...") -for name1, ds1 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), +for name1, ds1 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), ("Train3", train_dataset_3), ("Test", test_dataset)]: - for name2, ds2 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), + for name2, ds2 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), ("Train3", train_dataset_3), ("Test", test_dataset)]: if name1 < name2: # Compare each pair only once print(f"\nComparing {name1} vs {name2}:") diff --git a/examples/testing.py b/examples/testing.py deleted file mode 100644 index 6cfa2807e..000000000 --- a/examples/testing.py +++ /dev/null @@ -1,8 +0,0 @@ -from moabb.paradigms import MotorImagery -from moabb.datasets import utils - -paradigm = MotorImagery() -compatible_datasets = utils.dataset_search(paradigm=paradigm) -print("\nCompatible datasets for Motor Imagery paradigm:") -for dataset in compatible_datasets: - print(f"- {dataset.code}: {dataset.n_classes} classes") \ No newline at end of file diff --git a/moabb/evaluations/__init__.py b/moabb/evaluations/__init__.py index b5db9c26f..2a330646f 100644 --- a/moabb/evaluations/__init__.py +++ b/moabb/evaluations/__init__.py @@ -5,10 +5,10 @@ # flake8: noqa from .evaluations import ( + CrossDatasetEvaluation, CrossSessionEvaluation, CrossSubjectEvaluation, WithinSessionEvaluation, - CrossDatasetEvaluation, ) from .splitters import WithinSessionSplitter from .utils import create_save_path, save_model_cv, save_model_list diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index abc1292c3..a310ebc00 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -20,9 +20,8 @@ from moabb.evaluations.base import BaseEvaluation from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list -from mne.channels import make_standard_montage -from mne.io.constants import FIFF -import mne + + try: from codecarbon import EmissionsTracker @@ -783,8 +782,8 @@ def is_valid(self, dataset): return len(dataset.subject_list) > 1 class CrossDatasetEvaluation(BaseEvaluation): - """Evaluation class for deep learning models across datasets. - + """Evaluation class for deep learning models across different datasets. Useful for cross-dataset transfer learning. + Parameters ---------- train_dataset : Dataset or list of Dataset @@ -793,10 +792,12 @@ class CrossDatasetEvaluation(BaseEvaluation): Dataset(s) to use for testing pretrained_model : Optional[BaseEstimator] Pre-trained model to use (if None, will train from scratch) - fine_tune : bool + fine_tune : bool, default=True Whether to fine-tune the pretrained model on train_dataset - sfreq : float + sfreq : float, default=128 Target sampling frequency for all datasets + **kwargs : dict + Additional parameters passed to BaseEvaluation (paradigm, n_jobs, etc.) """ def __init__( self, @@ -813,93 +814,81 @@ def __init__( self.pretrained_model = pretrained_model self.fine_tune = fine_tune self.sfreq = sfreq - - # Validate datasets + self._validate_datasets() def _validate_datasets(self): """Validate compatibility of train and test datasets.""" all_datasets = self.train_dataset + self.test_dataset - - # Check paradigm compatibility + for dataset in all_datasets: if not self.paradigm.is_valid(dataset): raise ValueError(f"Dataset {dataset.code} not compatible with paradigm") - def _prepare_data(self, dataset): - """Prepare data from a dataset with consistent format.""" - X, y, metadata = self.paradigm.get_data( - dataset=dataset, - subjects=dataset.subject_list, - return_epochs=True - ) - - # Resample if needed - if X.info['sfreq'] != self.sfreq: - X = X.resample(self.sfreq) - - return X, y, metadata - - def evaluate(self, dataset, pipelines, groups=None, param_grid=None, **kwargs): - """Evaluate models across datasets.""" + def evaluate(self, dataset, pipelines): + """Evaluate models across datasets. + + Parameters + ---------- + dataset : Dataset + Unused but required by interface + pipelines : dict + Dictionary of pipelines to evaluate + + Yields + ------ + dict + Evaluation results containing scores and metadata + """ log.info("Starting cross-dataset evaluation") - + # Prepare training data train_X, train_y, train_metadata = [], [], [] for train_ds in self.train_dataset: - # Get raw data directly from paradigm raw, labels, events = self.paradigm.get_data( dataset=train_ds, subjects=train_ds.subject_list, return_epochs=False ) - - # Skip if no events found + if len(events) == 0: - log.warning(f"No events found in training dataset {train_ds.code}, skipping...") + log.warning(f"No events found in dataset {train_ds.code}, skipping") continue - - train_X.append(raw) # Just pass the raw data + + train_X.append(raw) train_y.extend(labels) train_metadata.append({'events': events}) - + if not train_X: raise ValueError("No valid training data found with events") - - # For each test dataset + + # Evaluate on test datasets for test_ds in self.test_dataset: - # Get raw data directly from paradigm raw, labels, events = self.paradigm.get_data( dataset=test_ds, subjects=test_ds.subject_list, return_epochs=False ) - - # Skip if no events found + if len(events) == 0: - log.warning(f"No events found in test dataset {test_ds.code}, skipping...") + log.warning(f"No events found in dataset {test_ds.code}, skipping") continue - - test_X = raw # Just pass the raw data + + test_X = raw test_y = labels - - # Evaluate each pipeline + for name, pipeline in pipelines.items(): if _carbonfootprint: tracker = EmissionsTracker(save_to_file=False, log_level="error") tracker.start() - + t_start = time() - + try: - # Train from scratch - model = clone(pipeline).fit(train_X[0], train_y) # Use first training dataset - - # Evaluate on test subjects + model = clone(pipeline).fit(train_X[0], train_y) score = model.score(test_X, test_y) - duration = time() - t_start - + result = { 'time': duration, 'dataset': test_ds, @@ -908,13 +897,13 @@ def evaluate(self, dataset, pipelines, groups=None, param_grid=None, **kwargs): 'pipeline': name, 'training_datasets': [ds.code for ds in self.train_dataset] } - + yield result - + except Exception as e: log.error(f"Error evaluating pipeline {name}: {str(e)}") raise def is_valid(self, dataset): """Check if dataset is valid for this evaluation.""" - return True # All datasets are valid for cross-dataset evaluation \ No newline at end of file + return True diff --git a/run_cross_dataset.py b/run_cross_dataset.py index 0519ecba6..e69de29bb 100644 --- a/run_cross_dataset.py +++ b/run_cross_dataset.py @@ -1 +0,0 @@ - \ No newline at end of file From 6369eed49c8c2300688e7aad13620d1067b21bd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Feb 2025 17:50:32 +0000 Subject: [PATCH 08/16] [pre-commit.ci] auto fixes from pre-commit.com hooks --- examples/cross_dataset.py | 43 +++---- examples/cross_dataset_braindecode.py | 170 ++++++++++++++++---------- moabb/evaluations/evaluations.py | 34 +++--- 3 files changed, 139 insertions(+), 108 deletions(-) diff --git a/examples/cross_dataset.py b/examples/cross_dataset.py index 41e970e22..1772b4f2f 100644 --- a/examples/cross_dataset.py +++ b/examples/cross_dataset.py @@ -30,7 +30,7 @@ # Configure logging set_log_level("WARNING") -logging.getLogger('mne').setLevel(logging.ERROR) +logging.getLogger("mne").setLevel(logging.ERROR) def get_common_channels(datasets: List[Any]) -> List[str]: @@ -78,6 +78,7 @@ def create_pipeline(common_channels: List[str]) -> Pipeline: Pipeline Sklearn pipeline for classification """ + def raw_to_data(X: np.ndarray) -> np.ndarray: """Convert raw MNE data to numpy array format. @@ -91,11 +92,9 @@ def raw_to_data(X: np.ndarray) -> np.ndarray: np.ndarray Converted data array """ - if hasattr(X, 'get_data'): + if hasattr(X, "get_data"): picks = mne.pick_channels( - X.info['ch_names'], - include=common_channels, - ordered=True + X.info["ch_names"], include=common_channels, ordered=True ) data = X.get_data() if data.ndim == 2: @@ -104,12 +103,14 @@ def raw_to_data(X: np.ndarray) -> np.ndarray: return data return X - pipeline = Pipeline([ - ('to_array', FunctionTransformer(raw_to_data)), - ('covariances', Covariances(estimator='oas')), - ('csp', CSP(nfilter=4, log=True)), - ('classifier', SVC(kernel='rbf', C=0.1)) - ]) + pipeline = Pipeline( + [ + ("to_array", FunctionTransformer(raw_to_data)), + ("covariances", Covariances(estimator="oas")), + ("csp", CSP(nfilter=4, log=True)), + ("classifier", SVC(kernel="rbf", C=0.1)), + ] + ) return pipeline @@ -123,12 +124,7 @@ def raw_to_data(X: np.ndarray) -> np.ndarray: print(f"\nCommon channels across datasets: {common_channels}\n") # Initialize the paradigm with common channels -paradigm = MotorImagery( - channels=common_channels, - n_classes=2, - fmin=8, - fmax=32 -) +paradigm = MotorImagery(channels=common_channels, n_classes=2, fmin=8, fmax=32) # Initialize the CrossDatasetEvaluation evaluation = CrossDatasetEvaluation( @@ -136,28 +132,25 @@ def raw_to_data(X: np.ndarray) -> np.ndarray: train_dataset=train_dataset, test_dataset=test_dataset, hdf5_path="./res_test", - save_model=True + save_model=True, ) # Run the evaluation results = [] for result in evaluation.evaluate( - dataset=None, - pipelines={'CSP_SVM': create_pipeline(common_channels)} + dataset=None, pipelines={"CSP_SVM": create_pipeline(common_channels)} ): - result['subject'] = 'all' + result["subject"] = "all" print(f"Cross-dataset score: {result.get('score', 'N/A'):.3f}") results.append(result) # Convert results to DataFrame and process results_df = pd.DataFrame(results) -results_df['dataset'] = results_df['dataset'].apply( - lambda x: x.__class__.__name__ -) +results_df["dataset"] = results_df["dataset"].apply(lambda x: x.__class__.__name__) # Print evaluation scores print("\nCross-dataset evaluation scores:") -print(results_df[['dataset', 'score', 'time']]) +print(results_df[["dataset", "score", "time"]]) # Plot the results score_plot(results_df) diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index 279fa6229..3e139db67 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -33,6 +33,7 @@ # Configure logging logging.basicConfig(level=logging.WARNING) + def get_common_channels(train_dataset, test_dataset): """Get channels that are available across both datasets.""" train_chans = train_dataset.datasets[0].raw.ch_names @@ -41,6 +42,7 @@ def get_common_channels(train_dataset, test_dataset): print(f"\nCommon channels across datasets: {common_channels}\n") return common_channels + def get_all_events(train_dataset, test_dataset): """Get all unique events across datasets.""" # Get events from first subject of each dataset @@ -57,7 +59,10 @@ def get_all_events(train_dataset, test_dataset): return event_id -def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) -> mne.io.Raw: + +def interpolate_missing_channels( + raw_data: mne.io.Raw, all_channels: List[str] +) -> mne.io.Raw: """Interpolate missing channels using spherical spline interpolation. Parameters @@ -97,16 +102,19 @@ def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) print(f"- {ch}") # Mark missing channels as bad - raw_data.info['bads'] = missing_channels + raw_data.info["bads"] = missing_channels # Add missing channels (temporarily with zeros) print("\nAdding temporary channels for interpolation...") - raw_data.add_channels([ - mne.io.RawArray( - np.zeros((1, len(raw_data.times))), - mne.create_info([ch], raw_data.info['sfreq'], ['eeg']) - ) for ch in missing_channels - ]) + raw_data.add_channels( + [ + mne.io.RawArray( + np.zeros((1, len(raw_data.times))), + mne.create_info([ch], raw_data.info["sfreq"], ["eeg"]), + ) + for ch in missing_channels + ] + ) # Interpolate the bad channels print("Performing spherical spline interpolation...") @@ -118,10 +126,10 @@ def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) ch_idx = raw_data.ch_names.index(ch) interpolated_data = data_after[ch_idx] stats = { - 'mean': np.mean(interpolated_data), - 'std': np.std(interpolated_data), - 'min': np.min(interpolated_data), - 'max': np.max(interpolated_data) + "mean": np.mean(interpolated_data), + "std": np.std(interpolated_data), + "min": np.min(interpolated_data), + "max": np.max(interpolated_data), } print(f"\nInterpolated channel {ch} statistics:") print(f"- Mean: {stats['mean']:.2f}") @@ -134,11 +142,12 @@ def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) return raw_data + def create_fixed_windows( dataset: BaseConcatDataset, samples_before: int, samples_after: int, - event_id: Dict[str, int] + event_id: Dict[str, int], ) -> BaseConcatDataset: """Create windows with consistent size across datasets. @@ -165,14 +174,15 @@ def create_fixed_windows( preload=True, mapping=event_id, window_size_samples=samples_before + samples_after, - window_stride_samples=samples_before + samples_after + window_stride_samples=samples_before + samples_after, ) + def standardize_windows( train_dataset: BaseConcatDataset, test_dataset: BaseConcatDataset, all_channels: List[str], - event_id: Dict[str, int] + event_id: Dict[str, int], ) -> Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int]: """Standardize datasets with consistent preprocessing. @@ -203,12 +213,13 @@ def standardize_windows( # Define preprocessing pipeline preprocessors = [ Preprocessor(lambda raw: interpolate_missing_channels(raw, all_channels)), - Preprocessor('pick_channels', ch_names=all_channels, ordered=True), - Preprocessor('resample', sfreq=target_sfreq), + Preprocessor("pick_channels", ch_names=all_channels, ordered=True), + Preprocessor("resample", sfreq=target_sfreq), Preprocessor(lambda data: np.multiply(data, 1e6)), - Preprocessor('filter', l_freq=4., h_freq=38.), - Preprocessor(exponential_moving_standardize, - factor_new=1e-3, init_block_size=1000) + Preprocessor("filter", l_freq=4.0, h_freq=38.0), + Preprocessor( + exponential_moving_standardize, factor_new=1e-3, init_block_size=1000 + ), ] # Apply preprocessing @@ -228,7 +239,7 @@ def standardize_windows( new_annotations = mne.Annotations( onset=events.onset, duration=np.zeros_like(events.duration), - description=events.description + description=events.description, ) ds.raw.set_annotations(new_annotations) @@ -252,6 +263,7 @@ def standardize_windows( window_length = train_shape[1] return train_windows, test_windows, window_length, samples_before, samples_after + # Load datasets with validation def load_and_validate_dataset(dataset_name, subject_ids): """Load dataset and validate its contents.""" @@ -268,6 +280,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): return dataset + # Load datasets with validation print("\nLoading training datasets...") train_dataset_1 = load_and_validate_dataset("BNCI2014_001", subject_ids=[1, 2, 3, 4]) @@ -279,17 +292,27 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Verify datasets are different print("\nVerifying dataset uniqueness...") -for name1, ds1 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), - ("Train3", train_dataset_3), ("Test", test_dataset)]: - for name2, ds2 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), - ("Train3", train_dataset_3), ("Test", test_dataset)]: +for name1, ds1 in [ + ("Train1", train_dataset_1), + ("Train2", train_dataset_2), + ("Train3", train_dataset_3), + ("Test", test_dataset), +]: + for name2, ds2 in [ + ("Train1", train_dataset_1), + ("Train2", train_dataset_2), + ("Train3", train_dataset_3), + ("Test", test_dataset), + ]: if name1 < name2: # Compare each pair only once print(f"\nComparing {name1} vs {name2}:") # Compare first subject of each dataset events1 = ds1.datasets[0].raw.annotations events2 = ds2.datasets[0].raw.annotations print(f"Event counts: {len(events1)} vs {len(events2)}") - print(f"Event types: {set(events1.description)} vs {set(events2.description)}") + print( + f"Event types: {set(events1.description)} vs {set(events2.description)}" + ) print(f"First timing: {events1.onset[0]} vs {events2.onset[0]}") # Get common channels across all datasets @@ -298,10 +321,14 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_chans_3 = train_dataset_3.datasets[0].raw.ch_names test_chans = test_dataset.datasets[0].raw.ch_names -common_channels = sorted(list(set(train_chans_1) - .intersection(set(train_chans_2)) - .intersection(set(train_chans_3)) - .intersection(set(test_chans)))) +common_channels = sorted( + list( + set(train_chans_1) + .intersection(set(train_chans_2)) + .intersection(set(train_chans_3)) + .intersection(set(test_chans)) + ) +) print(f"\nCommon channels across all datasets: {common_channels}\n") # Get all events across all datasets @@ -310,10 +337,14 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_events_3 = train_dataset_3.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description -all_events = sorted(list(set(train_events_1) - .union(set(train_events_2)) - .union(set(train_events_3)) - .union(set(test_events)))) +all_events = sorted( + list( + set(train_events_1) + .union(set(train_events_2)) + .union(set(train_events_3)) + .union(set(test_events)) + ) +) print(f"\nAll unique events across datasets: {all_events}\n") event_id = {str(event): idx for idx, event in enumerate(all_events)} @@ -333,25 +364,23 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_windows_3, _, _, _, _ = standardize_windows( train_dataset_3, test_dataset, common_channels, event_id ) -train_windows_test, test_windows, window_length, samples_before, samples_after = standardize_windows( - train_dataset_1, test_dataset, common_channels, event_id +train_windows_test, test_windows, window_length, samples_before, samples_after = ( + standardize_windows(train_dataset_1, test_dataset, common_channels, event_id) ) # Combine all training windows -combined_train_windows = BaseConcatDataset([ - train_windows_1, - train_windows_2, - train_windows_3 -]) +combined_train_windows = BaseConcatDataset( + [train_windows_1, train_windows_2, train_windows_3] +) # Split training data using only the combined dataset -splitted = combined_train_windows.split('session') -train_set = splitted['0train'] -valid_set = splitted['1test'] +splitted = combined_train_windows.split("session") +train_set = splitted["0train"] +valid_set = splitted["1test"] # Setup compute device cuda = torch.cuda.is_available() -device = 'cuda' if cuda else 'cpu' +device = "cuda" if cuda else "cpu" if cuda: torch.backends.cudnn.benchmark = True @@ -371,7 +400,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): filter_time_length=20, pool_time_length=35, pool_time_stride=7, - final_conv_length='auto' + final_conv_length="auto", ) if cuda: @@ -388,7 +417,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): batch_size=64, callbacks=[ "accuracy", - ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=11)), + ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=11)), ], device=device, classes=list(range(n_classes)), @@ -412,39 +441,46 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Create visualization fig = plt.figure(figsize=(10, 5)) -plt.plot(clf.history[:, 'train_loss'], label='Training Loss') -plt.plot(clf.history[:, 'valid_loss'], label='Validation Loss') +plt.plot(clf.history[:, "train_loss"], label="Training Loss") +plt.plot(clf.history[:, "valid_loss"], label="Validation Loss") plt.legend() -plt.xlabel('Epoch') -plt.ylabel('Loss') -plt.title('Training and Validation Loss Over Time') +plt.xlabel("Epoch") +plt.ylabel("Loss") +plt.title("Training and Validation Loss Over Time") plt.show() # Plot training curves -results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'] -df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns, - index=clf.history[:, 'epoch']) +results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"] +df = pd.DataFrame( + clf.history[:, results_columns], + columns=results_columns, + index=clf.history[:, "epoch"], +) -df = df.assign(train_misclass=100 - 100 * df.train_accuracy, - valid_misclass=100 - 100 * df.valid_accuracy) +df = df.assign( + train_misclass=100 - 100 * df.train_accuracy, + valid_misclass=100 - 100 * df.valid_accuracy, +) # Plot results fig, ax1 = plt.subplots(figsize=(8, 3)) -df.loc[:, ['train_loss', 'valid_loss']].plot( - ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False) +df.loc[:, ["train_loss", "valid_loss"]].plot( + ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False +) -ax1.tick_params(axis='y', labelcolor='tab:blue') -ax1.set_ylabel("Loss", color='tab:blue') +ax1.tick_params(axis="y", labelcolor="tab:blue") +ax1.set_ylabel("Loss", color="tab:blue") ax2 = ax1.twinx() -df.loc[:, ['train_misclass', 'valid_misclass']].plot( - ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False) -ax2.tick_params(axis='y', labelcolor='tab:red') -ax2.set_ylabel("Misclassification Rate [%]", color='tab:red') +df.loc[:, ["train_misclass", "valid_misclass"]].plot( + ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False +) +ax2.tick_params(axis="y", labelcolor="tab:red") +ax2.set_ylabel("Misclassification Rate [%]", color="tab:red") ax2.set_ylim(ax2.get_ylim()[0], 85) handles = [] -handles.append(Line2D([0], [0], color='black', linestyle='-', label='Train')) -handles.append(Line2D([0], [0], color='black', linestyle=':', label='Valid')) +handles.append(Line2D([0], [0], color="black", linestyle="-", label="Train")) +handles.append(Line2D([0], [0], color="black", linestyle=":", label="Valid")) plt.legend(handles, [h.get_label() for h in handles]) plt.tight_layout() diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index a310ebc00..ad4b4b6b0 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -781,6 +781,7 @@ def evaluate( def is_valid(self, dataset): return len(dataset.subject_list) > 1 + class CrossDatasetEvaluation(BaseEvaluation): """Evaluation class for deep learning models across different datasets. Useful for cross-dataset transfer learning. @@ -799,6 +800,7 @@ class CrossDatasetEvaluation(BaseEvaluation): **kwargs : dict Additional parameters passed to BaseEvaluation (paradigm, n_jobs, etc.) """ + def __init__( self, train_dataset, @@ -806,11 +808,15 @@ def __init__( pretrained_model=None, fine_tune=True, sfreq=128, - **kwargs + **kwargs, ): super().__init__(**kwargs) - self.train_dataset = train_dataset if isinstance(train_dataset, list) else [train_dataset] - self.test_dataset = test_dataset if isinstance(test_dataset, list) else [test_dataset] + self.train_dataset = ( + train_dataset if isinstance(train_dataset, list) else [train_dataset] + ) + self.test_dataset = ( + test_dataset if isinstance(test_dataset, list) else [test_dataset] + ) self.pretrained_model = pretrained_model self.fine_tune = fine_tune self.sfreq = sfreq @@ -846,9 +852,7 @@ def evaluate(self, dataset, pipelines): train_X, train_y, train_metadata = [], [], [] for train_ds in self.train_dataset: raw, labels, events = self.paradigm.get_data( - dataset=train_ds, - subjects=train_ds.subject_list, - return_epochs=False + dataset=train_ds, subjects=train_ds.subject_list, return_epochs=False ) if len(events) == 0: @@ -857,7 +861,7 @@ def evaluate(self, dataset, pipelines): train_X.append(raw) train_y.extend(labels) - train_metadata.append({'events': events}) + train_metadata.append({"events": events}) if not train_X: raise ValueError("No valid training data found with events") @@ -865,9 +869,7 @@ def evaluate(self, dataset, pipelines): # Evaluate on test datasets for test_ds in self.test_dataset: raw, labels, events = self.paradigm.get_data( - dataset=test_ds, - subjects=test_ds.subject_list, - return_epochs=False + dataset=test_ds, subjects=test_ds.subject_list, return_epochs=False ) if len(events) == 0: @@ -890,12 +892,12 @@ def evaluate(self, dataset, pipelines): duration = time() - t_start result = { - 'time': duration, - 'dataset': test_ds, - 'score': score, - 'n_samples': len(test_y), - 'pipeline': name, - 'training_datasets': [ds.code for ds in self.train_dataset] + "time": duration, + "dataset": test_ds, + "score": score, + "n_samples": len(test_y), + "pipeline": name, + "training_datasets": [ds.code for ds in self.train_dataset], } yield result From 6f83dcb8b446403eb3afe073b3e91037a6435620 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Mon, 24 Feb 2025 19:00:14 +0100 Subject: [PATCH 09/16] cross dataset eval with examples --- examples/cross_dataset_braindecode.py | 218 +++++++++++--------------- 1 file changed, 91 insertions(+), 127 deletions(-) diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index 3e139db67..caf408d9d 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -5,30 +5,29 @@ and test on another using Braindecode. """ -import logging -from typing import Dict, List, Tuple - -import matplotlib.pyplot as plt -import mne -import numpy as np -import pandas as pd -import torch -from braindecode import EEGClassifier -from braindecode.datasets import BaseConcatDataset, MOABBDataset +from braindecode.datasets import MOABBDataset +from numpy import multiply +from braindecode.preprocessing import (Preprocessor, + exponential_moving_standardize, + preprocess, + create_windows_from_events) from braindecode.models import ShallowFBCSPNet -from braindecode.preprocessing import ( - Preprocessor, - create_windows_from_events, - exponential_moving_standardize, - preprocess, -) from braindecode.util import set_random_seeds -from braindecode.visualization import plot_confusion_matrix -from matplotlib.lines import Line2D -from sklearn.metrics import confusion_matrix +from braindecode import EEGClassifier from skorch.callbacks import LRScheduler from skorch.helper import predefined_split - +import torch +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.lines import Line2D +from sklearn.metrics import confusion_matrix +from braindecode.visualization import plot_confusion_matrix +import logging +import mne +from numpy import array +import numpy as np +from braindecode.datasets import BaseConcatDataset +from typing import List, Dict, Tuple # Configure logging logging.basicConfig(level=logging.WARNING) @@ -42,27 +41,23 @@ def get_common_channels(train_dataset, test_dataset): print(f"\nCommon channels across datasets: {common_channels}\n") return common_channels - def get_all_events(train_dataset, test_dataset): """Get all unique events across datasets.""" # Get events from first subject of each dataset train_events = train_dataset.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description - + # Get all unique events all_events = sorted(list(set(train_events).union(set(test_events)))) print(f"\nAll unique events across datasets: {all_events}\n") - + # Create event mapping (event description -> numerical ID) event_id = {str(event): idx for idx, event in enumerate(all_events)} print(f"Event mapping: {event_id}\n") - + return event_id - -def interpolate_missing_channels( - raw_data: mne.io.Raw, all_channels: List[str] -) -> mne.io.Raw: +def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) -> mne.io.Raw: """Interpolate missing channels using spherical spline interpolation. Parameters @@ -102,19 +97,16 @@ def interpolate_missing_channels( print(f"- {ch}") # Mark missing channels as bad - raw_data.info["bads"] = missing_channels + raw_data.info['bads'] = missing_channels # Add missing channels (temporarily with zeros) print("\nAdding temporary channels for interpolation...") - raw_data.add_channels( - [ - mne.io.RawArray( - np.zeros((1, len(raw_data.times))), - mne.create_info([ch], raw_data.info["sfreq"], ["eeg"]), - ) - for ch in missing_channels - ] - ) + raw_data.add_channels([ + mne.io.RawArray( + np.zeros((1, len(raw_data.times))), + mne.create_info([ch], raw_data.info['sfreq'], ['eeg']) + ) for ch in missing_channels + ]) # Interpolate the bad channels print("Performing spherical spline interpolation...") @@ -126,10 +118,10 @@ def interpolate_missing_channels( ch_idx = raw_data.ch_names.index(ch) interpolated_data = data_after[ch_idx] stats = { - "mean": np.mean(interpolated_data), - "std": np.std(interpolated_data), - "min": np.min(interpolated_data), - "max": np.max(interpolated_data), + 'mean': np.mean(interpolated_data), + 'std': np.std(interpolated_data), + 'min': np.min(interpolated_data), + 'max': np.max(interpolated_data) } print(f"\nInterpolated channel {ch} statistics:") print(f"- Mean: {stats['mean']:.2f}") @@ -142,12 +134,11 @@ def interpolate_missing_channels( return raw_data - def create_fixed_windows( dataset: BaseConcatDataset, samples_before: int, samples_after: int, - event_id: Dict[str, int], + event_id: Dict[str, int] ) -> BaseConcatDataset: """Create windows with consistent size across datasets. @@ -174,15 +165,14 @@ def create_fixed_windows( preload=True, mapping=event_id, window_size_samples=samples_before + samples_after, - window_stride_samples=samples_before + samples_after, + window_stride_samples=samples_before + samples_after ) - def standardize_windows( train_dataset: BaseConcatDataset, test_dataset: BaseConcatDataset, all_channels: List[str], - event_id: Dict[str, int], + event_id: Dict[str, int] ) -> Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int]: """Standardize datasets with consistent preprocessing. @@ -213,13 +203,12 @@ def standardize_windows( # Define preprocessing pipeline preprocessors = [ Preprocessor(lambda raw: interpolate_missing_channels(raw, all_channels)), - Preprocessor("pick_channels", ch_names=all_channels, ordered=True), - Preprocessor("resample", sfreq=target_sfreq), + Preprocessor('pick_channels', ch_names=all_channels, ordered=True), + Preprocessor('resample', sfreq=target_sfreq), Preprocessor(lambda data: np.multiply(data, 1e6)), - Preprocessor("filter", l_freq=4.0, h_freq=38.0), - Preprocessor( - exponential_moving_standardize, factor_new=1e-3, init_block_size=1000 - ), + Preprocessor('filter', l_freq=4., h_freq=38.), + Preprocessor(exponential_moving_standardize, + factor_new=1e-3, init_block_size=1000) ] # Apply preprocessing @@ -239,7 +228,7 @@ def standardize_windows( new_annotations = mne.Annotations( onset=events.onset, duration=np.zeros_like(events.duration), - description=events.description, + description=events.description ) ds.raw.set_annotations(new_annotations) @@ -263,12 +252,11 @@ def standardize_windows( window_length = train_shape[1] return train_windows, test_windows, window_length, samples_before, samples_after - # Load datasets with validation def load_and_validate_dataset(dataset_name, subject_ids): """Load dataset and validate its contents.""" dataset = MOABBDataset(dataset_name=dataset_name, subject_ids=subject_ids) - + print(f"\nValidating dataset: {dataset_name}") for i, ds in enumerate(dataset.datasets): events = ds.raw.annotations @@ -277,10 +265,9 @@ def load_and_validate_dataset(dataset_name, subject_ids): print(f"Unique event types: {set(events.description)}") print(f"First 5 event timings: {events.onset[:5]}") print(f"Sample event descriptions: {list(events.description[:5])}") - + return dataset - # Load datasets with validation print("\nLoading training datasets...") train_dataset_1 = load_and_validate_dataset("BNCI2014_001", subject_ids=[1, 2, 3, 4]) @@ -292,27 +279,17 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Verify datasets are different print("\nVerifying dataset uniqueness...") -for name1, ds1 in [ - ("Train1", train_dataset_1), - ("Train2", train_dataset_2), - ("Train3", train_dataset_3), - ("Test", test_dataset), -]: - for name2, ds2 in [ - ("Train1", train_dataset_1), - ("Train2", train_dataset_2), - ("Train3", train_dataset_3), - ("Test", test_dataset), - ]: +for name1, ds1 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), + ("Train3", train_dataset_3), ("Test", test_dataset)]: + for name2, ds2 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), + ("Train3", train_dataset_3), ("Test", test_dataset)]: if name1 < name2: # Compare each pair only once print(f"\nComparing {name1} vs {name2}:") # Compare first subject of each dataset events1 = ds1.datasets[0].raw.annotations events2 = ds2.datasets[0].raw.annotations print(f"Event counts: {len(events1)} vs {len(events2)}") - print( - f"Event types: {set(events1.description)} vs {set(events2.description)}" - ) + print(f"Event types: {set(events1.description)} vs {set(events2.description)}") print(f"First timing: {events1.onset[0]} vs {events2.onset[0]}") # Get common channels across all datasets @@ -321,14 +298,10 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_chans_3 = train_dataset_3.datasets[0].raw.ch_names test_chans = test_dataset.datasets[0].raw.ch_names -common_channels = sorted( - list( - set(train_chans_1) - .intersection(set(train_chans_2)) - .intersection(set(train_chans_3)) - .intersection(set(test_chans)) - ) -) +common_channels = sorted(list(set(train_chans_1) + .intersection(set(train_chans_2)) + .intersection(set(train_chans_3)) + .intersection(set(test_chans)))) print(f"\nCommon channels across all datasets: {common_channels}\n") # Get all events across all datasets @@ -337,14 +310,10 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_events_3 = train_dataset_3.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description -all_events = sorted( - list( - set(train_events_1) - .union(set(train_events_2)) - .union(set(train_events_3)) - .union(set(test_events)) - ) -) +all_events = sorted(list(set(train_events_1) + .union(set(train_events_2)) + .union(set(train_events_3)) + .union(set(test_events)))) print(f"\nAll unique events across datasets: {all_events}\n") event_id = {str(event): idx for idx, event in enumerate(all_events)} @@ -364,23 +333,25 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_windows_3, _, _, _, _ = standardize_windows( train_dataset_3, test_dataset, common_channels, event_id ) -train_windows_test, test_windows, window_length, samples_before, samples_after = ( - standardize_windows(train_dataset_1, test_dataset, common_channels, event_id) +train_windows_test, test_windows, window_length, samples_before, samples_after = standardize_windows( + train_dataset_1, test_dataset, common_channels, event_id ) # Combine all training windows -combined_train_windows = BaseConcatDataset( - [train_windows_1, train_windows_2, train_windows_3] -) +combined_train_windows = BaseConcatDataset([ + train_windows_1, + train_windows_2, + train_windows_3 +]) # Split training data using only the combined dataset -splitted = combined_train_windows.split("session") -train_set = splitted["0train"] -valid_set = splitted["1test"] +split = combined_train_windows.split('session') +train_set = split['0train'] +valid_set = split['1test'] # Setup compute device cuda = torch.cuda.is_available() -device = "cuda" if cuda else "cpu" +device = 'cuda' if cuda else 'cpu' if cuda: torch.backends.cudnn.benchmark = True @@ -400,7 +371,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): filter_time_length=20, pool_time_length=35, pool_time_stride=7, - final_conv_length="auto", + final_conv_length='auto' ) if cuda: @@ -417,7 +388,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): batch_size=64, callbacks=[ "accuracy", - ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=11)), + ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=11)), ], device=device, classes=list(range(n_classes)), @@ -441,46 +412,39 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Create visualization fig = plt.figure(figsize=(10, 5)) -plt.plot(clf.history[:, "train_loss"], label="Training Loss") -plt.plot(clf.history[:, "valid_loss"], label="Validation Loss") +plt.plot(clf.history[:, 'train_loss'], label='Training Loss') +plt.plot(clf.history[:, 'valid_loss'], label='Validation Loss') plt.legend() -plt.xlabel("Epoch") -plt.ylabel("Loss") -plt.title("Training and Validation Loss Over Time") +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Training and Validation Loss Over Time') plt.show() # Plot training curves -results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"] -df = pd.DataFrame( - clf.history[:, results_columns], - columns=results_columns, - index=clf.history[:, "epoch"], -) +results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'] +df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns, + index=clf.history[:, 'epoch']) -df = df.assign( - train_misclass=100 - 100 * df.train_accuracy, - valid_misclass=100 - 100 * df.valid_accuracy, -) +df = df.assign(train_misclass=100 - 100 * df.train_accuracy, + valid_misclass=100 - 100 * df.valid_accuracy) # Plot results fig, ax1 = plt.subplots(figsize=(8, 3)) -df.loc[:, ["train_loss", "valid_loss"]].plot( - ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False -) +df.loc[:, ['train_loss', 'valid_loss']].plot( + ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False) -ax1.tick_params(axis="y", labelcolor="tab:blue") -ax1.set_ylabel("Loss", color="tab:blue") +ax1.tick_params(axis='y', labelcolor='tab:blue') +ax1.set_ylabel("Loss", color='tab:blue') ax2 = ax1.twinx() -df.loc[:, ["train_misclass", "valid_misclass"]].plot( - ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False -) -ax2.tick_params(axis="y", labelcolor="tab:red") -ax2.set_ylabel("Misclassification Rate [%]", color="tab:red") +df.loc[:, ['train_misclass', 'valid_misclass']].plot( + ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False) +ax2.tick_params(axis='y', labelcolor='tab:red') +ax2.set_ylabel("Misclassification Rate [%]", color='tab:red') ax2.set_ylim(ax2.get_ylim()[0], 85) handles = [] -handles.append(Line2D([0], [0], color="black", linestyle="-", label="Train")) -handles.append(Line2D([0], [0], color="black", linestyle=":", label="Valid")) +handles.append(Line2D([0], [0], color='black', linestyle='-', label='Train')) +handles.append(Line2D([0], [0], color='black', linestyle=':', label='Valid')) plt.legend(handles, [h.get_label() for h in handles]) plt.tight_layout() From 2cec69c70e0cb8776bba28c39ed0eb2522ec3c7a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Feb 2025 18:00:35 +0000 Subject: [PATCH 10/16] [pre-commit.ci] auto fixes from pre-commit.com hooks --- examples/cross_dataset_braindecode.py | 218 +++++++++++++++----------- 1 file changed, 127 insertions(+), 91 deletions(-) diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index caf408d9d..72cdf7ca8 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -5,29 +5,30 @@ and test on another using Braindecode. """ -from braindecode.datasets import MOABBDataset -from numpy import multiply -from braindecode.preprocessing import (Preprocessor, - exponential_moving_standardize, - preprocess, - create_windows_from_events) -from braindecode.models import ShallowFBCSPNet -from braindecode.util import set_random_seeds -from braindecode import EEGClassifier -from skorch.callbacks import LRScheduler -from skorch.helper import predefined_split -import torch +import logging +from typing import Dict, List, Tuple + import matplotlib.pyplot as plt +import mne +import numpy as np import pandas as pd +import torch +from braindecode import EEGClassifier +from braindecode.datasets import BaseConcatDataset, MOABBDataset +from braindecode.models import ShallowFBCSPNet +from braindecode.preprocessing import ( + Preprocessor, + create_windows_from_events, + exponential_moving_standardize, + preprocess, +) +from braindecode.util import set_random_seeds +from braindecode.visualization import plot_confusion_matrix from matplotlib.lines import Line2D from sklearn.metrics import confusion_matrix -from braindecode.visualization import plot_confusion_matrix -import logging -import mne -from numpy import array -import numpy as np -from braindecode.datasets import BaseConcatDataset -from typing import List, Dict, Tuple +from skorch.callbacks import LRScheduler +from skorch.helper import predefined_split + # Configure logging logging.basicConfig(level=logging.WARNING) @@ -41,23 +42,27 @@ def get_common_channels(train_dataset, test_dataset): print(f"\nCommon channels across datasets: {common_channels}\n") return common_channels + def get_all_events(train_dataset, test_dataset): """Get all unique events across datasets.""" # Get events from first subject of each dataset train_events = train_dataset.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description - + # Get all unique events all_events = sorted(list(set(train_events).union(set(test_events)))) print(f"\nAll unique events across datasets: {all_events}\n") - + # Create event mapping (event description -> numerical ID) event_id = {str(event): idx for idx, event in enumerate(all_events)} print(f"Event mapping: {event_id}\n") - + return event_id -def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) -> mne.io.Raw: + +def interpolate_missing_channels( + raw_data: mne.io.Raw, all_channels: List[str] +) -> mne.io.Raw: """Interpolate missing channels using spherical spline interpolation. Parameters @@ -97,16 +102,19 @@ def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) print(f"- {ch}") # Mark missing channels as bad - raw_data.info['bads'] = missing_channels + raw_data.info["bads"] = missing_channels # Add missing channels (temporarily with zeros) print("\nAdding temporary channels for interpolation...") - raw_data.add_channels([ - mne.io.RawArray( - np.zeros((1, len(raw_data.times))), - mne.create_info([ch], raw_data.info['sfreq'], ['eeg']) - ) for ch in missing_channels - ]) + raw_data.add_channels( + [ + mne.io.RawArray( + np.zeros((1, len(raw_data.times))), + mne.create_info([ch], raw_data.info["sfreq"], ["eeg"]), + ) + for ch in missing_channels + ] + ) # Interpolate the bad channels print("Performing spherical spline interpolation...") @@ -118,10 +126,10 @@ def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) ch_idx = raw_data.ch_names.index(ch) interpolated_data = data_after[ch_idx] stats = { - 'mean': np.mean(interpolated_data), - 'std': np.std(interpolated_data), - 'min': np.min(interpolated_data), - 'max': np.max(interpolated_data) + "mean": np.mean(interpolated_data), + "std": np.std(interpolated_data), + "min": np.min(interpolated_data), + "max": np.max(interpolated_data), } print(f"\nInterpolated channel {ch} statistics:") print(f"- Mean: {stats['mean']:.2f}") @@ -134,11 +142,12 @@ def interpolate_missing_channels(raw_data: mne.io.Raw, all_channels: List[str]) return raw_data + def create_fixed_windows( dataset: BaseConcatDataset, samples_before: int, samples_after: int, - event_id: Dict[str, int] + event_id: Dict[str, int], ) -> BaseConcatDataset: """Create windows with consistent size across datasets. @@ -165,14 +174,15 @@ def create_fixed_windows( preload=True, mapping=event_id, window_size_samples=samples_before + samples_after, - window_stride_samples=samples_before + samples_after + window_stride_samples=samples_before + samples_after, ) + def standardize_windows( train_dataset: BaseConcatDataset, test_dataset: BaseConcatDataset, all_channels: List[str], - event_id: Dict[str, int] + event_id: Dict[str, int], ) -> Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int]: """Standardize datasets with consistent preprocessing. @@ -203,12 +213,13 @@ def standardize_windows( # Define preprocessing pipeline preprocessors = [ Preprocessor(lambda raw: interpolate_missing_channels(raw, all_channels)), - Preprocessor('pick_channels', ch_names=all_channels, ordered=True), - Preprocessor('resample', sfreq=target_sfreq), + Preprocessor("pick_channels", ch_names=all_channels, ordered=True), + Preprocessor("resample", sfreq=target_sfreq), Preprocessor(lambda data: np.multiply(data, 1e6)), - Preprocessor('filter', l_freq=4., h_freq=38.), - Preprocessor(exponential_moving_standardize, - factor_new=1e-3, init_block_size=1000) + Preprocessor("filter", l_freq=4.0, h_freq=38.0), + Preprocessor( + exponential_moving_standardize, factor_new=1e-3, init_block_size=1000 + ), ] # Apply preprocessing @@ -228,7 +239,7 @@ def standardize_windows( new_annotations = mne.Annotations( onset=events.onset, duration=np.zeros_like(events.duration), - description=events.description + description=events.description, ) ds.raw.set_annotations(new_annotations) @@ -252,11 +263,12 @@ def standardize_windows( window_length = train_shape[1] return train_windows, test_windows, window_length, samples_before, samples_after + # Load datasets with validation def load_and_validate_dataset(dataset_name, subject_ids): """Load dataset and validate its contents.""" dataset = MOABBDataset(dataset_name=dataset_name, subject_ids=subject_ids) - + print(f"\nValidating dataset: {dataset_name}") for i, ds in enumerate(dataset.datasets): events = ds.raw.annotations @@ -265,9 +277,10 @@ def load_and_validate_dataset(dataset_name, subject_ids): print(f"Unique event types: {set(events.description)}") print(f"First 5 event timings: {events.onset[:5]}") print(f"Sample event descriptions: {list(events.description[:5])}") - + return dataset + # Load datasets with validation print("\nLoading training datasets...") train_dataset_1 = load_and_validate_dataset("BNCI2014_001", subject_ids=[1, 2, 3, 4]) @@ -279,17 +292,27 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Verify datasets are different print("\nVerifying dataset uniqueness...") -for name1, ds1 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), - ("Train3", train_dataset_3), ("Test", test_dataset)]: - for name2, ds2 in [("Train1", train_dataset_1), ("Train2", train_dataset_2), - ("Train3", train_dataset_3), ("Test", test_dataset)]: +for name1, ds1 in [ + ("Train1", train_dataset_1), + ("Train2", train_dataset_2), + ("Train3", train_dataset_3), + ("Test", test_dataset), +]: + for name2, ds2 in [ + ("Train1", train_dataset_1), + ("Train2", train_dataset_2), + ("Train3", train_dataset_3), + ("Test", test_dataset), + ]: if name1 < name2: # Compare each pair only once print(f"\nComparing {name1} vs {name2}:") # Compare first subject of each dataset events1 = ds1.datasets[0].raw.annotations events2 = ds2.datasets[0].raw.annotations print(f"Event counts: {len(events1)} vs {len(events2)}") - print(f"Event types: {set(events1.description)} vs {set(events2.description)}") + print( + f"Event types: {set(events1.description)} vs {set(events2.description)}" + ) print(f"First timing: {events1.onset[0]} vs {events2.onset[0]}") # Get common channels across all datasets @@ -298,10 +321,14 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_chans_3 = train_dataset_3.datasets[0].raw.ch_names test_chans = test_dataset.datasets[0].raw.ch_names -common_channels = sorted(list(set(train_chans_1) - .intersection(set(train_chans_2)) - .intersection(set(train_chans_3)) - .intersection(set(test_chans)))) +common_channels = sorted( + list( + set(train_chans_1) + .intersection(set(train_chans_2)) + .intersection(set(train_chans_3)) + .intersection(set(test_chans)) + ) +) print(f"\nCommon channels across all datasets: {common_channels}\n") # Get all events across all datasets @@ -310,10 +337,14 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_events_3 = train_dataset_3.datasets[0].raw.annotations.description test_events = test_dataset.datasets[0].raw.annotations.description -all_events = sorted(list(set(train_events_1) - .union(set(train_events_2)) - .union(set(train_events_3)) - .union(set(test_events)))) +all_events = sorted( + list( + set(train_events_1) + .union(set(train_events_2)) + .union(set(train_events_3)) + .union(set(test_events)) + ) +) print(f"\nAll unique events across datasets: {all_events}\n") event_id = {str(event): idx for idx, event in enumerate(all_events)} @@ -333,25 +364,23 @@ def load_and_validate_dataset(dataset_name, subject_ids): train_windows_3, _, _, _, _ = standardize_windows( train_dataset_3, test_dataset, common_channels, event_id ) -train_windows_test, test_windows, window_length, samples_before, samples_after = standardize_windows( - train_dataset_1, test_dataset, common_channels, event_id +train_windows_test, test_windows, window_length, samples_before, samples_after = ( + standardize_windows(train_dataset_1, test_dataset, common_channels, event_id) ) # Combine all training windows -combined_train_windows = BaseConcatDataset([ - train_windows_1, - train_windows_2, - train_windows_3 -]) +combined_train_windows = BaseConcatDataset( + [train_windows_1, train_windows_2, train_windows_3] +) # Split training data using only the combined dataset -split = combined_train_windows.split('session') -train_set = split['0train'] -valid_set = split['1test'] +split = combined_train_windows.split("session") +train_set = split["0train"] +valid_set = split["1test"] # Setup compute device cuda = torch.cuda.is_available() -device = 'cuda' if cuda else 'cpu' +device = "cuda" if cuda else "cpu" if cuda: torch.backends.cudnn.benchmark = True @@ -371,7 +400,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): filter_time_length=20, pool_time_length=35, pool_time_stride=7, - final_conv_length='auto' + final_conv_length="auto", ) if cuda: @@ -388,7 +417,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): batch_size=64, callbacks=[ "accuracy", - ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=11)), + ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=11)), ], device=device, classes=list(range(n_classes)), @@ -412,39 +441,46 @@ def load_and_validate_dataset(dataset_name, subject_ids): # Create visualization fig = plt.figure(figsize=(10, 5)) -plt.plot(clf.history[:, 'train_loss'], label='Training Loss') -plt.plot(clf.history[:, 'valid_loss'], label='Validation Loss') +plt.plot(clf.history[:, "train_loss"], label="Training Loss") +plt.plot(clf.history[:, "valid_loss"], label="Validation Loss") plt.legend() -plt.xlabel('Epoch') -plt.ylabel('Loss') -plt.title('Training and Validation Loss Over Time') +plt.xlabel("Epoch") +plt.ylabel("Loss") +plt.title("Training and Validation Loss Over Time") plt.show() # Plot training curves -results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'] -df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns, - index=clf.history[:, 'epoch']) +results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"] +df = pd.DataFrame( + clf.history[:, results_columns], + columns=results_columns, + index=clf.history[:, "epoch"], +) -df = df.assign(train_misclass=100 - 100 * df.train_accuracy, - valid_misclass=100 - 100 * df.valid_accuracy) +df = df.assign( + train_misclass=100 - 100 * df.train_accuracy, + valid_misclass=100 - 100 * df.valid_accuracy, +) # Plot results fig, ax1 = plt.subplots(figsize=(8, 3)) -df.loc[:, ['train_loss', 'valid_loss']].plot( - ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False) +df.loc[:, ["train_loss", "valid_loss"]].plot( + ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False +) -ax1.tick_params(axis='y', labelcolor='tab:blue') -ax1.set_ylabel("Loss", color='tab:blue') +ax1.tick_params(axis="y", labelcolor="tab:blue") +ax1.set_ylabel("Loss", color="tab:blue") ax2 = ax1.twinx() -df.loc[:, ['train_misclass', 'valid_misclass']].plot( - ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False) -ax2.tick_params(axis='y', labelcolor='tab:red') -ax2.set_ylabel("Misclassification Rate [%]", color='tab:red') +df.loc[:, ["train_misclass", "valid_misclass"]].plot( + ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False +) +ax2.tick_params(axis="y", labelcolor="tab:red") +ax2.set_ylabel("Misclassification Rate [%]", color="tab:red") ax2.set_ylim(ax2.get_ylim()[0], 85) handles = [] -handles.append(Line2D([0], [0], color='black', linestyle='-', label='Train')) -handles.append(Line2D([0], [0], color='black', linestyle=':', label='Valid')) +handles.append(Line2D([0], [0], color="black", linestyle="-", label="Train")) +handles.append(Line2D([0], [0], color="black", linestyle=":", label="Valid")) plt.legend(handles, [h.get_label() for h in handles]) plt.tight_layout() From 023a0a73a78d06b0d713262f0c89f7d697121e25 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Mon, 3 Mar 2025 20:11:17 +0100 Subject: [PATCH 11/16] fix: resolve merge conflicts in cross dataset example --- examples/cross_dataset_braindecode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index 72cdf7ca8..cfa6c06c2 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -296,7 +296,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): ("Train1", train_dataset_1), ("Train2", train_dataset_2), ("Train3", train_dataset_3), - ("Test", test_dataset), + ("Test", test_dataset) ]: for name2, ds2 in [ ("Train1", train_dataset_1), From 4a8bcb915ba56ebca04de877bedc56bb33338c4e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Mar 2025 19:11:36 +0000 Subject: [PATCH 12/16] [pre-commit.ci] auto fixes from pre-commit.com hooks --- examples/cross_dataset_braindecode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py index cfa6c06c2..72cdf7ca8 100644 --- a/examples/cross_dataset_braindecode.py +++ b/examples/cross_dataset_braindecode.py @@ -296,7 +296,7 @@ def load_and_validate_dataset(dataset_name, subject_ids): ("Train1", train_dataset_1), ("Train2", train_dataset_2), ("Train3", train_dataset_3), - ("Test", test_dataset) + ("Test", test_dataset), ]: for name2, ds2 in [ ("Train1", train_dataset_1), From 96a678a023e0875cc144644232373570b599b0a7 Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Mon, 3 Mar 2025 20:36:46 +0100 Subject: [PATCH 13/16] added tests and edited changelog --- docs/source/whats_new.rst | 2 ++ moabb/tests/test_evaluations.py | 61 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index ea5c96109..ceae77543 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -17,6 +17,7 @@ Develop branch Enhancements ~~~~~~~~~~~~ +- Adding :class:`moabb.evaluations.CrossDatasetEvaluation` for cross-dataset evaluation, enabling training on one dataset and testing on another (:gh:`703` by `Ali Imran`_) - Adding :class:`moabb.evaluations.splitters.WithinSessionSplitter` (:gh:`664` by `Bruna Lopes_`) - Update version of pyRiemann to 0.7 (:gh:`671` by `Gregoire Cattan`_) @@ -516,3 +517,4 @@ API changes .. _Yash Chauhan: https://github.com/jiggychauhi .. _Taha Habib: https://github.com/tahatt13 .. _AFF: https://github.com/allwaysFindFood +.. _Ali Imran: https://github.com/ali-sehar diff --git a/moabb/tests/test_evaluations.py b/moabb/tests/test_evaluations.py index 0dc4f98c0..87d04febc 100644 --- a/moabb/tests/test_evaluations.py +++ b/moabb/tests/test_evaluations.py @@ -382,6 +382,67 @@ def test_compatible_dataset(self): assert self.eval.is_valid(dataset=ds) +class Test_CrossDataset: + """Test CrossDatasetEvaluation class.""" + + def setup_method(self): + self.train_ds = FakeDataset( + paradigm="imagery", + event_list=["left_hand", "right_hand"], + n_subjects=2, + n_sessions=2, + ) + self.test_ds = FakeDataset( + paradigm="imagery", + event_list=["left_hand", "right_hand"], + n_subjects=1, # Different number of subjects + n_sessions=3, # Different number of sessions + ) + self.paradigm = FakeImageryParadigm() + self.eval = ev.CrossDatasetEvaluation( + train_dataset=self.train_ds, + test_dataset=self.test_ds, + paradigm=self.paradigm, + ) + + self.test_pipelines = OrderedDict() + self.test_pipelines["dummy"] = make_pipeline( + FunctionTransformer(), + Dummy(strategy="stratified", random_state=42) + ) + + def test_validate_datasets(self): + """Test dataset validation.""" + # Test with compatible dataset + assert self.eval.is_valid(self.test_ds) + + # Test with incompatible dataset (different events) + incompatible_ds = FakeDataset( + paradigm="imagery", + event_list=["left_hand"], # Only one class + n_subjects=1, + ) + + # This should work but return a warning + eval_incompatible = ev.CrossDatasetEvaluation( + train_dataset=self.train_ds, + test_dataset=incompatible_ds, + paradigm=self.paradigm, + ) + assert eval_incompatible.is_valid(incompatible_ds) + + def test_evaluate(self): + """Test basic evaluation functionality.""" + results = list(self.eval.evaluate(dataset=None, pipelines=self.test_pipelines)) + + # Check results structure + assert len(results) > 0 + assert "score" in results[0] + assert "training_datasets" in results[0] + assert isinstance(results[0]["training_datasets"], list) + assert self.train_ds.code in results[0]["training_datasets"] + + class UtilEvaluation: def test_save_model_cv(self): model = Dummy() From d6d1d0079d0cfb47e405f089da6d8789e5cc8667 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Mar 2025 19:37:04 +0000 Subject: [PATCH 14/16] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/tests/test_evaluations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/moabb/tests/test_evaluations.py b/moabb/tests/test_evaluations.py index 87d04febc..1bf2ce1c5 100644 --- a/moabb/tests/test_evaluations.py +++ b/moabb/tests/test_evaluations.py @@ -407,8 +407,7 @@ def setup_method(self): self.test_pipelines = OrderedDict() self.test_pipelines["dummy"] = make_pipeline( - FunctionTransformer(), - Dummy(strategy="stratified", random_state=42) + FunctionTransformer(), Dummy(strategy="stratified", random_state=42) ) def test_validate_datasets(self): From 0f90c4bcb86a3915be06a5f62d6b11932a8e2f0f Mon Sep 17 00:00:00 2001 From: ali-sehar Date: Tue, 11 Mar 2025 20:27:12 +0100 Subject: [PATCH 15/16] Using match all --- docs/source/whats_new.rst | 3 +- .../plot_cross_dataset.py} | 105 ++- .../plot_cross_dataset_braindecode.py | 650 ++++++++++++++++++ examples/cross_dataset_braindecode.py | 486 ------------- 4 files changed, 719 insertions(+), 525 deletions(-) rename examples/{cross_dataset.py => advanced_examples/plot_cross_dataset.py} (53%) create mode 100644 examples/advanced_examples/plot_cross_dataset_braindecode.py delete mode 100644 examples/cross_dataset_braindecode.py diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 93c5ccaa3..d5e8bf5b3 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -22,7 +22,7 @@ Enhancements - Creating the meta information for the BIDS converted datasets (:gh:`688` by `Bruno Aristimunha`_) - Adding :class:`moabb.dataset.Beetl2021A` and :class:`moabb.dataset.Beetl2021B`(:gh:`675` by `Samuel Boehm_`) - Adding :class:`moabb.dataset.base.BaseBIDSDataset` and :class:`moabb.dataset.base.LocalBIDSDataset` (:gh:`724` by `Pierre Guetschel`_) - +- Adding :class:`moabb.evaluations.CrossDatasetEvaluation` for cross-dataset evaluation, enabling training on one dataset and testing on another (:gh:`703` by `Ali Imran`_) Bugs ~~~~ @@ -544,3 +544,4 @@ API changes .. _AFF: https://github.com/allwaysFindFood .. _Marco Congedo: https://github.com/Marco-Congedo .. _Samuel Boehm: https://github.com/Samuel-Boehm +.. _Ali Imran: https://github.com/EazyAl diff --git a/examples/cross_dataset.py b/examples/advanced_examples/plot_cross_dataset.py similarity index 53% rename from examples/cross_dataset.py rename to examples/advanced_examples/plot_cross_dataset.py index 1772b4f2f..9bc0c2ec8 100644 --- a/examples/cross_dataset.py +++ b/examples/advanced_examples/plot_cross_dataset.py @@ -6,14 +6,15 @@ # Standard library imports import logging -from typing import Any, List - -import matplotlib.pyplot as plt +from typing import List # Third-party imports +import matplotlib.pyplot as plt import mne import numpy as np import pandas as pd +from mne.io import RawArray +from mne.io.cnt.cnt import RawCNT from pyriemann.estimation import Covariances from pyriemann.spatialfilters import CSP from sklearn.pipeline import Pipeline @@ -33,38 +34,6 @@ logging.getLogger("mne").setLevel(logging.ERROR) -def get_common_channels(datasets: List[Any]) -> List[str]: - """Get channels that are available across all datasets. - - Parameters - ---------- - datasets : List[Dataset] - List of MOABB dataset objects to analyze - - Returns - ------- - List[str] - Sorted list of common channel names - """ - all_channels = [] - for dataset in datasets: - # Get a sample raw from each dataset - subject = dataset.subject_list[0] - raw_dict = dataset.get_data([subject]) - - # Navigate through the nested dictionary structure - subject_data = raw_dict[subject] - first_session = list(subject_data.keys())[0] - first_run = list(subject_data[first_session].keys())[0] - raw = subject_data[first_session][first_run] - - all_channels.append(raw.ch_names) - - # Find common channels across all datasets - common_channels = set.intersection(*map(set, all_channels)) - return sorted(list(common_channels)) - - def create_pipeline(common_channels: List[str]) -> Pipeline: """Create classification pipeline with CSP and SVM. @@ -119,9 +88,69 @@ def raw_to_data(X: np.ndarray) -> np.ndarray: train_dataset = BNCI2014001() test_dataset = Zhou2016() -# Get common channels across datasets -common_channels = get_common_channels([train_dataset, test_dataset]) -print(f"\nCommon channels across datasets: {common_channels}\n") +# Create a dictionary of datasets for easier handling +datasets_dict = {"train_dataset": train_dataset, "test_dataset": test_dataset} + +# Get the list of channels from each dataset before matching +print("\nChannels before matching:") +for ds_name, ds in datasets_dict.items(): + try: + # Load data for first subject to get channel information + data = ds.get_data([ds.subject_list[0]]) # Get data for first subject + first_subject = list(data.keys())[0] + first_session = list(data[first_subject].keys())[0] + first_run = list(data[first_subject][first_session].keys())[0] + run_data = data[first_subject][first_session][first_run] + + if isinstance(run_data, (RawArray, RawCNT)): + channels = run_data.info["ch_names"] + else: + # Assuming the channels are stored in the dataset class after loading + channels = ds.channels + print(f"{ds_name}: {channels}") + except Exception as e: + print(f"Error getting channels for {ds_name}: {str(e)}") + +# Use MOABB's match_all for channel handling +print("\nMatching channels across datasets...") +paradigm = MotorImagery() + +# Apply match_all to all datasets +all_datasets = list(datasets_dict.values()) +paradigm.match_all(all_datasets, channel_merge_strategy="intersect") + +# Get channels from all datasets after matching to ensure we have the correct intersection +all_channels_after_matching = [] +print("\nChannels after matching:") +for i, (ds_name, _) in enumerate(datasets_dict.items()): + ds = all_datasets[i] # Get the matched dataset + try: + data = ds.get_data([ds.subject_list[0]]) + subject = list(data.keys())[0] + session = list(data[subject].keys())[0] + run = list(data[subject][session].keys())[0] + run_data = data[subject][session][run] + + if isinstance(run_data, (RawArray, RawCNT)): + channels = run_data.info["ch_names"] + else: + channels = ds.channels + all_channels_after_matching.append(set(channels)) + print(f"{ds_name}: {channels}") + except Exception as e: + print(f"Error getting channels for {ds_name} after matching: {str(e)}") + +# Get the intersection of all channel sets +common_channels = sorted(list(set.intersection(*all_channels_after_matching))) +print(f"\nCommon channels after matching: {common_channels}") +print(f"Number of common channels: {len(common_channels)}") + +# Update the datasets_dict with the matched datasets +for i, (name, _) in enumerate(datasets_dict.items()): + datasets_dict[name] = all_datasets[i] + +train_dataset = datasets_dict["train_dataset"] +test_dataset = datasets_dict["test_dataset"] # Initialize the paradigm with common channels paradigm = MotorImagery(channels=common_channels, n_classes=2, fmin=8, fmax=32) diff --git a/examples/advanced_examples/plot_cross_dataset_braindecode.py b/examples/advanced_examples/plot_cross_dataset_braindecode.py new file mode 100644 index 000000000..8968e2a45 --- /dev/null +++ b/examples/advanced_examples/plot_cross_dataset_braindecode.py @@ -0,0 +1,650 @@ +""" +Cross-Dataset Brain Decoding with Deep Learning +============================================= +This example shows how to train deep learning models on one dataset +and test on another using Braindecode. +""" + +import logging +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt +import mne +import numpy as np +import pandas as pd +import torch +from braindecode import EEGClassifier +from braindecode.datasets import BaseConcatDataset, BaseDataset +from braindecode.models import ShallowFBCSPNet +from braindecode.preprocessing import ( + Preprocessor, + create_windows_from_events, + exponential_moving_standardize, + preprocess, +) +from braindecode.util import set_random_seeds +from braindecode.visualization import plot_confusion_matrix +from matplotlib.lines import Line2D +from mne.io import RawArray +from mne.io.cnt.cnt import RawCNT +from sklearn.metrics import confusion_matrix +from skorch.callbacks import LRScheduler +from skorch.helper import predefined_split + +from moabb.datasets import ( + BNCI2014_001, + BNCI2014_004, + BNCI2015_001, + Zhou2016, +) +from moabb.paradigms import MotorImagery + + +# Configure logging +logging.basicConfig(level=logging.WARNING) + + +def get_all_events(train_dataset, test_dataset): + """Get all unique events across datasets.""" + # Get events from first subject of each dataset + train_events = train_dataset.datasets[0].raw.annotations.description + test_events = test_dataset.datasets[0].raw.annotations.description + + # Get all unique events + all_events = sorted(list(set(train_events).union(set(test_events)))) + print(f"\nAll unique events across datasets: {all_events}\n") + + # Create event mapping (event description -> numerical ID) + event_id = {str(event): idx for idx, event in enumerate(all_events)} + print(f"Event mapping: {event_id}\n") + + return event_id + + +def create_fixed_windows( + dataset: BaseConcatDataset, + samples_before: int, + samples_after: int, + event_id: Dict[str, int], +) -> BaseConcatDataset: + """Create windows with consistent size across datasets. + + Parameters + ---------- + dataset : BaseConcatDataset + Dataset to create windows from + samples_before : int + Number of samples before event + samples_after : int + Number of samples after event + event_id : Dict[str, int] + Mapping of event names to numerical IDs + + Returns + ------- + BaseConcatDataset + Windowed dataset + """ + return create_windows_from_events( + dataset, + trial_start_offset_samples=-samples_before, + trial_stop_offset_samples=samples_after, + preload=True, + mapping=event_id, + window_size_samples=samples_before + samples_after, + window_stride_samples=samples_before + samples_after, + ) + + +def standardize_windows( + train_dataset: BaseConcatDataset, + test_dataset: BaseConcatDataset, + all_channels: List[str], + event_id: Dict[str, int], +) -> Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int]: + """Standardize datasets with consistent preprocessing. + + Parameters + ---------- + train_dataset : BaseConcatDataset + Training dataset to standardize + test_dataset : BaseConcatDataset + Test dataset to standardize + all_channels : List[str] + List of all required channel names + event_id : Dict[str, int] + Mapping of event names to numerical IDs + + Returns + ------- + Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int] + Processed training windows, test windows, window length, + samples before and after + """ + target_sfreq = 100 # Target sampling frequency + + print("\nInitial dataset properties:") + for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: + for i, ds in enumerate(dataset.datasets): + print(f"{name} dataset {i}:") + print(f" Sampling rate: {ds.raw.info['sfreq']} Hz") + print(f" Number of channels: {len(ds.raw.info['ch_names'])}") + print(f" Channel names: {ds.raw.info['ch_names']}") + + # Get the actual available channels (intersection of all datasets) + available_channels = set(train_dataset.datasets[0].raw.info['ch_names']) + for ds in train_dataset.datasets[1:] + test_dataset.datasets: + available_channels = available_channels.intersection(ds.raw.info['ch_names']) + available_channels = sorted(list(available_channels)) + + print(f"\nCommon channels across all datasets ({len(available_channels)}): {available_channels}") + + # Verify all datasets have the same number of channels + for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: + for i, ds in enumerate(dataset.datasets): + if len(ds.raw.info['ch_names']) != len(available_channels): + print(f"Warning: {name} dataset {i} has {len(ds.raw.info['ch_names'])} channels, " + f"expected {len(available_channels)}") + + # Define preprocessing pipeline using only available channels + preprocessors = [ + Preprocessor("pick_channels", ch_names=available_channels, ordered=True), + Preprocessor("resample", sfreq=target_sfreq), + Preprocessor(lambda data: np.multiply(data, 1e6)), + Preprocessor("filter", l_freq=4.0, h_freq=38.0), + Preprocessor( + exponential_moving_standardize, factor_new=1e-3, init_block_size=1000 + ), + ] + + # Apply preprocessing + print("\nApplying preprocessing...") + preprocess(train_dataset, preprocessors, n_jobs=-1) + preprocess(test_dataset, preprocessors, n_jobs=-1) + + # Verify channel counts after preprocessing + print("\nVerifying channel counts after preprocessing:") + for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: + for i, ds in enumerate(dataset.datasets): + n_channels = len(ds.raw.info['ch_names']) + print(f"{name} dataset {i} has {n_channels} channels") + if n_channels != len(available_channels): + raise ValueError(f"Channel count mismatch in {name} dataset {i}: " + f"got {n_channels}, expected {len(available_channels)}") + + # Define window parameters + window_start = -0.5 # Start 0.5s before event + window_duration = 4.0 # Window duration in seconds + samples_before = int(abs(window_start) * target_sfreq) + samples_after = int(window_duration * target_sfreq) + + # Standardize event durations + for dataset in [train_dataset, test_dataset]: + for ds in dataset.datasets: + events = ds.raw.annotations[:-1] + new_annotations = mne.Annotations( + onset=events.onset, + duration=np.zeros_like(events.duration), + description=events.description, + ) + ds.raw.set_annotations(new_annotations) + + # Create and validate windows + print("\nCreating windows...") + train_windows = create_fixed_windows( + train_dataset, samples_before, samples_after, event_id + ) + test_windows = create_fixed_windows( + test_dataset, samples_before, samples_after, event_id + ) + + # Verify window shapes + train_shape = train_windows[0][0].shape + test_shape = test_windows[0][0].shape + print(f"\nWindow shapes - Train: {train_shape}, Test: {test_shape}") + + if train_shape != test_shape: + raise ValueError( + f"Window shapes don't match: train={train_shape}, test={test_shape}" + ) + + window_length = train_shape[1] + return train_windows, test_windows, window_length, samples_before, samples_after + + +# Load datasets with validation +def load_and_validate_dataset(dataset_class, subject_ids): + """Load dataset and validate its contents. + + Parameters + ---------- + dataset_class : MOABB Dataset class + The dataset class to instantiate + subject_ids : list + List of subject IDs to include + + Returns + ------- + dataset : MOABB Dataset + The loaded and validated dataset + """ + # Initialize the dataset + dataset = dataset_class() + dataset.subject_list = subject_ids + + print(f"\nValidating dataset: {dataset.__class__.__name__}") + + try: + # Get data in MOABB format + data = dataset.get_data() + + # Validate each subject's data + for subject_id in subject_ids: + print(f"\nSubject {subject_id}:") + if subject_id not in data: + print(f"No data found for subject {subject_id}") + continue + + subject_data = data[subject_id] + for session_name, session_data in subject_data.items(): + print(f"Session: {session_name}") + for run_name, run_data in session_data.items(): + print(f"Run: {run_name}") + + # Handle both dictionary and MNE Raw object formats + if isinstance(run_data, (RawArray, RawCNT)): + data_array = run_data.get_data() + events = mne.events_from_annotations(run_data)[0] + print(f"Data shape: {data_array.shape}") + print(f"Number of events: {len(events)}") + if len(events) > 0: + print(f"Event types: {np.unique(events[:, -1])}") + elif isinstance(run_data, dict): + if 'data' in run_data and 'events' in run_data: + data_array = run_data['data'] + events_array = run_data['events'] + print(f"Data shape: {data_array.shape}") + print(f"Events shape: {events_array.shape}") + if events_array.size > 0: + print(f"Event types: {np.unique(events_array[:, -1])}") + else: + print("Warning: Run data missing required keys") + else: + print(f"Warning: Unexpected run_data type: {type(run_data)}") + + except Exception as e: + print(f"Error loading dataset: {str(e)}") + raise + + return dataset + + +# The conversion between MOABB and Braindecode formats is necessary because: +# 1. MOABB's match_all provides robust channel matching and interpolation +# 2. Braindecode's training pipeline expects its own data format +# 3. We need to preserve both the benefits of MOABB's preprocessing and Braindecode's training + +def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): + """Convert MOABB dataset format to Braindecode format. + + Parameters + ---------- + moabb_dataset : MOABB Dataset + The MOABB dataset to convert + subject_ids : list + List of subject IDs to include + channels : list + List of channels to use + + Returns + ------- + BaseConcatDataset + Dataset in Braindecode format + """ + # Get data in MOABB format + moabb_data = moabb_dataset.get_data() + + # Create list to hold all raw objects with their subject IDs + raw_objects = [] + descriptions = [] # To keep track of metadata for each raw object + + # Iterate through all subjects and their sessions/runs + for subject_id in subject_ids: + if subject_id not in moabb_data: + print(f"Warning: No data found for subject {subject_id}") + continue + + subject_data = moabb_data[subject_id] + for session_name, session_data in subject_data.items(): + for run_name, run_data in session_data.items(): + try: + if isinstance(run_data, (RawArray, RawCNT)): + # If it's already an MNE Raw object, pick only the common channels + raw = run_data.copy().pick_channels(channels, ordered=True) + raw_objects.append(raw) + descriptions.append({ + 'subject': subject_id, + 'session': session_name, + 'run': run_name, + 'dataset_name': moabb_dataset.__class__.__name__ + }) + elif isinstance(run_data, dict) and 'data' in run_data and 'events' in run_data: + # If it's a dictionary, create MNE Raw object with only common channels + X = run_data['data'] + events = run_data['events'] + + # Create MNE RawArray with only common channels + sfreq = moabb_dataset.interval[2] + info = mne.create_info(ch_names=channels, sfreq=sfreq, ch_types=['eeg'] * len(channels)) + raw = mne.io.RawArray(X.T, info) + + # Convert events to annotations + onset = events[:, 0] / sfreq + duration = np.zeros_like(onset) + description = events[:, -1].astype(str) + annot = mne.Annotations(onset=onset, duration=duration, description=description) + raw.set_annotations(annot) + + raw_objects.append(raw) + descriptions.append({ + 'subject': subject_id, + 'session': session_name, + 'run': run_name, + 'dataset_name': moabb_dataset.__class__.__name__ + }) + else: + print(f"Warning: Invalid run data format for subject {subject_id}, session {session_name}, run {run_name}") + except Exception as e: + print(f"Warning: Error processing run data for subject {subject_id}: {str(e)}") + + if not raw_objects: + raise ValueError("No valid data found to convert") + + # Convert to Braindecode format with proper descriptions + return BaseConcatDataset([ + BaseDataset(raw, description=description) + for raw, description in zip(raw_objects, descriptions) + ]) + + +# Load datasets in MOABB format first +# This allows us to use MOABB's robust dataset handling and preprocessing +print("\nLoading training datasets...") +train_dataset_1_moabb = load_and_validate_dataset(BNCI2014_001, [1, 2, 3, 4]) +train_dataset_2_moabb = load_and_validate_dataset(BNCI2015_001, [1, 2, 3, 4]) +train_dataset_3_moabb = load_and_validate_dataset(BNCI2014_004, [1, 2, 3, 4]) + +print("\nLoading test dataset...") +test_dataset_moabb = load_and_validate_dataset(Zhou2016, [1, 2, 3]) + +# Use MOABB's match_all for channel handling +# This is a crucial step that: +# 1. Ensures all datasets have the same channels +# 2. Handles different channel names and positions +# 3. Interpolates missing channels when needed +# 4. Maintains data quality across datasets +paradigm = MotorImagery() +all_datasets = [ + train_dataset_1_moabb, + train_dataset_2_moabb, + train_dataset_3_moabb, + test_dataset_moabb +] + +# Get the list of channels from each dataset before matching +print("\nChannels before matching:") +for ds in all_datasets: + # Load data for first subject to get channel information + data = ds.get_data([ds.subject_list[0]]) # Get data for first subject + # Get first session and run + first_subject = list(data.keys())[0] + first_session = list(data[first_subject].keys())[0] + first_run = list(data[first_subject][first_session].keys())[0] + run_data = data[first_subject][first_session][first_run] + + if isinstance(run_data, (RawArray, RawCNT)): + channels = run_data.info['ch_names'] + else: + # Assuming the channels are stored in the dataset class after loading + channels = ds.channels + print(f"{ds.__class__.__name__}: {channels}") + +paradigm.match_all(all_datasets, channel_merge_strategy="intersect") + +# Get channels from all datasets after matching to ensure we have the correct intersection +all_channels_after_matching = [] +print("\nGetting channels from each dataset after matching:") +for ds in all_datasets: + data = ds.get_data([ds.subject_list[0]]) + subject = list(data.keys())[0] + session = list(data[subject].keys())[0] + run = list(data[subject][session].keys())[0] + run_data = data[subject][session][run] + + if isinstance(run_data, (RawArray, RawCNT)): + channels = run_data.info['ch_names'] + else: + channels = ds.channels + all_channels_after_matching.append(set(channels)) + print(f"{ds.__class__.__name__}: {channels}") + +# Get the intersection of all channel sets +common_channels = sorted(list(set.intersection(*all_channels_after_matching))) +print(f"\nActual common channels after matching: {common_channels}") +print(f"Number of common channels: {len(common_channels)}") + +# Convert matched datasets to Braindecode format +print("\nConverting datasets to Braindecode format...") +print(f"Using {len(common_channels)} common channels: {common_channels}") + +# Convert datasets using common channels +train_dataset_1 = convert_moabb_to_braindecode(train_dataset_1_moabb, [1, 2, 3, 4], common_channels) +train_dataset_2 = convert_moabb_to_braindecode(train_dataset_2_moabb, [1, 2, 3, 4], common_channels) +train_dataset_3 = convert_moabb_to_braindecode(train_dataset_3_moabb, [1, 2, 3, 4], common_channels) +test_dataset = convert_moabb_to_braindecode(test_dataset_moabb, [1, 2, 3], common_channels) + +# Verify channel counts in converted datasets +print("\nVerifying channel counts in converted datasets:") +for name, dataset in [ + ("Train 1", train_dataset_1), + ("Train 2", train_dataset_2), + ("Train 3", train_dataset_3), + ("Test", test_dataset) +]: + for i, ds in enumerate(dataset.datasets): + n_channels = len(ds.raw.info['ch_names']) + print(f"{name} dataset {i} has {n_channels} channels") + if n_channels != len(common_channels): + raise ValueError(f"Channel count mismatch in {name} dataset {i}: " + f"got {n_channels}, expected {len(common_channels)}") + +# Get all events across all datasets +train_events_1 = train_dataset_1.datasets[0].raw.annotations.description +train_events_2 = train_dataset_2.datasets[0].raw.annotations.description +train_events_3 = train_dataset_3.datasets[0].raw.annotations.description +test_events = test_dataset.datasets[0].raw.annotations.description + +all_events = sorted( + list( + set(train_events_1) + .union(set(train_events_2)) + .union(set(train_events_3)) + .union(set(test_events)) + ) +) +print(f"\nAll unique events across datasets: {all_events}\n") + +event_id = {str(event): idx for idx, event in enumerate(all_events)} +print(f"Event mapping: {event_id}\n") + +# Define number of classes based on all unique events +n_classes = len(event_id) +print(f"Number of classes: {n_classes}\n") + +# Process all training datasets using the common channels +print("\nProcessing training datasets...") +print(f"Using {len(common_channels)} common channels: {common_channels}") + +# Process datasets one at a time to ensure consistent channel counts +train_windows_list = [] +for i, (train_ds, name) in enumerate([ + (train_dataset_1, "Dataset 1"), + (train_dataset_2, "Dataset 2"), + (train_dataset_3, "Dataset 3") +]): + print(f"\nProcessing training {name}...") + # Verify channel count before processing + for ds in train_ds.datasets: + if len(ds.raw.info['ch_names']) != len(common_channels): + print(f"Warning: {name} has {len(ds.raw.info['ch_names'])} channels before processing") + print(f"Current channels: {ds.raw.info['ch_names']}") + + # Process the dataset + windows, _, _, _, _ = standardize_windows( + train_ds, test_dataset, common_channels, event_id + ) + + # Verify window dimensions + for w_idx, window in enumerate(windows): + if window[0].shape[0] != len(common_channels): + raise ValueError( + f"Window {w_idx} in {name} has {window[0].shape[0]} channels, " + f"expected {len(common_channels)}" + ) + + train_windows_list.append(windows) + +print("\nProcessing test dataset...") +train_windows_test, test_windows, window_length, samples_before, samples_after = ( + standardize_windows(train_dataset_1, test_dataset, common_channels, event_id) +) + +# Verify all window shapes before combining +print("\nVerifying window shapes:") +for i, windows in enumerate(train_windows_list): + print(f"Training dataset {i+1} window shape: {windows[0][0].shape}") +print(f"Test dataset window shape: {test_windows[0][0].shape}") + +# Combine all training windows +print("\nCombining training windows...") +combined_train_windows = BaseConcatDataset(train_windows_list) + +# Verify combined dataset +print(f"Combined training set size: {len(combined_train_windows)}") +print(f"First window shape: {combined_train_windows[0][0].shape}") +print(f"Last window shape: {combined_train_windows[-1][0].shape}") + +# Split training data using only the combined dataset +split = combined_train_windows.split("session") +train_set = split["0train"] +valid_set = split["1test"] + +print(f"\nTraining set size: {len(train_set)}") +print(f"Validation set size: {len(valid_set)}") + +# Setup compute device +cuda = torch.cuda.is_available() +device = "cuda" if cuda else "cpu" +if cuda: + torch.backends.cudnn.benchmark = True + +# Set random seed +set_random_seeds(seed=20200220, cuda=cuda) + +# Calculate model parameters based on standardized data +n_chans = len(common_channels) +input_window_samples = window_length # 450 samples + +# Create model with adjusted parameters for all classes +model = ShallowFBCSPNet( + n_chans, + n_classes, # This will now be the total number of unique events + n_times=input_window_samples, + n_filters_time=40, + filter_time_length=20, + pool_time_length=35, + pool_time_stride=7, + final_conv_length="auto", +) + +if cuda: + model = model.cuda() + +# Create and train classifier +clf = EEGClassifier( + model, + criterion=torch.nn.NLLLoss, + optimizer=torch.optim.AdamW, + train_split=predefined_split(valid_set), + optimizer__lr=0.0625 * 0.01, + optimizer__weight_decay=0, + batch_size=64, + callbacks=[ + "accuracy", + ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=11)), + ], + device=device, + classes=list(range(n_classes)), +) + +# Train the model +_ = clf.fit(train_set, y=None, epochs=100) + +# Get test labels from the windows +y_true = test_windows.get_metadata().target +y_pred = clf.predict(test_windows) +test_accuracy = np.mean(y_true == y_pred) +print(f"\nTest accuracy: {test_accuracy:.4f}") + +# Generate confusion matrix for test set +confusion_mat = confusion_matrix(y_true, y_pred) + +# Plot confusion matrix with all event names +plot_confusion_matrix(confusion_mat, class_names=list(event_id.keys())) +plt.show() + +# Create visualization +fig = plt.figure(figsize=(10, 5)) +plt.plot(clf.history[:, "train_loss"], label="Training Loss") +plt.plot(clf.history[:, "valid_loss"], label="Validation Loss") +plt.legend() +plt.xlabel("Epoch") +plt.ylabel("Loss") +plt.title("Training and Validation Loss Over Time") +plt.show() + +# Plot training curves +results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"] +df = pd.DataFrame( + clf.history[:, results_columns], + columns=results_columns, + index=clf.history[:, "epoch"], +) + +df = df.assign( + train_misclass=100 - 100 * df.train_accuracy, + valid_misclass=100 - 100 * df.valid_accuracy, +) + +# Plot results +fig, ax1 = plt.subplots(figsize=(8, 3)) +df.loc[:, ["train_loss", "valid_loss"]].plot( + ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False +) + +ax1.tick_params(axis="y", labelcolor="tab:blue") +ax1.set_ylabel("Loss", color="tab:blue") + +ax2 = ax1.twinx() +df.loc[:, ["train_misclass", "valid_misclass"]].plot( + ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False +) +ax2.tick_params(axis="y", labelcolor="tab:red") +ax2.set_ylabel("Misclassification Rate [%]", color="tab:red") +ax2.set_ylim(ax2.get_ylim()[0], 85) + +handles = [] +handles.append(Line2D([0], [0], color="black", linestyle="-", label="Train")) +handles.append(Line2D([0], [0], color="black", linestyle=":", label="Valid")) +plt.legend(handles, [h.get_label() for h in handles]) +plt.tight_layout() diff --git a/examples/cross_dataset_braindecode.py b/examples/cross_dataset_braindecode.py deleted file mode 100644 index 72cdf7ca8..000000000 --- a/examples/cross_dataset_braindecode.py +++ /dev/null @@ -1,486 +0,0 @@ -""" -Cross-Dataset Brain Decoding with Deep Learning -============================================= -This example shows how to train deep learning models on one dataset -and test on another using Braindecode. -""" - -import logging -from typing import Dict, List, Tuple - -import matplotlib.pyplot as plt -import mne -import numpy as np -import pandas as pd -import torch -from braindecode import EEGClassifier -from braindecode.datasets import BaseConcatDataset, MOABBDataset -from braindecode.models import ShallowFBCSPNet -from braindecode.preprocessing import ( - Preprocessor, - create_windows_from_events, - exponential_moving_standardize, - preprocess, -) -from braindecode.util import set_random_seeds -from braindecode.visualization import plot_confusion_matrix -from matplotlib.lines import Line2D -from sklearn.metrics import confusion_matrix -from skorch.callbacks import LRScheduler -from skorch.helper import predefined_split - - -# Configure logging -logging.basicConfig(level=logging.WARNING) - - -def get_common_channels(train_dataset, test_dataset): - """Get channels that are available across both datasets.""" - train_chans = train_dataset.datasets[0].raw.ch_names - test_chans = test_dataset.datasets[0].raw.ch_names - common_channels = sorted(list(set(train_chans).intersection(set(test_chans)))) - print(f"\nCommon channels across datasets: {common_channels}\n") - return common_channels - - -def get_all_events(train_dataset, test_dataset): - """Get all unique events across datasets.""" - # Get events from first subject of each dataset - train_events = train_dataset.datasets[0].raw.annotations.description - test_events = test_dataset.datasets[0].raw.annotations.description - - # Get all unique events - all_events = sorted(list(set(train_events).union(set(test_events)))) - print(f"\nAll unique events across datasets: {all_events}\n") - - # Create event mapping (event description -> numerical ID) - event_id = {str(event): idx for idx, event in enumerate(all_events)} - print(f"Event mapping: {event_id}\n") - - return event_id - - -def interpolate_missing_channels( - raw_data: mne.io.Raw, all_channels: List[str] -) -> mne.io.Raw: - """Interpolate missing channels using spherical spline interpolation. - - Parameters - ---------- - raw_data : mne.io.Raw - Raw EEG data to process - all_channels : List[str] - List of all required channel names - - Returns - ------- - mne.io.Raw - Processed data with interpolated channels - - Raises - ------ - TypeError - If raw_data is not an MNE Raw object - """ - if isinstance(raw_data, np.ndarray): - return raw_data - - if not isinstance(raw_data, mne.io.Raw): - raise TypeError("Expected MNE Raw object") - - missing_channels = [ch for ch in all_channels if ch not in raw_data.ch_names] - existing_channels = raw_data.ch_names - - print("\nChannel Information:") - print(f"Total channels needed: {len(all_channels)}") - print(f"Existing channels: {len(existing_channels)}") - print(f"Missing channels to interpolate: {len(missing_channels)}") - - if missing_channels: - print("\nMissing channels:") - for ch in missing_channels: - print(f"- {ch}") - - # Mark missing channels as bad - raw_data.info["bads"] = missing_channels - - # Add missing channels (temporarily with zeros) - print("\nAdding temporary channels for interpolation...") - raw_data.add_channels( - [ - mne.io.RawArray( - np.zeros((1, len(raw_data.times))), - mne.create_info([ch], raw_data.info["sfreq"], ["eeg"]), - ) - for ch in missing_channels - ] - ) - - # Interpolate the bad channels - print("Performing spherical spline interpolation...") - raw_data.interpolate_bads(reset_bads=True) - - # Calculate and print interpolation statistics - data_after = raw_data.get_data() - for ch in missing_channels: - ch_idx = raw_data.ch_names.index(ch) - interpolated_data = data_after[ch_idx] - stats = { - "mean": np.mean(interpolated_data), - "std": np.std(interpolated_data), - "min": np.min(interpolated_data), - "max": np.max(interpolated_data), - } - print(f"\nInterpolated channel {ch} statistics:") - print(f"- Mean: {stats['mean']:.2f}") - print(f"- Std: {stats['std']:.2f}") - print(f"- Range: [{stats['min']:.2f}, {stats['max']:.2f}]") - - print("\nInterpolation complete.") - else: - print("No channels need interpolation.") - - return raw_data - - -def create_fixed_windows( - dataset: BaseConcatDataset, - samples_before: int, - samples_after: int, - event_id: Dict[str, int], -) -> BaseConcatDataset: - """Create windows with consistent size across datasets. - - Parameters - ---------- - dataset : BaseConcatDataset - Dataset to create windows from - samples_before : int - Number of samples before event - samples_after : int - Number of samples after event - event_id : Dict[str, int] - Mapping of event names to numerical IDs - - Returns - ------- - BaseConcatDataset - Windowed dataset - """ - return create_windows_from_events( - dataset, - trial_start_offset_samples=-samples_before, - trial_stop_offset_samples=samples_after, - preload=True, - mapping=event_id, - window_size_samples=samples_before + samples_after, - window_stride_samples=samples_before + samples_after, - ) - - -def standardize_windows( - train_dataset: BaseConcatDataset, - test_dataset: BaseConcatDataset, - all_channels: List[str], - event_id: Dict[str, int], -) -> Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int]: - """Standardize datasets with consistent preprocessing. - - Parameters - ---------- - train_dataset : BaseConcatDataset - Training dataset to standardize - test_dataset : BaseConcatDataset - Test dataset to standardize - all_channels : List[str] - List of all required channel names - event_id : Dict[str, int] - Mapping of event names to numerical IDs - - Returns - ------- - Tuple[BaseConcatDataset, BaseConcatDataset, int, int, int] - Processed training windows, test windows, window length, - samples before and after - """ - target_sfreq = 100 # Target sampling frequency - - print("\nInitial dataset properties:") - for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: - for i, ds in enumerate(dataset.datasets): - print(f"{name} dataset {i} sampling rate: {ds.raw.info['sfreq']} Hz") - - # Define preprocessing pipeline - preprocessors = [ - Preprocessor(lambda raw: interpolate_missing_channels(raw, all_channels)), - Preprocessor("pick_channels", ch_names=all_channels, ordered=True), - Preprocessor("resample", sfreq=target_sfreq), - Preprocessor(lambda data: np.multiply(data, 1e6)), - Preprocessor("filter", l_freq=4.0, h_freq=38.0), - Preprocessor( - exponential_moving_standardize, factor_new=1e-3, init_block_size=1000 - ), - ] - - # Apply preprocessing - preprocess(train_dataset, preprocessors, n_jobs=-1) - preprocess(test_dataset, preprocessors, n_jobs=-1) - - # Define window parameters - window_start = -0.5 # Start 0.5s before event - window_duration = 4.0 # Window duration in seconds - samples_before = int(abs(window_start) * target_sfreq) - samples_after = int(window_duration * target_sfreq) - - # Standardize event durations - for dataset in [train_dataset, test_dataset]: - for ds in dataset.datasets: - events = ds.raw.annotations[:-1] - new_annotations = mne.Annotations( - onset=events.onset, - duration=np.zeros_like(events.duration), - description=events.description, - ) - ds.raw.set_annotations(new_annotations) - - # Create and validate windows - train_windows = create_fixed_windows( - train_dataset, samples_before, samples_after, event_id - ) - test_windows = create_fixed_windows( - test_dataset, samples_before, samples_after, event_id - ) - - # Verify window shapes - train_shape = train_windows[0][0].shape - test_shape = test_windows[0][0].shape - - if train_shape != test_shape: - raise ValueError( - f"Window shapes don't match: train={train_shape}, test={test_shape}" - ) - - window_length = train_shape[1] - return train_windows, test_windows, window_length, samples_before, samples_after - - -# Load datasets with validation -def load_and_validate_dataset(dataset_name, subject_ids): - """Load dataset and validate its contents.""" - dataset = MOABBDataset(dataset_name=dataset_name, subject_ids=subject_ids) - - print(f"\nValidating dataset: {dataset_name}") - for i, ds in enumerate(dataset.datasets): - events = ds.raw.annotations - print(f"\nSubject {i+1}:") # Changed from subject_ids[i] to i+1 - print(f"Number of events: {len(events)}") - print(f"Unique event types: {set(events.description)}") - print(f"First 5 event timings: {events.onset[:5]}") - print(f"Sample event descriptions: {list(events.description[:5])}") - - return dataset - - -# Load datasets with validation -print("\nLoading training datasets...") -train_dataset_1 = load_and_validate_dataset("BNCI2014_001", subject_ids=[1, 2, 3, 4]) -train_dataset_2 = load_and_validate_dataset("BNCI2015_001", subject_ids=[1, 2, 3, 4]) -train_dataset_3 = load_and_validate_dataset("BNCI2014_004", subject_ids=[1, 2, 3, 4]) - -print("\nLoading test dataset...") -test_dataset = load_and_validate_dataset("Zhou2016", subject_ids=[1, 2, 3]) - -# Verify datasets are different -print("\nVerifying dataset uniqueness...") -for name1, ds1 in [ - ("Train1", train_dataset_1), - ("Train2", train_dataset_2), - ("Train3", train_dataset_3), - ("Test", test_dataset), -]: - for name2, ds2 in [ - ("Train1", train_dataset_1), - ("Train2", train_dataset_2), - ("Train3", train_dataset_3), - ("Test", test_dataset), - ]: - if name1 < name2: # Compare each pair only once - print(f"\nComparing {name1} vs {name2}:") - # Compare first subject of each dataset - events1 = ds1.datasets[0].raw.annotations - events2 = ds2.datasets[0].raw.annotations - print(f"Event counts: {len(events1)} vs {len(events2)}") - print( - f"Event types: {set(events1.description)} vs {set(events2.description)}" - ) - print(f"First timing: {events1.onset[0]} vs {events2.onset[0]}") - -# Get common channels across all datasets -train_chans_1 = train_dataset_1.datasets[0].raw.ch_names -train_chans_2 = train_dataset_2.datasets[0].raw.ch_names -train_chans_3 = train_dataset_3.datasets[0].raw.ch_names -test_chans = test_dataset.datasets[0].raw.ch_names - -common_channels = sorted( - list( - set(train_chans_1) - .intersection(set(train_chans_2)) - .intersection(set(train_chans_3)) - .intersection(set(test_chans)) - ) -) -print(f"\nCommon channels across all datasets: {common_channels}\n") - -# Get all events across all datasets -train_events_1 = train_dataset_1.datasets[0].raw.annotations.description -train_events_2 = train_dataset_2.datasets[0].raw.annotations.description -train_events_3 = train_dataset_3.datasets[0].raw.annotations.description -test_events = test_dataset.datasets[0].raw.annotations.description - -all_events = sorted( - list( - set(train_events_1) - .union(set(train_events_2)) - .union(set(train_events_3)) - .union(set(test_events)) - ) -) -print(f"\nAll unique events across datasets: {all_events}\n") - -event_id = {str(event): idx for idx, event in enumerate(all_events)} -print(f"Event mapping: {event_id}\n") - -# Define number of classes based on all unique events -n_classes = len(event_id) -print(f"Number of classes: {n_classes}\n") - -# Process all training datasets -train_windows_1, _, _, _, _ = standardize_windows( - train_dataset_1, test_dataset, common_channels, event_id -) -train_windows_2, _, _, _, _ = standardize_windows( - train_dataset_2, test_dataset, common_channels, event_id -) -train_windows_3, _, _, _, _ = standardize_windows( - train_dataset_3, test_dataset, common_channels, event_id -) -train_windows_test, test_windows, window_length, samples_before, samples_after = ( - standardize_windows(train_dataset_1, test_dataset, common_channels, event_id) -) - -# Combine all training windows -combined_train_windows = BaseConcatDataset( - [train_windows_1, train_windows_2, train_windows_3] -) - -# Split training data using only the combined dataset -split = combined_train_windows.split("session") -train_set = split["0train"] -valid_set = split["1test"] - -# Setup compute device -cuda = torch.cuda.is_available() -device = "cuda" if cuda else "cpu" -if cuda: - torch.backends.cudnn.benchmark = True - -# Set random seed -set_random_seeds(seed=20200220, cuda=cuda) - -# Calculate model parameters based on standardized data -n_chans = len(common_channels) -input_window_samples = window_length # 450 samples - -# Create model with adjusted parameters for all classes -model = ShallowFBCSPNet( - n_chans, - n_classes, # This will now be the total number of unique events - n_times=input_window_samples, - n_filters_time=40, - filter_time_length=20, - pool_time_length=35, - pool_time_stride=7, - final_conv_length="auto", -) - -if cuda: - model = model.cuda() - -# Create and train classifier -clf = EEGClassifier( - model, - criterion=torch.nn.NLLLoss, - optimizer=torch.optim.AdamW, - train_split=predefined_split(valid_set), - optimizer__lr=0.0625 * 0.01, - optimizer__weight_decay=0, - batch_size=64, - callbacks=[ - "accuracy", - ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=11)), - ], - device=device, - classes=list(range(n_classes)), -) - -# Train the model -_ = clf.fit(train_set, y=None, epochs=100) - -# Get test labels from the windows -y_true = test_windows.get_metadata().target -y_pred = clf.predict(test_windows) -test_accuracy = np.mean(y_true == y_pred) -print(f"\nTest accuracy: {test_accuracy:.4f}") - -# Generate confusion matrix for test set -confusion_mat = confusion_matrix(y_true, y_pred) - -# Plot confusion matrix with all event names -plot_confusion_matrix(confusion_mat, class_names=list(event_id.keys())) -plt.show() - -# Create visualization -fig = plt.figure(figsize=(10, 5)) -plt.plot(clf.history[:, "train_loss"], label="Training Loss") -plt.plot(clf.history[:, "valid_loss"], label="Validation Loss") -plt.legend() -plt.xlabel("Epoch") -plt.ylabel("Loss") -plt.title("Training and Validation Loss Over Time") -plt.show() - -# Plot training curves -results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"] -df = pd.DataFrame( - clf.history[:, results_columns], - columns=results_columns, - index=clf.history[:, "epoch"], -) - -df = df.assign( - train_misclass=100 - 100 * df.train_accuracy, - valid_misclass=100 - 100 * df.valid_accuracy, -) - -# Plot results -fig, ax1 = plt.subplots(figsize=(8, 3)) -df.loc[:, ["train_loss", "valid_loss"]].plot( - ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False -) - -ax1.tick_params(axis="y", labelcolor="tab:blue") -ax1.set_ylabel("Loss", color="tab:blue") - -ax2 = ax1.twinx() -df.loc[:, ["train_misclass", "valid_misclass"]].plot( - ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False -) -ax2.tick_params(axis="y", labelcolor="tab:red") -ax2.set_ylabel("Misclassification Rate [%]", color="tab:red") -ax2.set_ylim(ax2.get_ylim()[0], 85) - -handles = [] -handles.append(Line2D([0], [0], color="black", linestyle="-", label="Train")) -handles.append(Line2D([0], [0], color="black", linestyle=":", label="Valid")) -plt.legend(handles, [h.get_label() for h in handles]) -plt.tight_layout() From ae8ac234f8e645491dd82754ebb5bd45cfa34fc4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 19:27:39 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks --- .../plot_cross_dataset_braindecode.py | 147 +++++++++++------- 1 file changed, 94 insertions(+), 53 deletions(-) diff --git a/examples/advanced_examples/plot_cross_dataset_braindecode.py b/examples/advanced_examples/plot_cross_dataset_braindecode.py index 8968e2a45..f2812c1f6 100644 --- a/examples/advanced_examples/plot_cross_dataset_braindecode.py +++ b/examples/advanced_examples/plot_cross_dataset_braindecode.py @@ -132,19 +132,23 @@ def standardize_windows( print(f" Channel names: {ds.raw.info['ch_names']}") # Get the actual available channels (intersection of all datasets) - available_channels = set(train_dataset.datasets[0].raw.info['ch_names']) + available_channels = set(train_dataset.datasets[0].raw.info["ch_names"]) for ds in train_dataset.datasets[1:] + test_dataset.datasets: - available_channels = available_channels.intersection(ds.raw.info['ch_names']) + available_channels = available_channels.intersection(ds.raw.info["ch_names"]) available_channels = sorted(list(available_channels)) - print(f"\nCommon channels across all datasets ({len(available_channels)}): {available_channels}") + print( + f"\nCommon channels across all datasets ({len(available_channels)}): {available_channels}" + ) # Verify all datasets have the same number of channels for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: for i, ds in enumerate(dataset.datasets): - if len(ds.raw.info['ch_names']) != len(available_channels): - print(f"Warning: {name} dataset {i} has {len(ds.raw.info['ch_names'])} channels, " - f"expected {len(available_channels)}") + if len(ds.raw.info["ch_names"]) != len(available_channels): + print( + f"Warning: {name} dataset {i} has {len(ds.raw.info['ch_names'])} channels, " + f"expected {len(available_channels)}" + ) # Define preprocessing pipeline using only available channels preprocessors = [ @@ -166,11 +170,13 @@ def standardize_windows( print("\nVerifying channel counts after preprocessing:") for name, dataset in [("Train", train_dataset), ("Test", test_dataset)]: for i, ds in enumerate(dataset.datasets): - n_channels = len(ds.raw.info['ch_names']) + n_channels = len(ds.raw.info["ch_names"]) print(f"{name} dataset {i} has {n_channels} channels") if n_channels != len(available_channels): - raise ValueError(f"Channel count mismatch in {name} dataset {i}: " - f"got {n_channels}, expected {len(available_channels)}") + raise ValueError( + f"Channel count mismatch in {name} dataset {i}: " + f"got {n_channels}, expected {len(available_channels)}" + ) # Define window parameters window_start = -0.5 # Start 0.5s before event @@ -260,9 +266,9 @@ def load_and_validate_dataset(dataset_class, subject_ids): if len(events) > 0: print(f"Event types: {np.unique(events[:, -1])}") elif isinstance(run_data, dict): - if 'data' in run_data and 'events' in run_data: - data_array = run_data['data'] - events_array = run_data['events'] + if "data" in run_data and "events" in run_data: + data_array = run_data["data"] + events_array = run_data["events"] print(f"Data shape: {data_array.shape}") print(f"Events shape: {events_array.shape}") if events_array.size > 0: @@ -284,6 +290,7 @@ def load_and_validate_dataset(dataset_class, subject_ids): # 2. Braindecode's training pipeline expects its own data format # 3. We need to preserve both the benefits of MOABB's preprocessing and Braindecode's training + def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): """Convert MOABB dataset format to Braindecode format. @@ -322,49 +329,69 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): # If it's already an MNE Raw object, pick only the common channels raw = run_data.copy().pick_channels(channels, ordered=True) raw_objects.append(raw) - descriptions.append({ - 'subject': subject_id, - 'session': session_name, - 'run': run_name, - 'dataset_name': moabb_dataset.__class__.__name__ - }) - elif isinstance(run_data, dict) and 'data' in run_data and 'events' in run_data: + descriptions.append( + { + "subject": subject_id, + "session": session_name, + "run": run_name, + "dataset_name": moabb_dataset.__class__.__name__, + } + ) + elif ( + isinstance(run_data, dict) + and "data" in run_data + and "events" in run_data + ): # If it's a dictionary, create MNE Raw object with only common channels - X = run_data['data'] - events = run_data['events'] + X = run_data["data"] + events = run_data["events"] # Create MNE RawArray with only common channels sfreq = moabb_dataset.interval[2] - info = mne.create_info(ch_names=channels, sfreq=sfreq, ch_types=['eeg'] * len(channels)) + info = mne.create_info( + ch_names=channels, + sfreq=sfreq, + ch_types=["eeg"] * len(channels), + ) raw = mne.io.RawArray(X.T, info) # Convert events to annotations onset = events[:, 0] / sfreq duration = np.zeros_like(onset) description = events[:, -1].astype(str) - annot = mne.Annotations(onset=onset, duration=duration, description=description) + annot = mne.Annotations( + onset=onset, duration=duration, description=description + ) raw.set_annotations(annot) raw_objects.append(raw) - descriptions.append({ - 'subject': subject_id, - 'session': session_name, - 'run': run_name, - 'dataset_name': moabb_dataset.__class__.__name__ - }) + descriptions.append( + { + "subject": subject_id, + "session": session_name, + "run": run_name, + "dataset_name": moabb_dataset.__class__.__name__, + } + ) else: - print(f"Warning: Invalid run data format for subject {subject_id}, session {session_name}, run {run_name}") + print( + f"Warning: Invalid run data format for subject {subject_id}, session {session_name}, run {run_name}" + ) except Exception as e: - print(f"Warning: Error processing run data for subject {subject_id}: {str(e)}") + print( + f"Warning: Error processing run data for subject {subject_id}: {str(e)}" + ) if not raw_objects: raise ValueError("No valid data found to convert") # Convert to Braindecode format with proper descriptions - return BaseConcatDataset([ - BaseDataset(raw, description=description) - for raw, description in zip(raw_objects, descriptions) - ]) + return BaseConcatDataset( + [ + BaseDataset(raw, description=description) + for raw, description in zip(raw_objects, descriptions) + ] + ) # Load datasets in MOABB format first @@ -388,7 +415,7 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): train_dataset_1_moabb, train_dataset_2_moabb, train_dataset_3_moabb, - test_dataset_moabb + test_dataset_moabb, ] # Get the list of channels from each dataset before matching @@ -403,7 +430,7 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): run_data = data[first_subject][first_session][first_run] if isinstance(run_data, (RawArray, RawCNT)): - channels = run_data.info['ch_names'] + channels = run_data.info["ch_names"] else: # Assuming the channels are stored in the dataset class after loading channels = ds.channels @@ -422,7 +449,7 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): run_data = data[subject][session][run] if isinstance(run_data, (RawArray, RawCNT)): - channels = run_data.info['ch_names'] + channels = run_data.info["ch_names"] else: channels = ds.channels all_channels_after_matching.append(set(channels)) @@ -438,10 +465,18 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): print(f"Using {len(common_channels)} common channels: {common_channels}") # Convert datasets using common channels -train_dataset_1 = convert_moabb_to_braindecode(train_dataset_1_moabb, [1, 2, 3, 4], common_channels) -train_dataset_2 = convert_moabb_to_braindecode(train_dataset_2_moabb, [1, 2, 3, 4], common_channels) -train_dataset_3 = convert_moabb_to_braindecode(train_dataset_3_moabb, [1, 2, 3, 4], common_channels) -test_dataset = convert_moabb_to_braindecode(test_dataset_moabb, [1, 2, 3], common_channels) +train_dataset_1 = convert_moabb_to_braindecode( + train_dataset_1_moabb, [1, 2, 3, 4], common_channels +) +train_dataset_2 = convert_moabb_to_braindecode( + train_dataset_2_moabb, [1, 2, 3, 4], common_channels +) +train_dataset_3 = convert_moabb_to_braindecode( + train_dataset_3_moabb, [1, 2, 3, 4], common_channels +) +test_dataset = convert_moabb_to_braindecode( + test_dataset_moabb, [1, 2, 3], common_channels +) # Verify channel counts in converted datasets print("\nVerifying channel counts in converted datasets:") @@ -449,14 +484,16 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): ("Train 1", train_dataset_1), ("Train 2", train_dataset_2), ("Train 3", train_dataset_3), - ("Test", test_dataset) + ("Test", test_dataset), ]: for i, ds in enumerate(dataset.datasets): - n_channels = len(ds.raw.info['ch_names']) + n_channels = len(ds.raw.info["ch_names"]) print(f"{name} dataset {i} has {n_channels} channels") if n_channels != len(common_channels): - raise ValueError(f"Channel count mismatch in {name} dataset {i}: " - f"got {n_channels}, expected {len(common_channels)}") + raise ValueError( + f"Channel count mismatch in {name} dataset {i}: " + f"got {n_channels}, expected {len(common_channels)}" + ) # Get all events across all datasets train_events_1 = train_dataset_1.datasets[0].raw.annotations.description @@ -487,16 +524,20 @@ def convert_moabb_to_braindecode(moabb_dataset, subject_ids, channels): # Process datasets one at a time to ensure consistent channel counts train_windows_list = [] -for i, (train_ds, name) in enumerate([ - (train_dataset_1, "Dataset 1"), - (train_dataset_2, "Dataset 2"), - (train_dataset_3, "Dataset 3") -]): +for i, (train_ds, name) in enumerate( + [ + (train_dataset_1, "Dataset 1"), + (train_dataset_2, "Dataset 2"), + (train_dataset_3, "Dataset 3"), + ] +): print(f"\nProcessing training {name}...") # Verify channel count before processing for ds in train_ds.datasets: - if len(ds.raw.info['ch_names']) != len(common_channels): - print(f"Warning: {name} has {len(ds.raw.info['ch_names'])} channels before processing") + if len(ds.raw.info["ch_names"]) != len(common_channels): + print( + f"Warning: {name} has {len(ds.raw.info['ch_names'])} channels before processing" + ) print(f"Current channels: {ds.raw.info['ch_names']}") # Process the dataset