From 8c02a82575260f937bc3bae1c870bf4c2a334b51 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 2 Mar 2026 15:59:56 -0500 Subject: [PATCH 1/4] Update the update_deprecated scripts --- pyproject.toml | 1 + .../cli/dwiml_send_value_to_comet_from_log.py | 13 +- src/dwi_ml/cli/dwiml_visualize_logs.py | 4 +- src/dwi_ml/cli/l2t_update_deprecated_exp.py | 73 ++++-- src/dwi_ml/cli/tt_update_deprecated_exp.py | 231 ++++++++++++++++++ 5 files changed, 291 insertions(+), 31 deletions(-) create mode 100644 src/dwi_ml/cli/tt_update_deprecated_exp.py diff --git a/pyproject.toml b/pyproject.toml index ca092d6c..5a994414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,5 +76,6 @@ l2t_visualize_weights_evolution = "dwi_ml.cli.l2t_visualize_weights_evolution:ma tt_resume_training_from_checkpoint = "dwi_ml.cli.tt_resume_training_from_checkpoint:main" tt_track_from_model = "dwi_ml.cli.tt_track_from_model:main" tt_train_model = "dwi_ml.cli.tt_train_model:main" +tt_update_deprecated_exp = "dwi_ml.cli.tt_update_deprecated_exp:main" tt_visualize_loss = "dwi_ml.cli.tt_visualize_loss:main" tt_visualize_weights = "dwi_ml.cli.tt_visualize_weights:main" diff --git a/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py b/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py index ed89decd..9aacb095 100644 --- a/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py +++ b/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py @@ -32,10 +32,11 @@ def _build_arg_parser(): help="Comet.ml's metric name(s). Must contain the same " "number of inputs as --logs.\n" "If not set, we will suggest you the probable name(s) " - "but we will not run the script.") - p.add_argument('--use_suggested_name', action='store_true', - help="If set and --metric_name is not set, will run with " - "the suggested name(s).") + "but we will not run the script, unless you add " + "--use_suggested_names") + p.add_argument('--use_suggested_names', action='store_true', + help="If set and --metric_names is not set, guess the " + "metric names.") p.add_argument('--use_best', action='store_true', help="If set, uses only the best value in segment [0, t] " "as value at time t. (Best = lowest).") @@ -67,7 +68,7 @@ def main(): if args.use_best: print("Not sure what to suggest you for --metric_name with " "option --use_best. Probably 'best_loss'!?") - args.use_suggested_name = False + args.use_suggested_names = False else: for log in log_paths: # Based on current implementation of things: @@ -91,7 +92,7 @@ def main(): args.metric_names.append(metric_name) # Possibly stop now - if not args.use_suggested_name: + if not args.use_suggested_names: return # Verify diff --git a/src/dwi_ml/cli/dwiml_visualize_logs.py b/src/dwi_ml/cli/dwiml_visualize_logs.py index e8807748..ac4e8618 100644 --- a/src/dwi_ml/cli/dwiml_visualize_logs.py +++ b/src/dwi_ml/cli/dwiml_visualize_logs.py @@ -152,7 +152,7 @@ def __parse_log_operations(parser, graph): parser.error("Can't understand option --graph graph. The diff " "operation requires two logs, separated by a " "comma.") - logs = [_log.replace(' ', '') for _log in logs] + logs = [_log._replace(' ', '') for _log in logs] return logs, 'diff' elif _graph[0:4] == 'sum(': assert _graph[-1] == ')' @@ -162,7 +162,7 @@ def __parse_log_operations(parser, graph): parser.error("Can't understand option --graph graph. The sum " "operation requires two logs, separated by a " "comma.") - logs = [_log.replace(' ', '') for _log in logs] + logs = [_log._replace(' ', '') for _log in logs] return logs, 'sum' op = None diff --git a/src/dwi_ml/cli/l2t_update_deprecated_exp.py b/src/dwi_ml/cli/l2t_update_deprecated_exp.py index a62edce7..833078d8 100644 --- a/src/dwi_ml/cli/l2t_update_deprecated_exp.py +++ b/src/dwi_ml/cli/l2t_update_deprecated_exp.py @@ -3,6 +3,9 @@ """ Copies an existing experiment to another folder, updating deprecated values. + +Useful for Emmanuelle's work! :) + """ import argparse import json @@ -12,7 +15,7 @@ import numpy as np import torch -from scilpy.io.utils import add_overwrite_arg, add_verbose_arg +from scilpy.io.utils import add_verbose_arg from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str @@ -32,14 +35,15 @@ def prepare_arg_parser(): p.add_argument('--new_hdf5', help="Required only if previous hdf5 has been moved.") - add_overwrite_arg(p) add_verbose_arg(p) return p -def replace(params, old_key, new_key): +def _replace(params, old_key, new_key): if old_key in params: + logging.warning("Replacing old key {} with new key {}" + .format(old_key, new_key)) params[new_key] = params[old_key] del params[old_key] else: @@ -49,11 +53,11 @@ def replace(params, old_key, new_key): return params -def fix_deprecated_model_params(params): +def fix_deprecated_learn2track_params(params): # embedding_size --> embedded_size - params = replace(params, 'prev_dirs_embedding_size', + params = _replace(params, 'prev_dirs_embedding_size', 'prev_dirs_embedded_size') - params = replace(params, 'input_embedding_size', + params = _replace(params, 'input_embedding_size', 'input_embedded_size') # deleted size_ratio option @@ -61,17 +65,37 @@ def fix_deprecated_model_params(params): assert params['input_embedding_size_ratio'] is None, \ ("Can't fix deprecated value 'inpute_embedded_size_ratio'; has no" " new equivalent. I thought I never used it.") + logging.warning("Deleting option input_embedding_size_ratio") del params['input_embedding_size_ratio'] + # deleted start_from_copy_prev + if 'start_from_copy_prev' in params: + logging.warning("Deleting option start_from_copy_prev") + del params['start_from_copy_prev'] + + # deleted options compress_loss, weight_loss_with_angle (in dg) + if 'compress_loss' in params['dg_args']: + logging.warning("Deleting option compress_loss") + del params['dg_args']['compress_loss'] + if 'compress_eps' in params['dg_args']: + logging.warning("Deleting option compress_eps") + del params['dg_args']['compress_eps'] + if 'weight_loss_with_angle' in params['dg_args']: + logging.warning("Deleting option weight_loss_with_angle") + del params['dg_args']['weight_loss_with_angle'] + # Added cnn options if 'nb_cnn_filters' not in params: + logging.warning("Adding options for CNN") params['nb_cnn_filters'] = None if 'kernel_size' not in params: + logging.warning("Adding options for CNN") params['kernel_size'] = None # Neighborhood management modified r = params['neighborhood_radius'] if isinstance(r, list): + logging.warning("Updating neighborhood radius management") params['neighborhood_radius'] = len(r) params['neighborhood_resolution'] = r[0] if len(r) > 1: @@ -83,11 +107,14 @@ def fix_deprecated_model_params(params): return params -def fix_deprecated_checkpoint_params(checkpoint_state): +def fix_other_checkpoint_params(checkpoint_state): + """ + Updating non-specific params (trainer, batch loader, etc) + """ # 1) Dataset params: better use a --new_hdf5 to fix. # 2) Trainer params : nb_steps --> nb_segments - checkpoint_state['params_for_init'] = replace( + checkpoint_state['params_for_init'] = _replace( checkpoint_state['params_for_init'], 'tracking_phase_nb_steps_init', 'tracking_phase_nb_segments_init') @@ -114,14 +141,20 @@ def fix_deprecated_checkpoint_params(checkpoint_state): return checkpoint_state -def fix_both_model_parameter_json_files(args): - # 1) Loading params from checkpoint's model +def fix_model_parameters_json_file(args): + """ + Updating this specific model's params (learn2track) + """ + + # 1) Loading params from checkpoint's latest model model_dir = os.path.join(args.experiment_path, 'checkpoint', 'model') params = Learn2TrackModel._load_params(model_dir) - # 2) Loading params from best model and verifying that they fit + # 2) Loading params from best model model_dir = os.path.join(args.experiment_path, 'best_model') params2 = Learn2TrackModel._load_params(model_dir) + + # Verifying that they fit assert params == params2, ("Unexpected error. Parameters in the " "checkpoint dir and in the best_model dir " "should be the same. Did you modify the " @@ -137,7 +170,7 @@ def fix_both_model_parameter_json_files(args): logging.debug("Loaded params:\n{}".format(format_dict_to_str(params))) # 3) Fixing params - params = fix_deprecated_model_params(params) + params = fix_deprecated_learn2track_params(params) print("\n\n----------------Fixed the model parameters ----------------\n" "Reformated model's params:\n " + format_dict_to_str(params)) @@ -207,7 +240,7 @@ def fix_checkpoint(args, model): argparse.Namespace(**dataset_params)) # Fixing checkpoint - checkpoint_state = fix_deprecated_checkpoint_params(checkpoint_state) + checkpoint_state = fix_other_checkpoint_params(checkpoint_state) checkpoint_dir = os.path.join(args.out_experiment, "checkpoint") torch.save(checkpoint_state, os.path.join(checkpoint_dir, "checkpoint_state.pkl")) @@ -218,7 +251,7 @@ def fix_checkpoint(args, model): batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( dataset, model, checkpoint_state['batch_loader_params']) experiments_path, experiment_name = os.path.split(args.out_experiment) - trainer = Learn2TrackTrainer.init_from_checkpoint( + _ = Learn2TrackTrainer.init_from_checkpoint( model, experiments_path, experiment_name, batch_sampler, batch_loader, checkpoint_state, new_patience=None, new_max_epochs=None, @@ -232,21 +265,15 @@ def main(): # General logging (ex, scilpy: Warning) logging.getLogger().setLevel(level=logging.WARNING) - # Verify if a checkpoint has been saved. if not os.path.exists(args.experiment_path): raise FileNotFoundError("Experiment not found ({})." .format(args.experiment_path)) if os.path.exists(args.out_experiment): - if args.overwrite: - logging.warning("Careful. Deleting whole dir: {}" - .format(args.out_experiment)) - shutil.rmtree(args.out_experiment) - else: - raise FileExistsError("Out experiment already exists! ({})." - .format(args.out_experiment)) + raise FileExistsError("Out experiment already exists! ({})." + .format(args.out_experiment)) shutil.copytree(args.experiment_path, args.out_experiment) - model = fix_both_model_parameter_json_files(args) + model = fix_model_parameters_json_file(args) fix_checkpoint(args, model) print("Out experiment {} should now be usable!" diff --git a/src/dwi_ml/cli/tt_update_deprecated_exp.py b/src/dwi_ml/cli/tt_update_deprecated_exp.py new file mode 100644 index 00000000..c12fed18 --- /dev/null +++ b/src/dwi_ml/cli/tt_update_deprecated_exp.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Copies an existing experiment to another folder, updating deprecated values. + +Useful for Emmanuelle's work! :) + +""" +import argparse +import json +import logging +import os +import shutil + +import torch +from dwi_ml.io_utils import verify_which_model_in_path +from dwi_ml.models.projects.transformer_models import find_transformer_class +from dwi_ml.training.projects.transformer_trainer import TransformerTrainer +from scilpy.io.utils import add_verbose_arg + +from dwi_ml.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler + + +def prepare_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('experiment_path', + help='Name for the experiment.') + p.add_argument('out_experiment', + help="Name of the fixed experiment.") + p.add_argument('--new_hdf5', + help="Required only if previous hdf5 has been moved.") + + add_verbose_arg(p) + + return p + + +def _replace(params, old_key, new_key): + if old_key in params: + logging.warning("Replacing old key {} with new key {}" + .format(old_key, new_key)) + params[new_key] = params[old_key] + del params[old_key] + else: + logging.warning("Expected to find deprecated key {} (to be replaced " + "by {}), but did not find it. Skipping." + .format(old_key, new_key)) + return params + + +def fix_deprecated_transformers_params(params): + # deleted start_from_copy_prev + if 'start_from_copy_prev' in params: + logging.warning("Deleting option start_from_copy_prev") + del params['start_from_copy_prev'] + + # deleted options compress_loss, weight_loss_with_angle (in dg) + if 'compress_loss' in params['dg_args']: + logging.warning("Deleting option compress_loss") + del params['dg_args']['compress_loss'] + if 'compress_eps' in params['dg_args']: + logging.warning("Deleting option compress_eps") + del params['dg_args']['compress_eps'] + if 'weight_loss_with_angle' in params['dg_args']: + logging.warning("Deleting option weight_loss_with_angle") + del params['dg_args']['weight_loss_with_angle'] + + return params + + +def fix_other_checkpoint_params(checkpoint_state): + """ + Updating non-specific params (trainer, batch loader, etc) + """ + # Nothing to do for Transformer, I had not used it before updates. + # See l2t_update_deprecated_exp if I have issues. + + return checkpoint_state + + +def fix_model_parameters_json_file(args): + """ + Updating this specific model's params (Transformers) + """ + + # 1) Loading params from checkpoint's latest model + model_dir = os.path.join(args.experiment_path, 'checkpoint', 'model') + model_type = verify_which_model_in_path(model_dir) + print("Model's class: {}".format(model_type)) + cls = find_transformer_class(model_type) + params = cls._load_params(model_dir) + + # 2) Loading params from best model + model_dir = os.path.join(args.experiment_path, 'best_model') + params2 = cls._load_params(model_dir) + + # Verifying that they fit + assert params == params2, ("Unexpected error. Parameters in the " + "checkpoint dir and in the best_model dir " + "should be the same. Did you modify the " + "parameters.json files?\n" + "Checkpoint params: \n" + "{}\n" + "--------------------" + "Best model params: \n" + "{}" + .format(format_dict_to_str(params), + format_dict_to_str(params2))) + del params2 + logging.debug("Loaded params:\n{}".format(format_dict_to_str(params))) + + # 3) Fixing params + params = fix_deprecated_transformers_params(params) + print("\n\n----------------Fixed the model parameters ----------------\n" + "Reformated model's params:\n " + format_dict_to_str(params)) + + # 4) Save fixed params in both parameters files + fixed_checkpoint_model_dir = os.path.join( + args.out_experiment, "checkpoint", "model") + params_in_checkpoint = os.path.join( + fixed_checkpoint_model_dir, "parameters.json") + with open(params_in_checkpoint, 'w') as json_file: + json_file.write(json.dumps(params, indent=4, separators=(',', ': '))) + fixed_best_model_dir = os.path.join(args.out_experiment, "best_model") + params_in_best_model = os.path.join(fixed_best_model_dir, + "parameters.json") + with open(params_in_best_model, 'w') as json_file: + json_file.write(json.dumps(params, indent=4, separators=(',', ': '))) + + # Verify that both models can be loaded + _ = cls.load_model_from_params_and_state( + fixed_checkpoint_model_dir) + model = cls.load_model_from_params_and_state( + fixed_best_model_dir) + + return model + + +def fix_checkpoint(args, model): + # Fixing trainer + + # Loading checkpoint + experiments_path, experiment_name = os.path.split(args.experiment_path) + checkpoint_state = TransformerTrainer.load_params_from_checkpoint( + experiments_path, experiment_name) + + # Verify hdf5 + dataset_params = checkpoint_state['dataset_params']['training set'] + if not os.path.isfile(dataset_params['hdf5_file']): + if args.new_hdf5 is None: + raise ValueError("hdf5 file has been deleted or moved ({})\n" + "Please set a path to a new hdf5.") + else: + # Get the hdf5 + dataset = prepare_multisubjectdataset( + argparse.Namespace(**{'hdf5_file': args.new_hdf5, + 'lazy': True, + 'cache_size': 1})) + # Compare all values + for k, v in dataset_params.items(): + if k not in ['set_name', 'hdf5_file', 'lazy']: + assert dataset.training_set.__getattribute__(k) == v, \ + ("Value {} in old hdf5 (training set) was {} but is " + "{} in the new one!" + .format(k, v, + dataset.training_set.__getattribute__(k))) + assert dataset.validation_set.__getattribute__(k) == v, \ + ("Value {} in old hdf5 (validation set) was {} but is " + "{} in the new one!" + .format(k, v, + dataset.training_set.__getattribute__(k))) + + elif args.new_hdf5 is not None: + raise ValueError("We already have all required information from the " + "hdf5 at {}. We do not need a --new_hdf5.") + else: + # Ensure it was lazy + dataset_params['lazy'] = True + dataset = prepare_multisubjectdataset( + argparse.Namespace(**dataset_params)) + + # Fixing checkpoint + checkpoint_state = fix_other_checkpoint_params(checkpoint_state) + checkpoint_dir = os.path.join(args.out_experiment, "checkpoint") + torch.save(checkpoint_state, + os.path.join(checkpoint_dir, "checkpoint_state.pkl")) + + # Init stuff will succeed if ok. + batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint( + dataset, checkpoint_state['batch_sampler_params']) + batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( + dataset, model, checkpoint_state['batch_loader_params']) + experiments_path, experiment_name = os.path.split(args.out_experiment) + _ = TransformerTrainer.init_from_checkpoint( + model, experiments_path, experiment_name, + batch_sampler, batch_loader, + checkpoint_state, new_patience=None, new_max_epochs=None, + log_level='WARNING') + + +def main(): + p = prepare_arg_parser() + args = p.parse_args() + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + if not os.path.exists(args.experiment_path): + raise FileNotFoundError("Experiment not found ({})." + .format(args.experiment_path)) + if os.path.exists(args.out_experiment): + raise FileExistsError("Out experiment already exists! ({})." + .format(args.out_experiment)) + shutil.copytree(args.experiment_path, args.out_experiment) + + model = fix_model_parameters_json_file(args) + + if args.experiment_path != args.out_experiment: + fix_checkpoint(args, model) + + print("Out experiment {} should now be usable!" + .format(args.out_experiment)) + + +if __name__ == '__main__': + main() From e29112b0c4b31cf4d62cfe797c0fc5ce20a84e05 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 23 Mar 2026 11:44:49 -0400 Subject: [PATCH 2/4] Tracking changes to allow updating old models and re-use them. --- .../scil_score_ismrm_Renauld2023.sh | 15 +- src/dwi_ml/cli/l2t_update_deprecated_exp.py | 150 +++++++------ src/dwi_ml/cli/tt_update_deprecated_exp.py | 204 ++++++++++++------ .../models/projects/transformer_models.py | 13 -- src/dwi_ml/training/trainers.py | 2 +- 5 files changed, 242 insertions(+), 142 deletions(-) diff --git a/bash_utilities/scil_score_ismrm_Renauld2023.sh b/bash_utilities/scil_score_ismrm_Renauld2023.sh index 74ea7c68..6a067547 100644 --- a/bash_utilities/scil_score_ismrm_Renauld2023.sh +++ b/bash_utilities/scil_score_ismrm_Renauld2023.sh @@ -5,6 +5,10 @@ then echo "-----" echo "Usage: " echo ">> scil_score_ismrm_Renauld2023.sh tractogram out_dir scoring_data" + echo "If you are using an old version of scilpy (<2.2.0), the '.py' extensions" + echo "were required when calling python scripts" + echo "Then, usage becomes:" + echo ">> scil_score_ismrm_Renauld2023.sh tractogram out_dir scoring_data .py " echo "-----" exit fi @@ -12,6 +16,7 @@ fi tractogram=$1 out_dir=$2 scoring_data=$3 +ext=$4 config_file_segmentation=$scoring_data/config_file_segmentation.json config_file_tractometry=$scoring_data/config_file_tractometry.json @@ -31,30 +36,30 @@ fi echo '------------- SEGMENTATION ------------' -scil_tractogram_segment_with_ROI_and_score $tractogram $config_file_segmentation $out_dir --no_empty \ +scil_tractogram_segment_with_ROI_and_score$ext $tractogram $config_file_segmentation $out_dir --no_empty \ --gt_dir $scoring_data --reference $ref --json_prefix tmp_ --no_bbox_check; echo '------------- Merging CC sub-bundles ------------' CC_files=$(ls $out_dir/segmented_VB/CC* 2> /dev/null) if [ "$CC_files" != '' ] then - scil_tractogram_math lazy_concatenate $CC_files $out_dir/segmented_VB/CC_VS.trk; + scil_tractogram_math$ext lazy_concatenate $CC_files $out_dir/segmented_VB/CC_VS.trk; fi echo '------------- Merging ICP left sub-bundles ------------' ICP_left_files=$(ls $out_dir/segmented_VB/ICP_left* 2> /dev/null) if [ "$ICP_left_files" != '' ] then - scil_tractogram_math lazy_concatenate $ICP_left_files $out_dir/segmented_VB/ICP_left_VS.trk; + scil_tractogram_math$ext lazy_concatenate $ICP_left_files $out_dir/segmented_VB/ICP_left_VS.trk; fi echo '------------- Merging ICP right sub-bundles ------------' ICP_right_files=$(ls $out_dir/segmented_VB/ICP_right* 2> /dev/null) if [ "$ICP_right_files" != '' ] then - scil_tractogram_math lazy_concatenate $ICP_right_files $out_dir/segmented_VB/ICP_right_VS.trk; + scil_tractogram_math$ext lazy_concatenate $ICP_right_files $out_dir/segmented_VB/ICP_right_VS.trk; fi echo '------------- FINAL SCORING ------------' -scil_bundle_score_many_bundles_one_tractogram $config_file_tractometry $out_dir \ +scil_bundle_score_many_bundles_one_tractogram$ext $config_file_tractometry $out_dir \ --gt_dir $scoring_data --reference $ref --no_bbox_check -v cat $out_dir/results.json \ No newline at end of file diff --git a/src/dwi_ml/cli/l2t_update_deprecated_exp.py b/src/dwi_ml/cli/l2t_update_deprecated_exp.py index 833078d8..69b24bc1 100644 --- a/src/dwi_ml/cli/l2t_update_deprecated_exp.py +++ b/src/dwi_ml/cli/l2t_update_deprecated_exp.py @@ -32,7 +32,7 @@ def prepare_arg_parser(): help='Name for the experiment.') p.add_argument('out_experiment', help="Name of the fixed experiment.") - p.add_argument('--new_hdf5', + p.add_argument('--hdf5_path', help="Required only if previous hdf5 has been moved.") add_verbose_arg(p) @@ -46,14 +46,17 @@ def _replace(params, old_key, new_key): .format(old_key, new_key)) params[new_key] = params[old_key] del params[old_key] - else: - logging.warning("Expected to find deprecated key {} (to be replaced " - "by {}), but did not find it. Skipping." - .format(old_key, new_key)) return params -def fix_deprecated_learn2track_params(params): +def _remove(params, old_key): + if old_key in params: + logging.warning("Deleting option {}".format(old_key)) + del params[old_key] + return params + + +def fix_model_params(params): # embedding_size --> embedded_size params = _replace(params, 'prev_dirs_embedding_size', 'prev_dirs_embedded_size') @@ -69,20 +72,12 @@ def fix_deprecated_learn2track_params(params): del params['input_embedding_size_ratio'] # deleted start_from_copy_prev - if 'start_from_copy_prev' in params: - logging.warning("Deleting option start_from_copy_prev") - del params['start_from_copy_prev'] + params = _remove(params, 'start_from_copy_prev') # deleted options compress_loss, weight_loss_with_angle (in dg) - if 'compress_loss' in params['dg_args']: - logging.warning("Deleting option compress_loss") - del params['dg_args']['compress_loss'] - if 'compress_eps' in params['dg_args']: - logging.warning("Deleting option compress_eps") - del params['dg_args']['compress_eps'] - if 'weight_loss_with_angle' in params['dg_args']: - logging.warning("Deleting option weight_loss_with_angle") - del params['dg_args']['weight_loss_with_angle'] + params['dg_args'] = _remove(params['dg_args'], 'compress_loss') + params['dg_args'] = _remove(params['dg_args'], 'compress_eps') + params['dg_args'] = _remove(params['dg_args'], 'weight_loss_with_angle') # Added cnn options if 'nb_cnn_filters' not in params: @@ -107,16 +102,28 @@ def fix_deprecated_learn2track_params(params): return params -def fix_other_checkpoint_params(checkpoint_state): +def fix_model_state(model_dir): + print("\n--------------- In model state ():".format(model_dir)) + model_state = Learn2TrackModel._load_state(model_dir) + model_state = _replace(model_state, 'input_embedding.linear.weight', + 'input_embedding_layer.linear.weight') + model_state = _replace(model_state, 'input_embedding.linear.bias', + 'input_embedding_layer.linear.bias') + torch.save(model_state, os.path.join(model_dir, "model_state.pkl")) + + +def fix_checkpoint_params(checkpoint_state): """ Updating non-specific params (trainer, batch loader, etc) """ - # 1) Dataset params: better use a --new_hdf5 to fix. + # 1) Dataset params: better use a --hdf5_path to fix. - # 2) Trainer params : nb_steps --> nb_segments + # 2) Trainer params : nb_steps --> nb_segments, no more lr_decrease_params checkpoint_state['params_for_init'] = _replace( checkpoint_state['params_for_init'], 'tracking_phase_nb_steps_init', 'tracking_phase_nb_segments_init') + checkpoint_state['params_for_init'] = _remove( + checkpoint_state['params_for_init'], 'lr_decrease_params') # 3) Monitors: new var ever_max, ever_min for k in checkpoint_state['current_states'].keys(): @@ -141,7 +148,16 @@ def fix_other_checkpoint_params(checkpoint_state): return checkpoint_state -def fix_model_parameters_json_file(args): +def fix_checkpoint_rng_state(checkpoint_state): + print("\n--------------- In checkpoint state:") + + assert 'torch_cuda_state' in checkpoint_state['current_states'] + checkpoint_state['current_states']['torch_cuda_state'] = None + return checkpoint_state + + + +def load_both_models_checkpoint_best_and_fix(args): """ Updating this specific model's params (learn2track) """ @@ -167,10 +183,10 @@ def fix_model_parameters_json_file(args): .format(format_dict_to_str(params), format_dict_to_str(params2))) del params2 - logging.debug("Loaded params:\n{}".format(format_dict_to_str(params))) + logging.info("Loaded params:\n{}".format(format_dict_to_str(params))) # 3) Fixing params - params = fix_deprecated_learn2track_params(params) + params = fix_model_params(params) print("\n\n----------------Fixed the model parameters ----------------\n" "Reformated model's params:\n " + format_dict_to_str(params)) @@ -187,6 +203,11 @@ def fix_model_parameters_json_file(args): with open(params_in_best_model, 'w') as json_file: json_file.write(json.dumps(params, indent=4, separators=(',', ': '))) + # 5. Fixing model state + fix_model_state(fixed_checkpoint_model_dir) + fix_model_state(fixed_best_model_dir) + + # Verify that both models can be loaded _ = Learn2TrackModel.load_model_from_params_and_state( fixed_checkpoint_model_dir) @@ -196,7 +217,7 @@ def fix_model_parameters_json_file(args): return model -def fix_checkpoint(args, model): +def load_checkpoint_and_fix(args, model): # Fixing trainer # Loading checkpoint @@ -206,41 +227,22 @@ def fix_checkpoint(args, model): # Verify hdf5 dataset_params = checkpoint_state['dataset_params']['training set'] - if not os.path.isfile(dataset_params['hdf5_file']): - if args.new_hdf5 is None: - raise ValueError("hdf5 file has been deleted or moved ({})\n" - "Please set a path to a new hdf5.") - else: - # Get the hdf5 - dataset = prepare_multisubjectdataset( - argparse.Namespace(**{'hdf5_file': args.new_hdf5, - 'lazy': True, - 'cache_size': 1})) - # Compare all values - for k, v in dataset_params.items(): - if k not in ['set_name', 'hdf5_file', 'lazy']: - assert dataset.training_set.__getattribute__(k) == v, \ - ("Value {} in old hdf5 (training set) was {} but is " - "{} in the new one!" - .format(k, v, - dataset.training_set.__getattribute__(k))) - assert dataset.validation_set.__getattribute__(k) == v, \ - ("Value {} in old hdf5 (validation set) was {} but is " - "{} in the new one!" - .format(k, v, - dataset.training_set.__getattribute__(k))) - - elif args.new_hdf5 is not None: - raise ValueError("We already have all required information from the " - "hdf5 at {}. We do not need a --new_hdf5.") + if args.hdf5_path is None and not os.path.isfile(dataset_params['hdf5_file']): + raise ValueError("hdf5 file has been deleted or moved ({})\n" + "Please set a path to a new hdf5.") + elif args.hdf5_path is not None: + # Get the hdf5 + dataset = prepare_multisubjectdataset( + argparse.Namespace(**{'hdf5_file': args.hdf5_path, + 'lazy': True, + 'cache_size': 1})) else: - # Ensure it was lazy dataset_params['lazy'] = True dataset = prepare_multisubjectdataset( argparse.Namespace(**dataset_params)) # Fixing checkpoint - checkpoint_state = fix_other_checkpoint_params(checkpoint_state) + checkpoint_state = fix_checkpoint_params(checkpoint_state) checkpoint_dir = os.path.join(args.out_experiment, "checkpoint") torch.save(checkpoint_state, os.path.join(checkpoint_dir, "checkpoint_state.pkl")) @@ -251,11 +253,37 @@ def fix_checkpoint(args, model): batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( dataset, model, checkpoint_state['batch_loader_params']) experiments_path, experiment_name = os.path.split(args.out_experiment) - _ = Learn2TrackTrainer.init_from_checkpoint( - model, experiments_path, experiment_name, - batch_sampler, batch_loader, - checkpoint_state, new_patience=None, new_max_epochs=None, - log_level='WARNING') + if experiments_path == '': + experiments_path = './' + + try: + _ = Learn2TrackTrainer.init_from_checkpoint( + model, experiments_path, experiment_name, + batch_sampler, batch_loader, + checkpoint_state, new_patience=None, new_max_epochs=None, + log_level='WARNING') + except RuntimeError as e: + if 'RNG state' in str(e): + logging.warning("RNG error in the checkpoint, due to pytorch " + "version, probably. Will ignore the RNG state. " + "Will probably not really influence anything") + checkpoint_path = os.path.join( + experiments_path, experiment_name, "checkpoint", + "checkpoint_state.pkl") + checkpoint_state = torch.load(checkpoint_path, weights_only=False) + checkpoint_state = fix_checkpoint_rng_state(checkpoint_state) + + print("Saving new state as ", checkpoint_path) + torch.save(checkpoint_state, checkpoint_path) + checkpoint_state = torch.load(checkpoint_path, weights_only=False) + + _ = Learn2TrackTrainer.init_from_checkpoint( + model, experiments_path, experiment_name, + batch_sampler, batch_loader, + checkpoint_state, new_patience=None, new_max_epochs=None, + log_level='WARNING') + else: + raise RuntimeError(e) def main(): @@ -273,8 +301,8 @@ def main(): .format(args.out_experiment)) shutil.copytree(args.experiment_path, args.out_experiment) - model = fix_model_parameters_json_file(args) - fix_checkpoint(args, model) + model = load_both_models_checkpoint_best_and_fix(args) + load_checkpoint_and_fix(args, model) print("Out experiment {} should now be usable!" .format(args.out_experiment)) diff --git a/src/dwi_ml/cli/tt_update_deprecated_exp.py b/src/dwi_ml/cli/tt_update_deprecated_exp.py index c12fed18..c097beae 100644 --- a/src/dwi_ml/cli/tt_update_deprecated_exp.py +++ b/src/dwi_ml/cli/tt_update_deprecated_exp.py @@ -13,6 +13,9 @@ import os import shutil +import numpy as np + +from dwi_ml.training.utils.monitoring import BatchHistoryMonitor import torch from dwi_ml.io_utils import verify_which_model_in_path from dwi_ml.models.projects.transformer_models import find_transformer_class @@ -32,7 +35,7 @@ def prepare_arg_parser(): help='Name for the experiment.') p.add_argument('out_experiment', help="Name of the fixed experiment.") - p.add_argument('--new_hdf5', + p.add_argument('--hdf5_path', help="Required only if previous hdf5 has been moved.") add_verbose_arg(p) @@ -46,44 +49,111 @@ def _replace(params, old_key, new_key): .format(old_key, new_key)) params[new_key] = params[old_key] del params[old_key] - else: - logging.warning("Expected to find deprecated key {} (to be replaced " - "by {}), but did not find it. Skipping." - .format(old_key, new_key)) return params -def fix_deprecated_transformers_params(params): +def _remove(params, old_key): + if old_key in params: + logging.warning("Deleting option {}".format(old_key)) + del params[old_key] + return params + + +def fix_model_params(params): + print("\n--------------- In model params:") + # deleted start_from_copy_prev - if 'start_from_copy_prev' in params: - logging.warning("Deleting option start_from_copy_prev") - del params['start_from_copy_prev'] + params = _remove(params, 'start_from_copy_prev') # deleted options compress_loss, weight_loss_with_angle (in dg) - if 'compress_loss' in params['dg_args']: - logging.warning("Deleting option compress_loss") - del params['dg_args']['compress_loss'] - if 'compress_eps' in params['dg_args']: - logging.warning("Deleting option compress_eps") - del params['dg_args']['compress_eps'] - if 'weight_loss_with_angle' in params['dg_args']: - logging.warning("Deleting option weight_loss_with_angle") - del params['dg_args']['weight_loss_with_angle'] + params['dg_args'] = _remove(params['dg_args'], 'compress_loss') + params['dg_args'] = _remove(params['dg_args'], 'compress_eps') + params['dg_args'] = _remove(params['dg_args'], 'weight_loss_with_angle') + + # ------ + # Even older models: + # ------ + params = _replace(params, 'embedding_key_x', 'input_embedding_key') + params = _replace(params, 'd_model', 'input_embedded_size') # NOT the same if TTST or TTO model, but I don't think I have any... + + # Added cnn options + if 'nb_cnn_filters' not in params: + logging.warning("Adding options for CNN") + params['nb_cnn_filters'] = None + if 'kernel_size' not in params: + logging.warning("Adding options for CNN") + params['kernel_size'] = None + + # Neighborhood management modified + r = params['neighborhood_radius'] + if isinstance(r, list): + logging.warning("Updating neighborhood radius management") + params['neighborhood_radius'] = len(r) + params['neighborhood_resolution'] = r[0] + if len(r) > 1: + if not np.all(np.diff(r) == r[0]): + raise ValueError("Now, neighborhood must have the same " + "resolution between each layer of " + "neighborhood. But got: {}".format(r)) return params +def fix_model_state(cls, model_dir): + print("\n--------------- In model state ():".format(model_dir)) + model_state = cls._load_state(model_dir) + model_state = _replace(model_state, 'embedding_layer_x.linear.weight', + 'input_embedding_layer.linear.weight') + model_state = _replace(model_state, 'embedding_layer_x.linear.bias', + 'input_embedding_layer.linear.bias') + torch.save(model_state, os.path.join(model_dir, "model_state.pkl")) + -def fix_other_checkpoint_params(checkpoint_state): +def fix_checkpoint_params(checkpoint_state): """ Updating non-specific params (trainer, batch loader, etc) """ # Nothing to do for Transformer, I had not used it before updates. # See l2t_update_deprecated_exp if I have issues. + print("\n--------------- In checkpoint params:") + # 1) Dataset params: better use a --hdf5_path to fix. + + # 2) Trainer params : nb_steps --> nb_segments, no more lr_decrease_params + checkpoint_state['params_for_init'] = _replace( + checkpoint_state['params_for_init'], 'tracking_phase_nb_steps_init', + 'tracking_phase_nb_segments_init') + checkpoint_state['params_for_init'] = _remove( + checkpoint_state['params_for_init'], 'lr_decrease_params') + + # 3) Monitors: new var ever_max, ever_min + for k in checkpoint_state['current_states'].keys(): + if isinstance(checkpoint_state['current_states'][k], dict) and \ + 'average_per_epoch' in checkpoint_state['current_states'][k].keys(): + # Found a monitor + if 'ever_min' not in checkpoint_state['current_states'][k]: + logging.warning("Setting false min, max for monitor {}. " + # But not sure this is ever used.... TODO + .format(k)) + checkpoint_state['current_states'][k]['ever_min'] = -np.inf + checkpoint_state['current_states'][k]['ever_max'] = np.inf + + # New monitor + if 'unclipped_grad_norm_monitor_state' not in checkpoint_state['current_states']: + logging.warning("Adding a new monitor (unclipped_grad_norm)") + monitor = BatchHistoryMonitor('unclipped_grad_norm_monitor', + weighted=False) + checkpoint_state['current_states'][monitor.name + '_state'] = monitor.get_state() return checkpoint_state +def fix_checkpoint_rng_state(checkpoint_state): + print("\n--------------- In checkpoint state:") -def fix_model_parameters_json_file(args): + assert 'torch_cuda_state' in checkpoint_state['current_states'] + checkpoint_state['current_states']['torch_cuda_state'] = None + return checkpoint_state + + +def load_both_models_checkpoint_best_and_fix(args): """ Updating this specific model's params (Transformers) """ @@ -112,10 +182,10 @@ def fix_model_parameters_json_file(args): .format(format_dict_to_str(params), format_dict_to_str(params2))) del params2 - logging.debug("Loaded params:\n{}".format(format_dict_to_str(params))) + logging.info("Loaded params:\n{}".format(format_dict_to_str(params))) # 3) Fixing params - params = fix_deprecated_transformers_params(params) + params = fix_model_params(params) print("\n\n----------------Fixed the model parameters ----------------\n" "Reformated model's params:\n " + format_dict_to_str(params)) @@ -132,6 +202,11 @@ def fix_model_parameters_json_file(args): with open(params_in_best_model, 'w') as json_file: json_file.write(json.dumps(params, indent=4, separators=(',', ': '))) + # 5. Fixing model state + fix_model_state(cls, fixed_checkpoint_model_dir) + fix_model_state(cls, fixed_best_model_dir) + + # Verify that both models can be loaded _ = cls.load_model_from_params_and_state( fixed_checkpoint_model_dir) @@ -141,7 +216,7 @@ def fix_model_parameters_json_file(args): return model -def fix_checkpoint(args, model): +def load_checkpoint_and_fix(args, model): # Fixing trainer # Loading checkpoint @@ -151,41 +226,22 @@ def fix_checkpoint(args, model): # Verify hdf5 dataset_params = checkpoint_state['dataset_params']['training set'] - if not os.path.isfile(dataset_params['hdf5_file']): - if args.new_hdf5 is None: - raise ValueError("hdf5 file has been deleted or moved ({})\n" - "Please set a path to a new hdf5.") - else: - # Get the hdf5 - dataset = prepare_multisubjectdataset( - argparse.Namespace(**{'hdf5_file': args.new_hdf5, - 'lazy': True, - 'cache_size': 1})) - # Compare all values - for k, v in dataset_params.items(): - if k not in ['set_name', 'hdf5_file', 'lazy']: - assert dataset.training_set.__getattribute__(k) == v, \ - ("Value {} in old hdf5 (training set) was {} but is " - "{} in the new one!" - .format(k, v, - dataset.training_set.__getattribute__(k))) - assert dataset.validation_set.__getattribute__(k) == v, \ - ("Value {} in old hdf5 (validation set) was {} but is " - "{} in the new one!" - .format(k, v, - dataset.training_set.__getattribute__(k))) - - elif args.new_hdf5 is not None: - raise ValueError("We already have all required information from the " - "hdf5 at {}. We do not need a --new_hdf5.") + if args.hdf5_path is None and not os.path.isfile(dataset_params['hdf5_file']): + raise ValueError("hdf5 file has been deleted or moved ({})\n" + "Please set a path to a new hdf5.") + elif args.hdf5_path is not None: + # Get the hdf5 + dataset = prepare_multisubjectdataset( + argparse.Namespace(**{'hdf5_file': args.hdf5_path, + 'lazy': True, + 'cache_size': 1})) else: - # Ensure it was lazy dataset_params['lazy'] = True dataset = prepare_multisubjectdataset( argparse.Namespace(**dataset_params)) # Fixing checkpoint - checkpoint_state = fix_other_checkpoint_params(checkpoint_state) + checkpoint_state = fix_checkpoint_params(checkpoint_state) checkpoint_dir = os.path.join(args.out_experiment, "checkpoint") torch.save(checkpoint_state, os.path.join(checkpoint_dir, "checkpoint_state.pkl")) @@ -196,11 +252,37 @@ def fix_checkpoint(args, model): batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( dataset, model, checkpoint_state['batch_loader_params']) experiments_path, experiment_name = os.path.split(args.out_experiment) - _ = TransformerTrainer.init_from_checkpoint( - model, experiments_path, experiment_name, - batch_sampler, batch_loader, - checkpoint_state, new_patience=None, new_max_epochs=None, - log_level='WARNING') + if experiments_path == '': + experiments_path = './' + + try: + _ = TransformerTrainer.init_from_checkpoint( + model, experiments_path, experiment_name, + batch_sampler, batch_loader, + checkpoint_state, new_patience=None, new_max_epochs=None, + log_level='WARNING') + except RuntimeError as e: + if 'RNG state' in str(e): + logging.warning("RNG error in the checkpoint, due to pytorch " + "version, probably. Will ignore the RNG state. " + "Will probably not really influence anything") + checkpoint_path = os.path.join( + experiments_path, experiment_name, "checkpoint", + "checkpoint_state.pkl") + checkpoint_state = torch.load(checkpoint_path, weights_only=False) + checkpoint_state = fix_checkpoint_rng_state(checkpoint_state) + + print("Saving new state as ", checkpoint_path) + torch.save(checkpoint_state, checkpoint_path) + checkpoint_state = torch.load(checkpoint_path, weights_only=False) + + _ = TransformerTrainer.init_from_checkpoint( + model, experiments_path, experiment_name, + batch_sampler, batch_loader, + checkpoint_state, new_patience=None, new_max_epochs=None, + log_level='WARNING') + else: + raise RuntimeError(e) def main(): @@ -208,7 +290,7 @@ def main(): args = p.parse_args() # General logging (ex, scilpy: Warning) - logging.getLogger().setLevel(level=logging.WARNING) + logging.getLogger().setLevel(args.verbose) if not os.path.exists(args.experiment_path): raise FileNotFoundError("Experiment not found ({})." @@ -218,10 +300,8 @@ def main(): .format(args.out_experiment)) shutil.copytree(args.experiment_path, args.out_experiment) - model = fix_model_parameters_json_file(args) - - if args.experiment_path != args.out_experiment: - fix_checkpoint(args, model) + model = load_both_models_checkpoint_best_and_fix(args) + load_checkpoint_and_fix(args, model) print("Out experiment {} should now be usable!" .format(args.out_experiment)) diff --git a/src/dwi_ml/models/projects/transformer_models.py b/src/dwi_ml/models/projects/transformer_models.py index bbc14aeb..6a9d650d 100644 --- a/src/dwi_ml/models/projects/transformer_models.py +++ b/src/dwi_ml/models/projects/transformer_models.py @@ -277,19 +277,6 @@ def params_for_checkpoint(self): return p - @classmethod - def _load_params(cls, model_dir): - params = super()._load_params(model_dir) - - # d_model now a property method. - if 'd_model' in params: - if isinstance(cls, TransformerSrcOnlyModel): - params['input_embedded_size'] = params['d_model'] - - del params['d_model'] - - return params - def set_context(self, context): # Training, validation: Used by trainer. Nothing special. # Tracking: Used by tracker. Returns only the last point. diff --git a/src/dwi_ml/training/trainers.py b/src/dwi_ml/training/trainers.py index 5a68c167..52e5bd52 100644 --- a/src/dwi_ml/training/trainers.py +++ b/src/dwi_ml/training/trainers.py @@ -518,7 +518,7 @@ def _update_states_from_checkpoint(self, current_states): self.batch_loader.np_rng.set_state(current_states['loader_np_rng_state']) # - torch torch.set_rng_state(current_states['torch_rng_state']) - if self.use_gpu: + if self.use_gpu and current_states['torch_cuda_state'] is not None: torch.cuda.set_rng_state(current_states['torch_cuda_state']) # B. Current epoch From ed959c1a34c8ccac4bc84ca83bac0b8fb802a8cd Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 23 Mar 2026 12:04:44 -0400 Subject: [PATCH 3/4] Remove methods not used anymore --- pyproject.toml | 1 - .../cli/dwiml_compute_loss_copy_previous.py | 77 ------------------- .../tests/test_compute_loss_copy_previous.py | 42 ---------- .../models/projects/copy_previous_dirs.py | 63 --------------- .../models/projects/learn2track_model.py | 55 ------------- .../models/projects/transformer_models.py | 49 ------------ 6 files changed, 287 deletions(-) delete mode 100644 src/dwi_ml/cli/dwiml_compute_loss_copy_previous.py delete mode 100644 src/dwi_ml/cli/tests/test_compute_loss_copy_previous.py delete mode 100644 src/dwi_ml/models/projects/copy_previous_dirs.py diff --git a/pyproject.toml b/pyproject.toml index 5a994414..2827838c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,6 @@ ae_train_model = "dwi_ml.cli.ae_train_model:main" dwiml_compute_connectivity_matrix_from_blocs = "dwi_ml.cli.dwiml_compute_connectivity_matrix_from_blocs:main" dwiml_compute_connectivity_matrix_from_labels = "dwi_ml.cli.dwiml_compute_connectivity_matrix_from_labels:main" dwiml_compute_connectivity_score = "dwi_ml.cli.dwiml_compute_connectivity_score:main" -dwiml_compute_loss_copy_previous = "dwi_ml.cli.dwiml_compute_loss_copy_previous:main" dwiml_create_hdf5_dataset = "dwi_ml.cli.dwiml_create_hdf5_dataset:main" dwiml_divide_volume_into_blocs = "dwi_ml.cli.dwiml_divide_volume_into_blocs:main" dwiml_hdf5_extract_data = "dwi_ml.cli.dwiml_hdf5_extract_data:main" diff --git a/src/dwi_ml/cli/dwiml_compute_loss_copy_previous.py b/src/dwi_ml/cli/dwiml_compute_loss_copy_previous.py deleted file mode 100644 index d6dae648..00000000 --- a/src/dwi_ml/cli/dwiml_compute_loss_copy_previous.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -One difficulty of choosing a good loss function for tractography is that -streamlines have the particularity of being smooth. - -Printing the average loss function for a given dataset when we simply copy the -previous direction. - - Target := SFT.streamlines' directions[1:] - Y := Previous directions. - loss = DirectionGetter(Target, Y) -""" -import argparse -import logging - -import torch.nn.functional - -from dwi_ml.io_utils import add_resample_or_compress_arg -from dwi_ml.models.projects.copy_previous_dirs import CopyPrevDirModel -from dwi_ml.models.utils.direction_getters import add_direction_getter_args, \ - check_args_direction_getter -from dwi_ml.testing.testers import TesterWithDirectionGetter -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 -from dwi_ml.testing.visu_loss import run_all_visu_loss -from dwi_ml.testing.visu_loss_utils import prepare_args_visu_loss, visu_checks - -CHOICES = ['cosine-regression', 'l2-regression', 'sphere-classification', - 'smooth-sphere-classification', 'cosine-plus-l2-regression'] - - -def prepare_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - add_args_testing_subj_hdf5(p, ask_input_group=False, - ask_streamlines_group=False) - prepare_args_visu_loss(p) - p.add_argument('--skip_first_point', action='store_true', - help="If set, do not compute the loss at the first point " - "of the streamline. \nElse (default) compute it with " - "previous dir = 0.") - add_resample_or_compress_arg(p) - add_direction_getter_args(p) - - return p - - -def main(): - p = prepare_arg_parser() - args = p.parse_args() - logging.getLogger().setLevel(level=args.verbose) - - # Checks - if args.out_dir is None: - p.error("Please specify out_dir, as there is not experiment path for " - "this fake experiment.") - names = visu_checks(args, p) - - # Device - device = (torch.device('cuda') if torch.cuda.is_available() and - args.use_gpu else None) - - # 1. Prepare fake model - dg_args = check_args_direction_getter(args) - model = CopyPrevDirModel(args.dg_key, dg_args, args.skip_first_point, - args.step_size, args.compress_th) - model.set_context('visu') - - # 2. Load data through the tester - tester = TesterWithDirectionGetter(model, args.subj_id, args.hdf5_file, - args.subset, args.batch_size, device) - - run_all_visu_loss(tester, model, args, names) - - -if __name__ == '__main__': - main() diff --git a/src/dwi_ml/cli/tests/test_compute_loss_copy_previous.py b/src/dwi_ml/cli/tests/test_compute_loss_copy_previous.py deleted file mode 100644 index 1f62f70c..00000000 --- a/src/dwi_ml/cli/tests/test_compute_loss_copy_previous.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -import os -import pytest - -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data -from dwi_ml.unit_tests.utils.expected_values import TEST_EXPECTED_SUBJ_NAMES, \ - TEST_EXPECTED_STREAMLINE_GROUPS - -data_dir = fetch_testing_data() - - -def test_help_option(script_runner): - ret = script_runner.run('dwiml_compute_loss_copy_previous', '--help') - assert ret.success - - -@pytest.fixture(scope="session") -def experiments_path(tmp_path_factory): - experiments_path = tmp_path_factory.mktemp("experiment_copy_prev") - return str(experiments_path) - - -def test_running(script_runner, experiments_path): - hdf5_file = os.path.join(data_dir, 'hdf5_file.hdf5') - subj_id = TEST_EXPECTED_SUBJ_NAMES[0] - streamline_group_name = TEST_EXPECTED_STREAMLINE_GROUPS[0] - - prefix = 'fornix_' - out_dir = os.path.join(experiments_path, 'test_visu') - ret = script_runner.run('dwiml_compute_loss_copy_previous', - hdf5_file, subj_id, - '--streamlines_group', streamline_group_name, - '--out_prefix', prefix, - '--out_dir', out_dir, - '--subset', 'training', - '--save_colored_tractogram', - '--save_colored_best_and_worst', - '--save_displacement', '1', - '--batch_size', '100', - '--min_range', '-1', '--max_range', '1') - assert ret.success diff --git a/src/dwi_ml/models/projects/copy_previous_dirs.py b/src/dwi_ml/models/projects/copy_previous_dirs.py deleted file mode 100644 index e9d58a1c..00000000 --- a/src/dwi_ml/models/projects/copy_previous_dirs.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import List - -import torch -from torch.distributions import Categorical - -from dwi_ml.data.processing.streamlines.post_processing import \ - compute_directions -from dwi_ml.data.processing.streamlines.sos_eos_management import \ - convert_dirs_to_class -from dwi_ml.models.main_models import ModelWithDirectionGetter - - -class CopyPrevDirModel(ModelWithDirectionGetter): - def __init__(self, dg_key: str = 'cosine-regression', dg_args: dict = None, - skip_first_point=False, step_size=None, compress_lines=None): - """ - Fake model, not very useful. Used to test value of a loss when copying - the previous direction. - """ - super().__init__(dg_key=dg_key, dg_args=dg_args, - experiment_name='TEST', - step_size=step_size, compress_lines=compress_lines) - - # Fake input size, we won't use the forward method. - self.instantiate_direction_getter(dg_input_size=1) - self.skip_first_point = skip_first_point - - def forward(self, inputs, streamlines, **kw): - # Prepare targets and outputs: both out of directions. - # Similar to direction_getter.prepare_targets: - # A) Compute directions. Shift + add a fake first direction. - - # Ignoring inputs. - outputs = compute_directions(streamlines) - - # Output at first position will be 0. We will add it later. - # Faking outputs based on direction getter format - if 'classification' in self.dg_key: - outputs = convert_dirs_to_class( - outputs, self.direction_getter.torch_sphere, add_sos=False, - add_eos=self.direction_getter.add_eos, to_one_hot=True) - outputs = [Categorical(probs=out).logits for out in outputs] - elif not ('regression' in self.dg_key): - raise NotImplementedError - - if not self.skip_first_point: - # Add fake first point. We don't know what output to give. Just - # using zeros everywhere. - outputs = [torch.nn.functional.pad(out, [0, 0, 1, 0]) - for out in outputs] - - return outputs - - def compute_loss(self, model_outputs: List[torch.Tensor], - target_streamlines: List[torch.Tensor], - average_results=True, return_eos_probs=False): - if self.skip_first_point: - target_streamlines = [t[1:] for t in target_streamlines] - - return self.direction_getter.compute_loss( - model_outputs, target_streamlines, average_results, - return_eos_probs) diff --git a/src/dwi_ml/models/projects/learn2track_model.py b/src/dwi_ml/models/projects/learn2track_model.py index ee89db58..0ab6a1b7 100644 --- a/src/dwi_ml/models/projects/learn2track_model.py +++ b/src/dwi_ml/models/projects/learn2track_model.py @@ -275,7 +275,6 @@ def forward(self, x: List[torch.tensor], # ==== 0. Previous dirs. n_prev_dirs = None - copy_prev_dir = 0.0 if self.nb_previous_dirs > 0: dirs = compute_directions(input_streamlines) if self.normalize_prev_dirs: @@ -376,60 +375,6 @@ def forward(self, x: List[torch.tensor], else: return x - def copy_prev_dir(self, dirs): - if 'regression' in self.dg_key: - # Regression: The latest previous dir will be used as skip - # connection on the output. - # Either take dirs and add [0, 0, 0] at each first position. - # Or use pre-computed: - copy_prev_dir = [torch.nn.functional.pad(cp, [0, 0, 1, 0]) - for cp in dirs] - copy_prev_dir = pack_sequence(copy_prev_dir) - copy_prev_dir = copy_prev_dir.data - elif self.dg_key == 'sphere-classification': - # Converting the input directions into classes the same way as - # during loss, but convert to one-hot. - # The first previous dir (0) converts to index 0. - if self.context == 'tracking': - if dirs[0].shape[0] == 0: - copy_prev_dir = torch.zeros( - len(dirs), - len(self.direction_getter.torch_sphere.vertices), - device=self.device) - else: - # Take only the last point. - dirs = [d[-1, :][None, :] for d in dirs] - copy_prev_dir = convert_dirs_to_class( - dirs, self.direction_getter.torch_sphere, - smooth_labels=False, add_sos=False, add_eos=False, - to_one_hot=True) - copy_prev_dir = pack_sequence(copy_prev_dir) - else: - # Take all points. - copy_prev_dir = convert_dirs_to_class( - dirs, self.direction_getter.torch_sphere, - smooth_labels=False, add_sos=False, add_eos=False, - to_one_hot=True) - - # Add zeros as previous dir at the first position - copy_prev_dir = [torch.nn.functional.pad(cp, [0, 0, 1, 0]) - for cp in copy_prev_dir] - copy_prev_dir = pack_sequence(copy_prev_dir) - - # Making the one from one-hot important for the sigmoid. - copy_prev_dir = copy_prev_dir.data * 6.0 - - elif self.dg_key == 'smooth-sphere-classification': - raise NotImplementedError - elif 'gaussian' in self.dg_key: - # The mean of the gaussian = the previous dir - raise NotImplementedError - else: - # Fisher: not sure how to do that. - raise NotImplementedError - - return copy_prev_dir - def take_lines_in_hidden_state(self, hidden_states, lines_to_keep): """ Utilitary method to remove a few streamlines from the hidden diff --git a/src/dwi_ml/models/projects/transformer_models.py b/src/dwi_ml/models/projects/transformer_models.py index 6a9d650d..14685eb0 100644 --- a/src/dwi_ml/models/projects/transformer_models.py +++ b/src/dwi_ml/models/projects/transformer_models.py @@ -701,55 +701,6 @@ def _run_position_encoding(self, data): def _run_main_layer_forward(self, data, masks, return_weights): raise NotImplementedError - def format_prev_dir_(self, dirs): - """ - Format the previous direction at each point. (To add to output). - At first coordinate: unkown. Using 0,0,0. - - If output is logits of classes: adding a one-hot vector with value - 6 on the right index (because sigmoid(6) is big). Always using value - 0 for the EOS class, if any. - """ - if 'regression' in self.dg_key: - # Regression: The latest previous dir will be used as skip - # connection on the output. - # Either take dirs and add [0, 0, 0] at each first position. - # Or use pre-computed: - copy_prev_dirs = dirs - elif self.dg_key == 'sphere-classification': - # Converting the input directions into classes the same way as - # during loss, but convert to one-hot. - # The first previous dir (0) converts to index 0. - - # Not necessarily the same class as previous dirs used as input to - # the decoder. - copy_prev_dirs = convert_dirs_to_class( - dirs, self.direction_getter.torch_sphere, smooth_labels=False, - add_sos=False, add_eos=False, to_one_hot=True) - - # Not adding a EOS point, but adding a EOS class with value 0. - if self.direction_getter.add_eos: - copy_prev_dirs = [torch.nn.functional.pad(cp, [0, 1, 0, 0]) - for cp in copy_prev_dirs] - - # Making the one from one-hot important for the sigmoid. - copy_prev_dirs = [c * 6.0 for c in copy_prev_dirs] - - elif self.dg_key == 'smooth-sphere-classification': - raise NotImplementedError - elif 'gaussian' in self.dg_key: - # The mean of the gaussian = the previous dir - raise NotImplementedError - else: - # Fisher: not sure how to do that. - raise NotImplementedError - - # Add zeros as previous dir at the first position - copy_prev_dirs = [torch.nn.functional.pad(cp, [0, 0, 1, 0]) - for cp in copy_prev_dirs] - - return copy_prev_dirs - def _run_target_embedding(self, targets, use_padding, batch_max_len): targets = pad_and_stack_batch(targets, use_padding, batch_max_len) targets = self.embedding_layer_t(targets) From c5eb6284d749b32cfe1ee36c25fa2403be16f51f Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 23 Mar 2026 12:08:30 -0400 Subject: [PATCH 4/4] Revert automatic changes --- src/dwi_ml/cli/dwiml_visualize_logs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dwi_ml/cli/dwiml_visualize_logs.py b/src/dwi_ml/cli/dwiml_visualize_logs.py index ac4e8618..e8807748 100644 --- a/src/dwi_ml/cli/dwiml_visualize_logs.py +++ b/src/dwi_ml/cli/dwiml_visualize_logs.py @@ -152,7 +152,7 @@ def __parse_log_operations(parser, graph): parser.error("Can't understand option --graph graph. The diff " "operation requires two logs, separated by a " "comma.") - logs = [_log._replace(' ', '') for _log in logs] + logs = [_log.replace(' ', '') for _log in logs] return logs, 'diff' elif _graph[0:4] == 'sum(': assert _graph[-1] == ')' @@ -162,7 +162,7 @@ def __parse_log_operations(parser, graph): parser.error("Can't understand option --graph graph. The sum " "operation requires two logs, separated by a " "comma.") - logs = [_log._replace(' ', '') for _log in logs] + logs = [_log.replace(' ', '') for _log in logs] return logs, 'sum' op = None