diff --git a/README.rst b/README.rst index 5c988c03..5beb8236 100644 --- a/README.rst +++ b/README.rst @@ -11,16 +11,12 @@ DWI_ML .. image:: https://img.shields.io/badge/License-MIT-yellow.svg :target: https://github.com/scil-vital/dwi_ml/blob/master/LICENSE -Welcome to the `Sherbrooke Connectivity Imaging Lab (SCIL)`_ and -`Videos & Images Theory and Analytics Laboratory (VITAL)`_ joint DWI_ML -toolkit ! +Welcome to the `SCIL`_ and `VITAL`_ joint DWI_ML toolkit ! -Links -===== +Documentation +============= -* `Getting started: installation and download `_ -* `Complete documentation with a step-by-step process explanation `_ -* `Contribute/Submit a patch `_ +* For the full documentation, including installation and usage, see here: https://dwi-ml.readthedocs.io/en/latest/. * `Issue tracking `_ About @@ -40,11 +36,10 @@ file. Citation ======== -If you use DWI_ML in your dMRI data analysis, please cite the toolkit and -provide a link to it. +If you use DWI_ML in your dMRI data analysis, please cite the toolkit and provide a link to it. .. Links .. Involved labs -.. _`Sherbrooke Connectivity Imaging Lab (SCIL)`: http://scil.dinf.usherbrooke.ca -.. _`Videos & Images Theory and Analytics Laboratory (VITAL)`: http://vital.dinf.usherbrooke.ca +.. _`SCIL`: http://scil.dinf.usherbrooke.ca +.. _`VITAL`: http://vital.dinf.usherbrooke.ca diff --git a/docs/_static/images/Learn2track.png b/docs/_static/images/Learn2track.png index b631a295..f02ef974 100644 Binary files a/docs/_static/images/Learn2track.png and b/docs/_static/images/Learn2track.png differ diff --git a/docs/_static/images/logo_dwi_ml_emma.png b/docs/_static/images/logo_dwi_ml_emma.png new file mode 100644 index 00000000..7dc6dc50 Binary files /dev/null and b/docs/_static/images/logo_dwi_ml_emma.png differ diff --git a/docs/_static/images/logo_dwi_ml_emma_avec_texte.png b/docs/_static/images/logo_dwi_ml_emma_avec_texte.png new file mode 100644 index 00000000..62a2c135 Binary files /dev/null and b/docs/_static/images/logo_dwi_ml_emma_avec_texte.png differ diff --git a/docs/_static/my_style.css b/docs/_static/my_style.css index 47880c5f..60abce31 100644 --- a/docs/_static/my_style.css +++ b/docs/_static/my_style.css @@ -67,4 +67,8 @@ footer { .rst-content li { padding: 0.1em 0; +} + +dt { + margin-bottom: 0 !important; } \ No newline at end of file diff --git a/docs/for_developers/data_management/index.rst b/docs/for_developers/data_management/index.rst index c06a0d2b..3200b897 100644 --- a/docs/for_developers/data_management/index.rst +++ b/docs/for_developers/data_management/index.rst @@ -1,3 +1,4 @@ +.. _data_management_index: Understanding our data management ================================= diff --git a/docs/for_developers/hdf5/advanced_hdf5_organization.rst b/docs/for_developers/hdf5/advanced_hdf5_organization.rst index 077e1e6c..2fd23d7b 100644 --- a/docs/for_developers/hdf5/advanced_hdf5_organization.rst +++ b/docs/for_developers/hdf5/advanced_hdf5_organization.rst @@ -1,4 +1,4 @@ -.. _ref_creating_hdf5: +.. _creating_hdf5: The hdf5 structure ================== diff --git a/docs/for_developers/models/index.rst b/docs/for_developers/models/index.rst index 5a8922e4..3b80643f 100644 --- a/docs/for_developers/models/index.rst +++ b/docs/for_developers/models/index.rst @@ -1,3 +1,5 @@ +.. _create_your_model: + Create your own model ===================== diff --git a/docs/for_developers/testing/general_testing.rst b/docs/for_developers/testing/general_testing.rst index 8c0c402e..b108a20c 100644 --- a/docs/for_developers/testing/general_testing.rst +++ b/docs/for_developers/testing/general_testing.rst @@ -1,3 +1,5 @@ +.. _model_testing: + General testing of a model -------------------------- diff --git a/docs/for_developers/training/trainers_details.rst b/docs/for_developers/training/trainers_details.rst index cf09c918..7305605d 100644 --- a/docs/for_developers/training/trainers_details.rst +++ b/docs/for_developers/training/trainers_details.rst @@ -1,3 +1,4 @@ +.. _trainers_details: Trainers: the code explained ============================ diff --git a/docs/for_users/models/denoising_models.rst b/docs/for_users/models/denoising_models.rst new file mode 100644 index 00000000..06874f6d --- /dev/null +++ b/docs/for_users/models/denoising_models.rst @@ -0,0 +1,7 @@ +.. _denoising_models: + + +Denoising models +================ + +Coming soon: Autoencoder (AE) model! diff --git a/docs/for_users/models/our_models.rst b/docs/for_users/models/our_models.rst new file mode 100644 index 00000000..835c68bf --- /dev/null +++ b/docs/for_users/models/our_models.rst @@ -0,0 +1,11 @@ +.. _our_models: + +Our models +========== + + .. toctree:: + :maxdepth: 1 + :caption: Our models + + tractography_models + denoising_models diff --git a/docs/for_users/our_models.rst b/docs/for_users/models/tractography_models.rst similarity index 55% rename from docs/for_users/our_models.rst rename to docs/for_users/models/tractography_models.rst index da247ea7..f69f4c41 100644 --- a/docs/for_users/our_models.rst +++ b/docs/for_users/models/tractography_models.rst @@ -1,34 +1,36 @@ -.. _our_models: +.. _tractography_models: -Our models -========== +Tractography models +=================== -Our library currently only offers two models, both for the task of tracking in the brain, but it could eventually hold models for other tasks. +For more explanation on how to use models for tracking, see :ref:`user_tracking`. -Denoising models ----------------- +TractographyTransformers (tt) +***************************** -Coming soon: Autoencoder (AE) model! +This uses transformers and should be the subject of an upcoming publication. -Tractography models -------------------- -For more explanation on how to use models for tracking, see :ref:`user_tracking`. + .. image:: /_static/images/Transformers.png + :align: center + :width: 600 + + +To use this model, run script `tt_track_from_model.py`. . To learn more, run:: + + tt_track_from_model --help + Learn2track (l2t) ***************** + This is a refactored version of the code prepared by authors of `Poulin2017 `_. .. image:: /_static/images/Learn2track.png + :align: center :width: 500 -To use this model, run script `l2t_track_from_model`. - -TractographyTransformers (tt) -***************************** - -This uses transformers and should be the subject of an upcoming publication. +To use this model, run script `l2t_track_from_model`. To learn more, run:: - ToDo: Add picture + l2t_track_from_model --help -To use this model, run script `tt_track_from_model.py`. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index e22afb03..2280f69a 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -4,7 +4,7 @@ Getting started: download and installation Downloading dwi_ml ****************** -To use the DWI_ML toolkit you will need to clone the repository and install the required dependencies:: +To use the DWI_ML toolkit you will need to clone the `GitHub repository `_ and install the required dependencies:: git clone https://github.com/scil-vital/dwi_ml.git @@ -29,12 +29,12 @@ We strongly recommend working in a virtual environment to install all dependenci **Creating a Comet account**: -- The toolkit uses `comet_ml `_. It is a python library that creates an "Experiment" (ex, training a model with a given set of hyperparameters) which automatically creates many types of logs online. It requires user to set an API key in $HOME/.comet.config with contents: +- The toolkit uses `comet_ml `_. It is a python library that creates an "Experiment" (ex, training a model with a given set of hyperparameters) which automatically creates many types of logs online. It requires user to set an API key in $HOME/.comet.config with contents:: | [comet] | api_key=YOUR-API-KEY -Alternatively, you can add it as an environment variable. Add this to your $HOME/.bashrc file. +Alternatively, you can add it as an environment variable. Add this to your $HOME/.bashrc file:: | export COMET_API_KEY=YOUR-API-KEY diff --git a/docs/index.rst b/docs/index.rst index f07ff94d..7f844c49 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,43 +1,78 @@ Welcome to DWI_ML documentation! ================================ -This website is a guide to the github repository from the SCIL-VITAL organisation: https://github.com/scil-vital/dwi_ml/. +This website is a guide to the github repository from the SCIL-VITAL organisation: https://github.com/scil-vital/dwi_ml/. DWI_ML is a toolkit for Diffusion Magnetic Resonance Imaging (dMRI) analysis +using machine learning and deep learning methods. It is mostly focused on the tractography derivatives of dMRI. + + + .. image:: /_static/images/logo_dwi_ml_emma_avec_texte.png + :align: center + :width: 500 In this doc, we will present you everything included in this library for you to become either a developer or a user. Note that to get a full understanding of every line of code, you can browse further in each section. -Getting started ---------------- +On this page: + + - :ref:`section_install` + - :ref:`section_users` + - :ref:`section_advanced_users` + - :ref:`section_developers` + +.. _section_install: + +1. Installing dwi_ml +-------------------- .. toctree:: :maxdepth: 1 :titlesonly: - :caption: Getting started + :caption: Installing dwi_ml getting_started -Explanations for users of pre-trained models --------------------------------------------- +.. _section_users: + +2. Explanations for users of pre-trained models (Learn2track, Transformers) +--------------------------------------------------------------------------- Pages in this section explain how to use our scripts to use our pre-trained models. -- **Models**: If you want to use our pre-trained models, you may contact us for access to the models learned weights. They will be available online once publications are accepted. +- **1. Downloading models**: If you want to use our pre-trained models, you may contact us for access to the models learned weights. They will be available online once publications are accepted. -- **Using hdf5**: In most cases, data must be organized correctly as a hdf5 before usage. See our page :ref:`hdf5_usage` for an explanation. + - :ref:`our_models` + - :ref:`tractography_models` + - :ref:`denoising_models` - .. toctree:: - :maxdepth: 2 - :caption: Explanations for users (pre-trained) +- **2. Organizing your data**: In most cases, data must be organized correctly as a hdf5 before usage. Follow the link below for an explanation. + + - :ref:`hdf5_usage` + +- **3. Using our models to perform tractography**: Use our models to track on your own subjects! + + - :ref:`user_tracking` + +.. --------------------Hidden toctree: --------------- - for_users/our_models - for_users/hdf5 - for_users/tracking +.. toctree:: + :maxdepth: 2 + :hidden: + :caption: Explanations for users (pre-trained) -Explanations for users of pre-coded models ------------------------------------------- + for_users/models/our_models + for_users/hdf5 + for_users/tracking + +------------------------------ + +.. _section_advanced_users: + +3. Explanations for advanced users: train a model with your own hyperparameters +------------------------------------------------------------------------------- Pages in this section are useful if you want to train a model based on pre-existing code, such as Learn2track or TractographyTransformers, using your favorite set of hyperparameters. +(Improved documentation coming soon!) .. toctree:: :maxdepth: 2 @@ -46,28 +81,45 @@ Pages in this section are useful if you want to train a model based on pre-exist for_users/from_start_to_finish for_users/visu_logs -Explanations for developers ---------------------------- +.. _section_developers: + +4. Explanations for developers: create your own model +----------------------------------------------------- Page in this section explain more in details how the code is implemented in python. -- **Models**: The first aspect to explore are our models. Discover how you can create your model to fit with our structure. Many parent classes are available for you: if your model inherits from them, they will have access to everything each one offers. For instance, some models have instructions on how to receive inputs from MRI data, prepare inputs in a neighborhood, and use embedding. Other models have access to many options of loss functions for the context of tractography (cosine similarity, classification, Gaussian loss, Fisher von Mises, etc.). +- **1. Create your models**: The first aspect to explore are our models. Discover how you can create your model to fit with our structure. Many parent classes are available for you: if your model inherits from them, they will have access to everything each one offers. For instance, some models have instructions on how to receive inputs from MRI data, prepare inputs in a neighborhood, and use embedding. Other models have access to many options of loss functions for the context of tractography (cosine similarity, classification, Gaussian loss, Fisher von Mises, etc.). -- **Using hdf5**: Our library has been organized to use data in the hdf5 format. Our hdf5 data organization should probably be enough for your needs (see explanations on :ref:`hdf5_usage`), but for more + - :ref:`create_your_model` -- **Training a model**: Then, take a look at how we have implemented our trainers for an efficient management of heavy data. +- **2. Explore our hdf5 organization**: Our library has been organized to use data in the hdf5 format. Our hdf5 data organization should probably be enough for your needs. -- **Using your trained models**: Discover our objects allowing to perform a full tractography from a tractography-model. + - :ref:`hdf5_usage` + - :ref:`creating_hdf5` +- **3. Train your model**: Take a look at how we have implemented our trainers for an efficient management of heavy data. Note that our trainer uses Data Management classes such as our BathLoader and BatchSampler. See below for more information. - .. toctree:: - :maxdepth: 1 - :caption: Explanations for developers - - for_developers/models/index - for_developers/hdf5/advanced_hdf5_organization - for_developers/training/training - for_developers/training/trainers_details - for_developers/data_management/index - for_developers/testing/general_testing - for_developers/testing/tracking_objects \ No newline at end of file + - :ref:`trainers` + - :ref:`trainers_details` + - :ref:`data_management_index` + +- **4. Use your trained model**: Discover our objects allowing to perform a full tractography from a tractography-model. + + - :ref:`model_testing` + - :ref:`user_tracking` + - :ref:`tracking` + +.. --------------------Hidden toctree: --------------- + +.. toctree:: + :maxdepth: 1 + :caption: Explanations for developers + :hidden: + + for_developers/models/index + for_developers/hdf5/advanced_hdf5_organization + for_developers/training/training + for_developers/training/trainers_details + for_developers/data_management/index + for_developers/testing/general_testing + for_developers/testing/tracking_objects \ No newline at end of file diff --git a/src/dwi_ml/cache/__init__.py b/src/dwi_ml/cli/__init__.py similarity index 100% rename from src/dwi_ml/cache/__init__.py rename to src/dwi_ml/cli/__init__.py diff --git a/src/dwi_ml/cli/ae_train_model.py b/src/dwi_ml/cli/ae_train_model.py index fb1c8b6b..0ba95a6d 100755 --- a/src/dwi_ml/cli/ae_train_model.py +++ b/src/dwi_ml/cli/ae_train_model.py @@ -17,19 +17,19 @@ from scilpy.io.utils import (assert_inputs_exist, assert_outputs_exist, add_verbose_arg) -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_memory_args -from dwi_ml.models.projects.ae_models import ModelAE -from dwi_ml.training.trainers import DWIMLTrainer -from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, - prepare_batch_sampler) -from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) -from dwi_ml.training.utils.trainer import (add_training_args, run_experiment, - format_lr) -from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader -from dwi_ml.training.utils.experiment import ( +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.io_utils import add_memory_args +from dwi_ml.projects.AE.ae_models import ModelAE +from dwi_ml.general.training.trainers import DWIMLTrainer +from dwi_ml.general.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.general.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.general.training.utils.trainer import (add_training_args, run_experiment, + format_lr) +from dwi_ml.general.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.general.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) diff --git a/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_blocs.py b/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_blocs.py index 402dc4db..11bd71ed 100644 --- a/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_blocs.py +++ b/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_blocs.py @@ -14,8 +14,8 @@ from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ add_verbose_arg, add_overwrite_arg -from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.hdf5.utils import format_nb_blocs_connectivity +from dwi_ml.general.data.processing.streamlines.post_processing import \ compute_triu_connectivity_from_blocs, \ find_streamlines_with_chosen_connectivity, prepare_figure_connectivity diff --git a/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_labels.py b/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_labels.py index f1ab8e9a..3e2aca63 100644 --- a/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_labels.py +++ b/src/dwi_ml/cli/dwiml_compute_connectivity_matrix_from_labels.py @@ -21,7 +21,7 @@ from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ add_verbose_arg, add_overwrite_arg -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ find_streamlines_with_chosen_connectivity, \ compute_triu_connectivity_from_labels, prepare_figure_connectivity diff --git a/src/dwi_ml/cli/dwiml_create_hdf5_dataset.py b/src/dwi_ml/cli/dwiml_create_hdf5_dataset.py index b5e244d6..fe444593 100644 --- a/src/dwi_ml/cli/dwiml_create_hdf5_dataset.py +++ b/src/dwi_ml/cli/dwiml_create_hdf5_dataset.py @@ -30,10 +30,10 @@ from dipy.io.stateful_tractogram import set_sft_logger_level -from dwi_ml.data.hdf5.hdf5_creation import HDF5Creator -from dwi_ml.data.hdf5.utils import ( +from dwi_ml.general.data.hdf5.hdf5_creation import HDF5Creator +from dwi_ml.general.data.hdf5.utils import ( add_hdf5_creation_args, add_streamline_processing_args) -from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.general.experiment_utils.timer import Timer def _initialize_intermediate_subdir(hdf5_file, save_intermediate): diff --git a/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py b/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py index 26df5706..7e967fab 100644 --- a/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py +++ b/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py @@ -8,7 +8,7 @@ from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ add_overwrite_arg -from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity +from dwi_ml.general.data.hdf5.utils import format_nb_blocs_connectivity def _build_arg_parser(): diff --git a/src/dwi_ml/cli/dwiml_hdf5_extract_data.py b/src/dwi_ml/cli/dwiml_hdf5_extract_data.py index e96835f0..24c73f3a 100644 --- a/src/dwi_ml/cli/dwiml_hdf5_extract_data.py +++ b/src/dwi_ml/cli/dwiml_hdf5_extract_data.py @@ -5,16 +5,16 @@ import h5py import nibabel as nib import numpy as np -from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin +from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.streamline import save_tractogram from matplotlib import pyplot as plt from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ add_overwrite_arg, add_verbose_arg -from dwi_ml.data.dataset.streamline_containers import \ +from dwi_ml.general.data.dataset.streamline_containers import \ load_all_streamlines_from_hdf, load_streamlines_attributes_from_hdf -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ prepare_figure_connectivity diff --git a/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py b/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py index 9aacb095..d1d5281e 100644 --- a/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py +++ b/src/dwi_ml/cli/dwiml_send_value_to_comet_from_log.py @@ -15,7 +15,7 @@ from comet_ml import ExistingExperiment from scilpy.io.utils import assert_inputs_exist -from dwi_ml.training.trainers import DWIMLTrainer +from dwi_ml.general.training.trainers import DWIMLTrainer def _build_arg_parser(): diff --git a/src/dwi_ml/cli/dwiml_send_value_to_comet_manually.py b/src/dwi_ml/cli/dwiml_send_value_to_comet_manually.py index 7a486729..d98ca975 100644 --- a/src/dwi_ml/cli/dwiml_send_value_to_comet_manually.py +++ b/src/dwi_ml/cli/dwiml_send_value_to_comet_manually.py @@ -9,7 +9,7 @@ from comet_ml import ExistingExperiment -from dwi_ml.training.trainers import DWIMLTrainer +from dwi_ml.general.training.trainers import DWIMLTrainer def _build_arg_parser(): diff --git a/src/dwi_ml/cli/dwiml_visualize_logs.py b/src/dwi_ml/cli/dwiml_visualize_logs.py index e8807748..e45f536d 100644 --- a/src/dwi_ml/cli/dwiml_visualize_logs.py +++ b/src/dwi_ml/cli/dwiml_visualize_logs.py @@ -39,7 +39,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_outputs_exist) -from dwi_ml.viz.logs_plots import visualize_logs +from dwi_ml.general.viz.logs_plots import visualize_logs def _build_arg_parser(): diff --git a/src/dwi_ml/cli/dwiml_visualize_noise_on_streamlines.py b/src/dwi_ml/cli/dwiml_visualize_noise_on_streamlines.py index 6a9d15f5..1706c041 100644 --- a/src/dwi_ml/cli/dwiml_visualize_noise_on_streamlines.py +++ b/src/dwi_ml/cli/dwiml_visualize_noise_on_streamlines.py @@ -20,12 +20,12 @@ from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ add_overwrite_arg -from dwi_ml.data.processing.streamlines.data_augmentation import \ +from dwi_ml.general.data.processing.streamlines.data_augmentation import \ resample_or_compress -from dwi_ml.data.processing.utils import add_noise_to_tensor -from dwi_ml.testing.utils import add_args_testing_subj_hdf5, \ +from dwi_ml.general.data.processing.utils import add_noise_to_tensor +from dwi_ml.general.testing.utils import add_args_testing_subj_hdf5, \ prepare_dataset_one_subj -from dwi_ml.training.utils.batch_loaders import add_args_batch_loader +from dwi_ml.general.training.utils.batch_loaders import add_args_batch_loader def prepare_arg_parser(): diff --git a/src/dwi_ml/cli/l2t_resume_training_from_checkpoint.py b/src/dwi_ml/cli/l2t_resume_training_from_checkpoint.py index 6e3ec354..8017d591 100644 --- a/src/dwi_ml/cli/l2t_resume_training_from_checkpoint.py +++ b/src/dwi_ml/cli/l2t_resume_training_from_checkpoint.py @@ -7,7 +7,6 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml # Also, after upgrading torch, I now have a lot of warnings: # FutureWarning: `torch.distributed.reduce_op` is deprecated, please use @@ -21,14 +20,15 @@ from scilpy.io.utils import add_verbose_arg -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler -from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer -from dwi_ml.training.utils.experiment import add_args_resuming_experiment -from dwi_ml.training.utils.trainer import run_experiment +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.general.training.utils.experiment import add_args_resuming_experiment +from dwi_ml.general.training.utils.trainer import run_experiment + +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.projects.Learn2track.learn2track_trainer import Learn2TrackTrainer def prepare_arg_parser(): diff --git a/src/dwi_ml/cli/l2t_track_from_model.py b/src/dwi_ml/cli/l2t_track_from_model.py index 37dd497e..c207cbbf 100644 --- a/src/dwi_ml/cli/l2t_track_from_model.py +++ b/src/dwi_ml/cli/l2t_track_from_model.py @@ -20,16 +20,17 @@ verify_streamline_length_options, verify_seed_options, add_out_options) -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.testing.utils import prepare_dataset_one_subj, \ +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.testing.utils import prepare_dataset_one_subj, \ find_hdf5_associated_to_experiment -from dwi_ml.tracking.projects.learn2track_tracker import RecurrentTracker -from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.io_utils import (add_tracking_options, - prepare_seed_generator, - prepare_tracking_mask, track_and_save) +from dwi_ml.general.tracking.tracking_mask import TrackingMask +from dwi_ml.general.tracking.io_utils import (add_tracking_options, + prepare_seed_generator, + prepare_tracking_mask, track_and_save) + +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.projects.Learn2track.learn2track_tracker import RecurrentTracker # Also, after upgrading torch, I now have a lot of warnings: diff --git a/src/dwi_ml/cli/l2t_train_model.py b/src/dwi_ml/cli/l2t_train_model.py index 8cf5baae..94df403e 100755 --- a/src/dwi_ml/cli/l2t_train_model.py +++ b/src/dwi_ml/cli/l2t_train_model.py @@ -11,7 +11,6 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml # Also, after upgrading torch, I now have a lot of warnings: # FutureWarning: `torch.distributed.reduce_op` is deprecated, please use @@ -27,23 +26,24 @@ from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist, assert_outputs_exist) -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_memory_args -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.models.projects.learn2track_utils import add_model_args -from dwi_ml.models.utils.direction_getters import check_args_direction_getter -from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer -from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, - prepare_batch_sampler) -from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader, - prepare_batch_loader) -from dwi_ml.training.utils.experiment import ( +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.io_utils import add_memory_args +from dwi_ml.general.models.utils.direction_getters import check_args_direction_getter +from dwi_ml.general.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.general.training.utils.batch_loaders import (add_args_batch_loader, + prepare_batch_loader) +from dwi_ml.general.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) -from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ +from dwi_ml.general.training.utils.trainer import run_experiment, add_training_args, \ format_lr +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.projects.Learn2track.learn2track_utils import add_model_args +from dwi_ml.projects.Learn2track.learn2track_trainer import Learn2TrackTrainer + def prepare_arg_parser(): p = argparse.ArgumentParser(description=__doc__, diff --git a/src/dwi_ml/cli/l2t_update_deprecated_exp.py b/src/dwi_ml/cli/l2t_update_deprecated_exp.py index 69b24bc1..44f13f79 100644 --- a/src/dwi_ml/cli/l2t_update_deprecated_exp.py +++ b/src/dwi_ml/cli/l2t_update_deprecated_exp.py @@ -17,12 +17,13 @@ import torch from scilpy.io.utils import add_verbose_arg -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler -from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler + +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.projects.Learn2track.learn2track_trainer import Learn2TrackTrainer def prepare_arg_parser(): @@ -254,7 +255,7 @@ def load_checkpoint_and_fix(args, model): dataset, model, checkpoint_state['batch_loader_params']) experiments_path, experiment_name = os.path.split(args.out_experiment) if experiments_path == '': - experiments_path = './' + experiments_path = '/' try: _ = Learn2TrackTrainer.init_from_checkpoint( diff --git a/src/dwi_ml/cli/l2t_visualize_loss.py b/src/dwi_ml/cli/l2t_visualize_loss.py index b998c02b..8c299da9 100644 --- a/src/dwi_ml/cli/l2t_visualize_loss.py +++ b/src/dwi_ml/cli/l2t_visualize_loss.py @@ -16,12 +16,13 @@ import torch -from dwi_ml.io_utils import add_arg_existing_experiment_path -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.testing.testers import TesterOneInput -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 -from dwi_ml.testing.visu_loss import run_all_visu_loss -from dwi_ml.testing.visu_loss_utils import prepare_args_visu_loss, visu_checks +from dwi_ml.general.io_utils import add_arg_existing_experiment_path +from dwi_ml.general.testing.testers import TesterOneInput +from dwi_ml.general.testing.utils import add_args_testing_subj_hdf5 +from dwi_ml.general.testing.visu_loss import run_all_visu_loss +from dwi_ml.general.testing.visu_loss_utils import prepare_args_visu_loss, visu_checks + +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel def prepare_argparser(): diff --git a/src/dwi_ml/cli/l2t_visualize_weights_evolution.py b/src/dwi_ml/cli/l2t_visualize_weights_evolution.py index cbf9a69a..1b5a7407 100644 --- a/src/dwi_ml/cli/l2t_visualize_weights_evolution.py +++ b/src/dwi_ml/cli/l2t_visualize_weights_evolution.py @@ -15,7 +15,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np -from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer +from dwi_ml.projects.Learn2track.learn2track_trainer import Learn2TrackTrainer def prepare_arg_parser(): diff --git a/src/dwi_ml/data/__init__.py b/src/dwi_ml/cli/tests/__init__.py similarity index 100% rename from src/dwi_ml/data/__init__.py rename to src/dwi_ml/cli/tests/__init__.py diff --git a/src/dwi_ml/cli/tests/test_all_steps_l2t.py b/src/dwi_ml/cli/tests/test_all_steps_l2t.py index 71544030..66367f75 100644 --- a/src/dwi_ml/cli/tests/test_all_steps_l2t.py +++ b/src/dwi_ml/cli/tests/test_all_steps_l2t.py @@ -6,10 +6,10 @@ import torch -from dwi_ml.unit_tests.utils.expected_values import ( +from dwi_ml.general.unit_tests.utils.expected_values import ( TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_SUBJ_NAMES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() experiment_name = 'test_experiment' diff --git a/src/dwi_ml/cli/tests/test_all_steps_tto.py b/src/dwi_ml/cli/tests/test_all_steps_tto.py index 2f7ccf7b..ccaacb76 100644 --- a/src/dwi_ml/cli/tests/test_all_steps_tto.py +++ b/src/dwi_ml/cli/tests/test_all_steps_tto.py @@ -7,10 +7,10 @@ import torch -from dwi_ml.unit_tests.utils.expected_values import ( +from dwi_ml.general.unit_tests.utils.expected_values import ( TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_SUBJ_NAMES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() tmp_dir = tempfile.TemporaryDirectory() diff --git a/src/dwi_ml/cli/tests/test_all_steps_tts.py b/src/dwi_ml/cli/tests/test_all_steps_tts.py index 1bb74bf0..325cc37a 100644 --- a/src/dwi_ml/cli/tests/test_all_steps_tts.py +++ b/src/dwi_ml/cli/tests/test_all_steps_tts.py @@ -5,10 +5,10 @@ import pytest import tempfile -from dwi_ml.unit_tests.utils.expected_values import \ +from dwi_ml.general.unit_tests.utils.expected_values import \ (TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_SUBJ_NAMES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() tmp_dir = tempfile.TemporaryDirectory() diff --git a/src/dwi_ml/cli/tests/test_all_steps_ttst.py b/src/dwi_ml/cli/tests/test_all_steps_ttst.py index 1ac4e5ac..d000e7f3 100644 --- a/src/dwi_ml/cli/tests/test_all_steps_ttst.py +++ b/src/dwi_ml/cli/tests/test_all_steps_ttst.py @@ -5,10 +5,10 @@ import pytest import tempfile -from dwi_ml.unit_tests.utils.expected_values import \ +from dwi_ml.general.unit_tests.utils.expected_values import \ (TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_SUBJ_NAMES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() tmp_dir = tempfile.TemporaryDirectory() diff --git a/src/dwi_ml/cli/tests/test_compute_connectivity_matrix_from_blocs.py b/src/dwi_ml/cli/tests/test_compute_connectivity_matrix_from_blocs.py index 78a02565..652a8d66 100644 --- a/src/dwi_ml/cli/tests/test_compute_connectivity_matrix_from_blocs.py +++ b/src/dwi_ml/cli/tests/test_compute_connectivity_matrix_from_blocs.py @@ -4,7 +4,7 @@ import os import tempfile -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() tmp_dir = tempfile.TemporaryDirectory() diff --git a/src/dwi_ml/cli/tests/test_compute_connectivity_score.py b/src/dwi_ml/cli/tests/test_compute_connectivity_score.py index e1745616..f7e46568 100644 --- a/src/dwi_ml/cli/tests/test_compute_connectivity_score.py +++ b/src/dwi_ml/cli/tests/test_compute_connectivity_score.py @@ -3,7 +3,7 @@ import os import tempfile -from dwi_ml.unit_tests.utils.data_and_models_for_tests import \ +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import \ fetch_testing_data data_dir = fetch_testing_data() diff --git a/src/dwi_ml/cli/tests/test_create_hdf5_dataset.py b/src/dwi_ml/cli/tests/test_create_hdf5_dataset.py index 311d0038..3cd616db 100644 --- a/src/dwi_ml/cli/tests/test_create_hdf5_dataset.py +++ b/src/dwi_ml/cli/tests/test_create_hdf5_dataset.py @@ -4,7 +4,7 @@ import os import tempfile -from dwi_ml.unit_tests.utils.data_and_models_for_tests import \ +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import \ fetch_testing_data data_dir = fetch_testing_data() diff --git a/src/dwi_ml/cli/tests/test_divide_volume_into_blocs.py b/src/dwi_ml/cli/tests/test_divide_volume_into_blocs.py index 8085f29c..b03951c6 100644 --- a/src/dwi_ml/cli/tests/test_divide_volume_into_blocs.py +++ b/src/dwi_ml/cli/tests/test_divide_volume_into_blocs.py @@ -3,7 +3,7 @@ import os import tempfile -from dwi_ml.unit_tests.utils.data_and_models_for_tests import \ +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import \ fetch_testing_data data_dir = fetch_testing_data() diff --git a/src/dwi_ml/cli/tests/test_print_hdf5_architecture.py b/src/dwi_ml/cli/tests/test_print_hdf5_architecture.py index 760baa75..4a911cb4 100644 --- a/src/dwi_ml/cli/tests/test_print_hdf5_architecture.py +++ b/src/dwi_ml/cli/tests/test_print_hdf5_architecture.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import os -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() experiment_name = 'test_experiment' diff --git a/src/dwi_ml/cli/tests/test_scil_score_ismrm_Renauld2023.py b/src/dwi_ml/cli/tests/test_scil_score_ismrm_Renauld2023.py index b2b3c46c..047da3d4 100644 --- a/src/dwi_ml/cli/tests/test_scil_score_ismrm_Renauld2023.py +++ b/src/dwi_ml/cli/tests/test_scil_score_ismrm_Renauld2023.py @@ -4,7 +4,7 @@ import subprocess import tempfile -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data data_dir = fetch_testing_data() tmp_dir = tempfile.TemporaryDirectory() diff --git a/src/dwi_ml/cli/tt_resume_training_from_checkpoint.py b/src/dwi_ml/cli/tt_resume_training_from_checkpoint.py index 76b573f5..e09871cb 100644 --- a/src/dwi_ml/cli/tt_resume_training_from_checkpoint.py +++ b/src/dwi_ml/cli/tt_resume_training_from_checkpoint.py @@ -7,7 +7,6 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml # Also, after upgrading torch, I now have a lot of warnings: # FutureWarning: `torch.distributed.reduce_op` is deprecated, please use @@ -21,15 +20,16 @@ from scilpy.io.utils import add_verbose_arg -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import verify_which_model_in_path -from dwi_ml.models.projects.transformer_models import find_transformer_class -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler -from dwi_ml.training.projects.transformer_trainer import TransformerTrainer -from dwi_ml.training.utils.experiment import add_args_resuming_experiment -from dwi_ml.training.utils.trainer import run_experiment +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.io_utils import verify_which_model_in_path +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.general.training.utils.experiment import add_args_resuming_experiment +from dwi_ml.general.training.utils.trainer import run_experiment + +from dwi_ml.projects.Transformers.transformer_models import find_transformer_class +from dwi_ml.projects.Transformers.transformer_trainer import TransformerTrainer def prepare_arg_parser(): diff --git a/src/dwi_ml/cli/tt_track_from_model.py b/src/dwi_ml/cli/tt_track_from_model.py index 26065912..a4d7c4a2 100644 --- a/src/dwi_ml/cli/tt_track_from_model.py +++ b/src/dwi_ml/cli/tt_track_from_model.py @@ -21,19 +21,20 @@ verify_streamline_length_options, verify_seed_options, add_out_options) -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import verify_which_model_in_path -from dwi_ml.models.projects.transformer_models import find_transformer_class -from dwi_ml.testing.utils import prepare_dataset_one_subj, \ +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.io_utils import verify_which_model_in_path +from dwi_ml.general.testing.utils import prepare_dataset_one_subj, \ find_hdf5_associated_to_experiment -from dwi_ml.tracking.projects.transformer_tracker import \ +from dwi_ml.general.tracking.tracking_mask import TrackingMask +from dwi_ml.general.tracking.io_utils import (add_tracking_options, + prepare_seed_generator, + prepare_tracking_mask, + track_and_save) + +from dwi_ml.projects.Transformers.transformer_models import find_transformer_class +from dwi_ml.projects.Transformers.transformer_tracker import \ TransformerTracker -from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.io_utils import (add_tracking_options, - prepare_seed_generator, - prepare_tracking_mask, - track_and_save) # Also, after upgrading torch, I now have a lot of warnings: # FutureWarning: `torch.distributed.reduce_op` is deprecated, please use diff --git a/src/dwi_ml/cli/tt_train_model.py b/src/dwi_ml/cli/tt_train_model.py index 127ef65c..dc1451f1 100755 --- a/src/dwi_ml/cli/tt_train_model.py +++ b/src/dwi_ml/cli/tt_train_model.py @@ -11,7 +11,6 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml # Also, after upgrading torch, I now have a lot of warnings: # FutureWarning: `torch.distributed.reduce_op` is deprecated, please use @@ -27,25 +26,26 @@ from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist, assert_outputs_exist) -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_memory_args -from dwi_ml.models.projects.transformer_models import ( +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.io_utils import add_memory_args +from dwi_ml.general.models.utils.direction_getters import check_args_direction_getter +from dwi_ml.general.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.general.training.utils.batch_loaders import (add_args_batch_loader, + prepare_batch_loader) +from dwi_ml.general.training.utils.experiment import ( + add_mandatory_args_experiment_and_hdf5_path) +from dwi_ml.general.training.utils.trainer import (add_training_args, run_experiment, + format_lr) + +from dwi_ml.projects.Transformers.transformer_models import ( OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel) -from dwi_ml.models.projects.transformers_utils import ( +from dwi_ml.projects.Transformers.transformers_utils import ( add_transformers_model_args) -from dwi_ml.models.utils.direction_getters import check_args_direction_getter -from dwi_ml.training.projects.transformer_trainer import TransformerTrainer -from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, - prepare_batch_sampler) -from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader, - prepare_batch_loader) -from dwi_ml.training.utils.experiment import ( - add_mandatory_args_experiment_and_hdf5_path) -from dwi_ml.training.utils.trainer import (add_training_args, run_experiment, - format_lr) +from dwi_ml.projects.Transformers.transformer_trainer import TransformerTrainer def prepare_arg_parser(): diff --git a/src/dwi_ml/cli/tt_update_deprecated_exp.py b/src/dwi_ml/cli/tt_update_deprecated_exp.py index c097beae..f1b4f33c 100644 --- a/src/dwi_ml/cli/tt_update_deprecated_exp.py +++ b/src/dwi_ml/cli/tt_update_deprecated_exp.py @@ -15,17 +15,18 @@ import numpy as np -from dwi_ml.training.utils.monitoring import BatchHistoryMonitor +from dwi_ml.general.training.utils.monitoring import BatchHistoryMonitor import torch -from dwi_ml.io_utils import verify_which_model_in_path -from dwi_ml.models.projects.transformer_models import find_transformer_class -from dwi_ml.training.projects.transformer_trainer import TransformerTrainer from scilpy.io.utils import add_verbose_arg -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.general.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.general.io_utils import verify_which_model_in_path +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler + +from dwi_ml.projects.Transformers.transformer_models import find_transformer_class +from dwi_ml.projects.Transformers.transformer_trainer import TransformerTrainer def prepare_arg_parser(): @@ -253,7 +254,7 @@ def load_checkpoint_and_fix(args, model): dataset, model, checkpoint_state['batch_loader_params']) experiments_path, experiment_name = os.path.split(args.out_experiment) if experiments_path == '': - experiments_path = './' + experiments_path = '/' try: _ = TransformerTrainer.init_from_checkpoint( diff --git a/src/dwi_ml/cli/tt_visualize_loss.py b/src/dwi_ml/cli/tt_visualize_loss.py index 5f8dde60..5a74ea72 100644 --- a/src/dwi_ml/cli/tt_visualize_loss.py +++ b/src/dwi_ml/cli/tt_visualize_loss.py @@ -18,13 +18,14 @@ from scilpy.io.utils import assert_inputs_exist -from dwi_ml.io_utils import (add_arg_existing_experiment_path, - verify_which_model_in_path) -from dwi_ml.models.projects.transformer_models import find_transformer_class -from dwi_ml.testing.testers import TesterOneInput -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 -from dwi_ml.testing.visu_loss import run_all_visu_loss -from dwi_ml.testing.visu_loss_utils import prepare_args_visu_loss, visu_checks +from dwi_ml.general.io_utils import (add_arg_existing_experiment_path, + verify_which_model_in_path) +from dwi_ml.general.testing.testers import TesterOneInput +from dwi_ml.general.testing.utils import add_args_testing_subj_hdf5 +from dwi_ml.general.testing.visu_loss import run_all_visu_loss +from dwi_ml.general.testing.visu_loss_utils import prepare_args_visu_loss, visu_checks + +from dwi_ml.projects.Transformers.transformer_models import find_transformer_class def prepare_argparser(): diff --git a/src/dwi_ml/cli/tt_visualize_weights.py b/src/dwi_ml/cli/tt_visualize_weights.py index cb5d131b..2af8a13f 100644 --- a/src/dwi_ml/cli/tt_visualize_weights.py +++ b/src/dwi_ml/cli/tt_visualize_weights.py @@ -7,10 +7,10 @@ from scilpy.io.utils import assert_outputs_exist -from dwi_ml.testing.projects.tt_visu_argparser import \ +from dwi_ml.projects.Transformers.tester.tt_visu_argparser import \ build_argparser_transformer_visu -from dwi_ml.testing.projects.tt_visu_main import tt_visualize_weights_main -from dwi_ml.testing.projects.tt_visu_utils import get_out_dir_and_create, \ +from dwi_ml.projects.Transformers.tester.tt_visu_main import tt_visualize_weights_main +from dwi_ml.projects.Transformers.tester.tt_visu_utils import get_out_dir_and_create, \ get_config_filename diff --git a/src/dwi_ml/data/dataset/__init__.py b/src/dwi_ml/general/__init__.py similarity index 100% rename from src/dwi_ml/data/dataset/__init__.py rename to src/dwi_ml/general/__init__.py diff --git a/src/dwi_ml/data/hdf5/__init__.py b/src/dwi_ml/general/cache/__init__.py similarity index 100% rename from src/dwi_ml/data/hdf5/__init__.py rename to src/dwi_ml/general/cache/__init__.py diff --git a/src/dwi_ml/cache/cache_manager.py b/src/dwi_ml/general/cache/cache_manager.py similarity index 100% rename from src/dwi_ml/cache/cache_manager.py rename to src/dwi_ml/general/cache/cache_manager.py diff --git a/src/dwi_ml/data/processing/__init__.py b/src/dwi_ml/general/data/__init__.py similarity index 100% rename from src/dwi_ml/data/processing/__init__.py rename to src/dwi_ml/general/data/__init__.py diff --git a/src/dwi_ml/data/processing/dwi/__init__.py b/src/dwi_ml/general/data/dataset/__init__.py similarity index 100% rename from src/dwi_ml/data/processing/dwi/__init__.py rename to src/dwi_ml/general/data/dataset/__init__.py diff --git a/src/dwi_ml/data/dataset/checks_for_groups.py b/src/dwi_ml/general/data/dataset/checks_for_groups.py similarity index 100% rename from src/dwi_ml/data/dataset/checks_for_groups.py rename to src/dwi_ml/general/data/dataset/checks_for_groups.py diff --git a/src/dwi_ml/data/dataset/mri_data_containers.py b/src/dwi_ml/general/data/dataset/mri_data_containers.py similarity index 100% rename from src/dwi_ml/data/dataset/mri_data_containers.py rename to src/dwi_ml/general/data/dataset/mri_data_containers.py diff --git a/src/dwi_ml/data/dataset/multi_subject_containers.py b/src/dwi_ml/general/data/dataset/multi_subject_containers.py similarity index 98% rename from src/dwi_ml/data/dataset/multi_subject_containers.py rename to src/dwi_ml/general/data/dataset/multi_subject_containers.py index 5aa999c3..64b55b92 100644 --- a/src/dwi_ml/data/dataset/multi_subject_containers.py +++ b/src/dwi_ml/general/data/dataset/multi_subject_containers.py @@ -11,13 +11,13 @@ from tqdm import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -from dwi_ml.cache.cache_manager import SingleThreadCacheManager -from dwi_ml.data.dataset.checks_for_groups import prepare_groups_info -from dwi_ml.data.dataset.mri_data_containers import MRIDataAbstract -from dwi_ml.data.dataset.subjectdata_list_containers import ( +from dwi_ml.general.cache.cache_manager import SingleThreadCacheManager +from dwi_ml.general.data.dataset.checks_for_groups import prepare_groups_info +from dwi_ml.general.data.dataset.mri_data_containers import MRIDataAbstract +from dwi_ml.general.data.dataset.subjectdata_list_containers import ( LazySubjectsDataList, SubjectsDataList) -from dwi_ml.data.dataset.single_subject_containers import (LazySubjectData, - SubjectData) +from dwi_ml.general.data.dataset.single_subject_containers import ( + LazySubjectData, SubjectData) logger = logging.getLogger('dataset_logger') diff --git a/src/dwi_ml/data/dataset/single_subject_containers.py b/src/dwi_ml/general/data/dataset/single_subject_containers.py similarity index 96% rename from src/dwi_ml/data/dataset/single_subject_containers.py rename to src/dwi_ml/general/data/dataset/single_subject_containers.py index fbd9b6cc..c1be6fcd 100644 --- a/src/dwi_ml/data/dataset/single_subject_containers.py +++ b/src/dwi_ml/general/data/dataset/single_subject_containers.py @@ -2,10 +2,11 @@ import logging from typing import List, Union -from dwi_ml.data.dataset.mri_data_containers import (LazyMRIData, MRIData, - MRIDataAbstract) -from dwi_ml.data.dataset.streamline_containers import LazySFTData, SFTData -from dwi_ml.data.dataset.checks_for_groups import prepare_groups_info +from dwi_ml.general.data.dataset.mri_data_containers import ( + LazyMRIData, MRIData, MRIDataAbstract) +from dwi_ml.general.data.dataset.streamline_containers import ( + LazySFTData, SFTData) +from dwi_ml.general.data.dataset.checks_for_groups import prepare_groups_info logger = logging.getLogger('dataset_logger') diff --git a/src/dwi_ml/data/dataset/streamline_containers.py b/src/dwi_ml/general/data/dataset/streamline_containers.py similarity index 100% rename from src/dwi_ml/data/dataset/streamline_containers.py rename to src/dwi_ml/general/data/dataset/streamline_containers.py diff --git a/src/dwi_ml/data/dataset/subjectdata_list_containers.py b/src/dwi_ml/general/data/dataset/subjectdata_list_containers.py similarity index 98% rename from src/dwi_ml/data/dataset/subjectdata_list_containers.py rename to src/dwi_ml/general/data/dataset/subjectdata_list_containers.py index 7d203ccd..6c061cc4 100644 --- a/src/dwi_ml/data/dataset/subjectdata_list_containers.py +++ b/src/dwi_ml/general/data/dataset/subjectdata_list_containers.py @@ -3,7 +3,7 @@ import os import h5py -from dwi_ml.data.dataset.single_subject_containers import ( +from dwi_ml.general.data.dataset.single_subject_containers import ( LazySubjectData, SubjectDataAbstract, SubjectData) logger = logging.getLogger('dataset_logger') diff --git a/src/dwi_ml/data/dataset/utils.py b/src/dwi_ml/general/data/dataset/utils.py similarity index 88% rename from src/dwi_ml/data/dataset/utils.py rename to src/dwi_ml/general/data/dataset/utils.py index d8a7b0e9..b27079dc 100644 --- a/src/dwi_ml/data/dataset/utils.py +++ b/src/dwi_ml/general/data/dataset/utils.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- import logging -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.general.data.dataset.multi_subject_containers import \ + MultiSubjectDataset +from dwi_ml.general.experiment_utils.timer import Timer def prepare_multisubjectdataset(args, load_training=True, load_validation=True, diff --git a/src/dwi_ml/data/processing/space/__init__.py b/src/dwi_ml/general/data/hdf5/__init__.py similarity index 100% rename from src/dwi_ml/data/processing/space/__init__.py rename to src/dwi_ml/general/data/hdf5/__init__.py diff --git a/src/dwi_ml/data/hdf5/hdf5_creation.py b/src/dwi_ml/general/data/hdf5/hdf5_creation.py similarity index 99% rename from src/dwi_ml/data/hdf5/hdf5_creation.py rename to src/dwi_ml/general/data/hdf5/hdf5_creation.py index 006b16fe..6b6615fa 100644 --- a/src/dwi_ml/data/hdf5/hdf5_creation.py +++ b/src/dwi_ml/general/data/hdf5/hdf5_creation.py @@ -13,8 +13,8 @@ import h5py from scilpy.image.labels import get_data_as_labels -from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity -from dwi_ml.data.processing.streamlines.data_augmentation import \ +from dwi_ml.general.data.hdf5.utils import format_nb_blocs_connectivity +from dwi_ml.general.data.processing.streamlines.data_augmentation import \ resample_or_compress from nested_lookup import nested_lookup import nibabel as nib @@ -22,8 +22,8 @@ from scilpy.tractograms.tractogram_operations import concatenate_sft -from dwi_ml.data.io import load_file_to4d -from dwi_ml.data.processing.dwi.dwi import standardize_data +from dwi_ml.general.data.io import load_file_to4d +from dwi_ml.general.data.processing.dwi.dwi import standardize_data def format_filelist(filenames, enforce_presence, folder=None) -> List[str]: diff --git a/src/dwi_ml/data/hdf5/utils.py b/src/dwi_ml/general/data/hdf5/utils.py similarity index 98% rename from src/dwi_ml/data/hdf5/utils.py rename to src/dwi_ml/general/data/hdf5/utils.py index 993b7f26..88c186a0 100644 --- a/src/dwi_ml/data/hdf5/utils.py +++ b/src/dwi_ml/general/data/hdf5/utils.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser from typing import List -from dwi_ml.io_utils import add_resample_or_compress_arg +from dwi_ml.general.io_utils import add_resample_or_compress_arg def format_nb_blocs_connectivity(connectivity_nb_blocs) -> List: diff --git a/src/dwi_ml/data/io.py b/src/dwi_ml/general/data/io.py similarity index 100% rename from src/dwi_ml/data/io.py rename to src/dwi_ml/general/data/io.py diff --git a/src/dwi_ml/data/processing/streamlines/__init__.py b/src/dwi_ml/general/data/processing/__init__.py similarity index 100% rename from src/dwi_ml/data/processing/streamlines/__init__.py rename to src/dwi_ml/general/data/processing/__init__.py diff --git a/src/dwi_ml/data/processing/volume/__init__.py b/src/dwi_ml/general/data/processing/dwi/__init__.py similarity index 100% rename from src/dwi_ml/data/processing/volume/__init__.py rename to src/dwi_ml/general/data/processing/dwi/__init__.py diff --git a/src/dwi_ml/data/processing/dwi/dwi.py b/src/dwi_ml/general/data/processing/dwi/dwi.py similarity index 100% rename from src/dwi_ml/data/processing/dwi/dwi.py rename to src/dwi_ml/general/data/processing/dwi/dwi.py diff --git a/src/dwi_ml/experiment_utils/__init__.py b/src/dwi_ml/general/data/processing/space/__init__.py similarity index 100% rename from src/dwi_ml/experiment_utils/__init__.py rename to src/dwi_ml/general/data/processing/space/__init__.py diff --git a/src/dwi_ml/data/processing/space/neighborhood.py b/src/dwi_ml/general/data/processing/space/neighborhood.py similarity index 100% rename from src/dwi_ml/data/processing/space/neighborhood.py rename to src/dwi_ml/general/data/processing/space/neighborhood.py diff --git a/src/dwi_ml/data/processing/space/world_to_vox.py b/src/dwi_ml/general/data/processing/space/world_to_vox.py similarity index 100% rename from src/dwi_ml/data/processing/space/world_to_vox.py rename to src/dwi_ml/general/data/processing/space/world_to_vox.py diff --git a/src/dwi_ml/models/__init__.py b/src/dwi_ml/general/data/processing/streamlines/__init__.py similarity index 100% rename from src/dwi_ml/models/__init__.py rename to src/dwi_ml/general/data/processing/streamlines/__init__.py diff --git a/src/dwi_ml/data/processing/streamlines/data_augmentation.py b/src/dwi_ml/general/data/processing/streamlines/data_augmentation.py similarity index 100% rename from src/dwi_ml/data/processing/streamlines/data_augmentation.py rename to src/dwi_ml/general/data/processing/streamlines/data_augmentation.py diff --git a/src/dwi_ml/data/processing/streamlines/post_processing.py b/src/dwi_ml/general/data/processing/streamlines/post_processing.py similarity index 100% rename from src/dwi_ml/data/processing/streamlines/post_processing.py rename to src/dwi_ml/general/data/processing/streamlines/post_processing.py diff --git a/src/dwi_ml/data/processing/streamlines/sos_eos_management.py b/src/dwi_ml/general/data/processing/streamlines/sos_eos_management.py similarity index 99% rename from src/dwi_ml/data/processing/streamlines/sos_eos_management.py rename to src/dwi_ml/general/data/processing/streamlines/sos_eos_management.py index e0e5303d..1153d84d 100644 --- a/src/dwi_ml/data/processing/streamlines/sos_eos_management.py +++ b/src/dwi_ml/general/data/processing/streamlines/sos_eos_management.py @@ -17,7 +17,7 @@ import torch from torch.nn.functional import one_hot, pad -from dwi_ml.data.spheres import TorchSphere +from dwi_ml.general.data.spheres import TorchSphere def convert_dirs_to_class(batch_dirs: List[torch.Tensor], diff --git a/src/dwi_ml/data/processing/utils.py b/src/dwi_ml/general/data/processing/utils.py similarity index 100% rename from src/dwi_ml/data/processing/utils.py rename to src/dwi_ml/general/data/processing/utils.py diff --git a/src/dwi_ml/models/projects/__init__.py b/src/dwi_ml/general/data/processing/volume/__init__.py similarity index 100% rename from src/dwi_ml/models/projects/__init__.py rename to src/dwi_ml/general/data/processing/volume/__init__.py diff --git a/src/dwi_ml/data/processing/volume/interpolation.py b/src/dwi_ml/general/data/processing/volume/interpolation.py similarity index 99% rename from src/dwi_ml/data/processing/volume/interpolation.py rename to src/dwi_ml/general/data/processing/volume/interpolation.py index 79b51ab7..7bc6717e 100644 --- a/src/dwi_ml/data/processing/volume/interpolation.py +++ b/src/dwi_ml/general/data/processing/volume/interpolation.py @@ -4,7 +4,7 @@ import torch import numpy as np -from dwi_ml.data.processing.space.neighborhood import \ +from dwi_ml.general.data.processing.space.neighborhood import \ extend_coordinates_with_neighborhood B1 = np.array([[1, 0, 0, 0, 0, 0, 0, 0], diff --git a/src/dwi_ml/data/spheres.py b/src/dwi_ml/general/data/spheres.py similarity index 100% rename from src/dwi_ml/data/spheres.py rename to src/dwi_ml/general/data/spheres.py diff --git a/src/dwi_ml/models/utils/__init__.py b/src/dwi_ml/general/experiment_utils/__init__.py similarity index 100% rename from src/dwi_ml/models/utils/__init__.py rename to src/dwi_ml/general/experiment_utils/__init__.py diff --git a/src/dwi_ml/experiment_utils/memory.py b/src/dwi_ml/general/experiment_utils/memory.py similarity index 100% rename from src/dwi_ml/experiment_utils/memory.py rename to src/dwi_ml/general/experiment_utils/memory.py diff --git a/src/dwi_ml/experiment_utils/prints.py b/src/dwi_ml/general/experiment_utils/prints.py similarity index 100% rename from src/dwi_ml/experiment_utils/prints.py rename to src/dwi_ml/general/experiment_utils/prints.py diff --git a/src/dwi_ml/experiment_utils/timer.py b/src/dwi_ml/general/experiment_utils/timer.py similarity index 100% rename from src/dwi_ml/experiment_utils/timer.py rename to src/dwi_ml/general/experiment_utils/timer.py diff --git a/src/dwi_ml/experiment_utils/tqdm_logging.py b/src/dwi_ml/general/experiment_utils/tqdm_logging.py similarity index 100% rename from src/dwi_ml/experiment_utils/tqdm_logging.py rename to src/dwi_ml/general/experiment_utils/tqdm_logging.py diff --git a/src/dwi_ml/io_utils.py b/src/dwi_ml/general/io_utils.py similarity index 100% rename from src/dwi_ml/io_utils.py rename to src/dwi_ml/general/io_utils.py diff --git a/src/dwi_ml/testing/__init__.py b/src/dwi_ml/general/models/__init__.py similarity index 100% rename from src/dwi_ml/testing/__init__.py rename to src/dwi_ml/general/models/__init__.py diff --git a/src/dwi_ml/testing/projects/__init__.py b/src/dwi_ml/general/models/main_layers/__init__.py similarity index 100% rename from src/dwi_ml/testing/projects/__init__.py rename to src/dwi_ml/general/models/main_layers/__init__.py diff --git a/src/dwi_ml/models/direction_getter_models.py b/src/dwi_ml/general/models/main_layers/direction_getter_models.py similarity index 99% rename from src/dwi_ml/models/direction_getter_models.py rename to src/dwi_ml/general/models/main_layers/direction_getter_models.py index 8116e1b5..f3b7c0f8 100644 --- a/src/dwi_ml/models/direction_getter_models.py +++ b/src/dwi_ml/general/models/main_layers/direction_getter_models.py @@ -12,13 +12,13 @@ KLDivLoss) from torch.nn.modules.distance import PairwiseDistance -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ normalize_directions, compute_directions -from dwi_ml.data.processing.streamlines.sos_eos_management import \ +from dwi_ml.general.data.processing.streamlines.sos_eos_management import \ add_label_as_last_dim, convert_dirs_to_class -from dwi_ml.data.spheres import TorchSphere -from dwi_ml.models.utils.gaussians import independent_gaussian_log_prob -from dwi_ml.models.utils.fisher_von_mises import fisher_von_mises_log_prob +from dwi_ml.general.data.spheres import TorchSphere +from dwi_ml.general.models.utils.gaussians import independent_gaussian_log_prob +from dwi_ml.general.models.utils.fisher_von_mises import fisher_von_mises_log_prob """ The complete formulas and explanations are available in our doc: diff --git a/src/dwi_ml/models/embeddings.py b/src/dwi_ml/general/models/main_layers/embeddings.py similarity index 100% rename from src/dwi_ml/models/embeddings.py rename to src/dwi_ml/general/models/main_layers/embeddings.py diff --git a/src/dwi_ml/models/positional_encoding.py b/src/dwi_ml/general/models/main_layers/positional_encoding.py similarity index 100% rename from src/dwi_ml/models/positional_encoding.py rename to src/dwi_ml/general/models/main_layers/positional_encoding.py diff --git a/src/dwi_ml/models/stacked_rnn.py b/src/dwi_ml/general/models/main_layers/stacked_rnn.py similarity index 98% rename from src/dwi_ml/models/stacked_rnn.py rename to src/dwi_ml/general/models/main_layers/stacked_rnn.py index 73741cf5..70320c32 100644 --- a/src/dwi_ml/models/stacked_rnn.py +++ b/src/dwi_ml/general/models/main_layers/stacked_rnn.py @@ -9,8 +9,6 @@ keys_to_rnn_class = {'lstm': torch.nn.LSTM, 'gru': torch.nn.GRU} -# Note. This logger's logging level can be modified through the main model, -# Learn2trackModel. logger = logging.getLogger('model_logger') # Same logger as main dwi_ml. # Skip connection: In https://arxiv.org/pdf/1308.0850v5.pdf they don't add a @@ -150,7 +148,7 @@ def forward(self, inputs: PackedSequence, Parameters ---------- inputs : PackedSequence - Current implementation of the learn2track model calls this using + Current implementation of the Learn2track model calls this using packed sequence. We run the RNN on the packed data, but the normalization and dropout on their tensor version. hidden_states : list[states] diff --git a/src/dwi_ml/models/projects/transformer_sublayers.py b/src/dwi_ml/general/models/main_layers/transformer_sublayers.py similarity index 100% rename from src/dwi_ml/models/projects/transformer_sublayers.py rename to src/dwi_ml/general/models/main_layers/transformer_sublayers.py diff --git a/src/dwi_ml/models/utils/transformers_from_torch.py b/src/dwi_ml/general/models/main_layers/transformers_from_torch.py similarity index 99% rename from src/dwi_ml/models/utils/transformers_from_torch.py rename to src/dwi_ml/general/models/main_layers/transformers_from_torch.py index 0d8cb839..6e1813da 100644 --- a/src/dwi_ml/models/utils/transformers_from_torch.py +++ b/src/dwi_ml/general/models/main_layers/transformers_from_torch.py @@ -17,7 +17,7 @@ from torch.nn import Transformer, TransformerDecoder, TransformerEncoder from torch.nn.modules.transformer import _get_seq_len, _detect_is_causal_mask -from dwi_ml.models.projects.transformer_sublayers import \ +from dwi_ml.general.models.main_layers.transformer_sublayers import \ ModifiedTransformerDecoderLayer, ModifiedTransformerEncoderLayer logger = logging.getLogger('model_logger') diff --git a/src/dwi_ml/tracking/__init__.py b/src/dwi_ml/general/models/main_models/__init__.py similarity index 100% rename from src/dwi_ml/tracking/__init__.py rename to src/dwi_ml/general/models/main_models/__init__.py diff --git a/src/dwi_ml/models/main_abstract_model.py b/src/dwi_ml/general/models/main_models/main_abstract_model.py similarity index 97% rename from src/dwi_ml/models/main_abstract_model.py rename to src/dwi_ml/general/models/main_models/main_abstract_model.py index 433a62be..b9c68fb5 100644 --- a/src/dwi_ml/models/main_abstract_model.py +++ b/src/dwi_ml/general/models/main_models/main_abstract_model.py @@ -7,8 +7,8 @@ import torch -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.io_utils import add_resample_or_compress_arg +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.io_utils import add_resample_or_compress_arg logger = logging.getLogger('model_logger') @@ -149,7 +149,7 @@ def save_params_and_state(self, model_dir): # the new. to_remove = None if os.path.exists(model_dir): - to_remove = os.path.join(model_dir, "..", "model_old") + to_remove = os.path.join(model_dir, "../../..", "model_old") shutil.move(model_dir, to_remove) os.makedirs(model_dir) diff --git a/src/dwi_ml/models/main_models.py b/src/dwi_ml/general/models/main_models/main_models.py similarity index 97% rename from src/dwi_ml/models/main_models.py rename to src/dwi_ml/general/models/main_models/main_models.py index 6e9c54cd..1df53718 100644 --- a/src/dwi_ml/models/main_models.py +++ b/src/dwi_ml/general/models/main_models/main_models.py @@ -7,16 +7,18 @@ import torch from torch import Tensor -from dwi_ml.data.dataset.multi_subject_containers import MultisubjectSubset -from dwi_ml.data.processing.volume.interpolation import \ +from dwi_ml.general.data.dataset.multi_subject_containers import \ + MultisubjectSubset +from dwi_ml.general.data.processing.volume.interpolation import \ interpolate_volume_in_neighborhood -from dwi_ml.data.processing.space.neighborhood import \ +from dwi_ml.general.data.processing.space.neighborhood import \ prepare_neighborhood_vectors -from dwi_ml.models.direction_getter_models import keys_to_direction_getters -from dwi_ml.models.embeddings import (keys_to_embeddings, NNEmbedding, - NoEmbedding) -from dwi_ml.models.main_abstract_model import MainModelAbstract -from dwi_ml.models.utils.direction_getters import add_direction_getter_args +from dwi_ml.general.models.main_layers.direction_getter_models import \ + keys_to_direction_getters +from dwi_ml.general.models.main_layers.embeddings import ( + keys_to_embeddings, NNEmbedding, NoEmbedding) +from dwi_ml.general.models.main_models.main_abstract_model import MainModelAbstract +from dwi_ml.general.models.utils.direction_getters import add_direction_getter_args logger = logging.getLogger('model_logger') diff --git a/src/dwi_ml/tracking/projects/__init__.py b/src/dwi_ml/general/models/utils/__init__.py similarity index 100% rename from src/dwi_ml/tracking/projects/__init__.py rename to src/dwi_ml/general/models/utils/__init__.py diff --git a/src/dwi_ml/models/utils/direction_getters.py b/src/dwi_ml/general/models/utils/direction_getters.py similarity index 98% rename from src/dwi_ml/models/utils/direction_getters.py rename to src/dwi_ml/general/models/utils/direction_getters.py index fb41eb0e..e985d86f 100644 --- a/src/dwi_ml/models/utils/direction_getters.py +++ b/src/dwi_ml/general/models/utils/direction_getters.py @@ -2,7 +2,7 @@ import logging from argparse import ArgumentParser -from dwi_ml.models.direction_getter_models import keys_to_direction_getters +from dwi_ml.general.models.main_layers.direction_getter_models import keys_to_direction_getters def add_direction_getter_args(p: ArgumentParser, gaussian_fisher_args=True): diff --git a/src/dwi_ml/models/utils/fisher_von_mises.py b/src/dwi_ml/general/models/utils/fisher_von_mises.py similarity index 100% rename from src/dwi_ml/models/utils/fisher_von_mises.py rename to src/dwi_ml/general/models/utils/fisher_von_mises.py diff --git a/src/dwi_ml/models/utils/gaussians.py b/src/dwi_ml/general/models/utils/gaussians.py similarity index 100% rename from src/dwi_ml/models/utils/gaussians.py rename to src/dwi_ml/general/models/utils/gaussians.py diff --git a/src/dwi_ml/training/__init__.py b/src/dwi_ml/general/testing/__init__.py similarity index 100% rename from src/dwi_ml/training/__init__.py rename to src/dwi_ml/general/testing/__init__.py diff --git a/src/dwi_ml/testing/testers.py b/src/dwi_ml/general/testing/testers.py similarity index 97% rename from src/dwi_ml/testing/testers.py rename to src/dwi_ml/general/testing/testers.py index b6aca66e..23e1ef52 100644 --- a/src/dwi_ml/testing/testers.py +++ b/src/dwi_ml/general/testing/testers.py @@ -5,11 +5,11 @@ import torch from tqdm import tqdm -from dwi_ml.data.processing.streamlines.data_augmentation import \ +from dwi_ml.general.data.processing.streamlines.data_augmentation import \ resample_or_compress -from dwi_ml.models.main_models import (ModelWithOneInput, - ModelWithDirectionGetter) -from dwi_ml.testing.utils import prepare_dataset_one_subj +from dwi_ml.general.models.main_models.main_models import ( + ModelWithOneInput, ModelWithDirectionGetter) +from dwi_ml.general.testing.utils import prepare_dataset_one_subj logger = logging.getLogger('tester_logger') diff --git a/src/dwi_ml/testing/utils.py b/src/dwi_ml/general/testing/utils.py similarity index 95% rename from src/dwi_ml/testing/utils.py rename to src/dwi_ml/general/testing/utils.py index a718817b..0d9f7613 100644 --- a/src/dwi_ml/testing/utils.py +++ b/src/dwi_ml/general/testing/utils.py @@ -7,8 +7,8 @@ import torch -from dwi_ml.data.dataset.multi_subject_containers import (MultiSubjectDataset, - MultisubjectSubset) +from dwi_ml.general.data.dataset.multi_subject_containers import (MultiSubjectDataset, + MultisubjectSubset) def add_args_testing_subj_hdf5(p: ArgumentParser, optional_hdf5=False, diff --git a/src/dwi_ml/testing/visu_loss.py b/src/dwi_ml/general/testing/visu_loss.py similarity index 99% rename from src/dwi_ml/testing/visu_loss.py rename to src/dwi_ml/general/testing/visu_loss.py index 5dd2675c..65b05734 100644 --- a/src/dwi_ml/testing/visu_loss.py +++ b/src/dwi_ml/general/testing/visu_loss.py @@ -16,8 +16,8 @@ # toDo in scilpy2.0: use # from scilpy.tractograms.dps_and_dpp_management import add_data_as_color_dpp -from dwi_ml.models.main_models import ModelWithDirectionGetter -from dwi_ml.testing.testers import load_sft_from_hdf5, \ +from dwi_ml.general.models.main_models.main_models import ModelWithDirectionGetter +from dwi_ml.general.testing.testers import load_sft_from_hdf5, \ TesterWithDirectionGetter blue = [2., 75., 252.] diff --git a/src/dwi_ml/testing/visu_loss_utils.py b/src/dwi_ml/general/testing/visu_loss_utils.py similarity index 99% rename from src/dwi_ml/testing/visu_loss_utils.py rename to src/dwi_ml/general/testing/visu_loss_utils.py index 9db78d52..f004d62a 100644 --- a/src/dwi_ml/testing/visu_loss_utils.py +++ b/src/dwi_ml/general/testing/visu_loss_utils.py @@ -8,7 +8,7 @@ assert_inputs_exist, assert_outputs_exist, add_reference_arg, ranged_type) -from dwi_ml.io_utils import add_memory_args +from dwi_ml.general.io_utils import add_memory_args def prepare_args_visu_loss(p: ArgumentParser): diff --git a/src/dwi_ml/training/utils/__init__.py b/src/dwi_ml/general/tracking/__init__.py similarity index 100% rename from src/dwi_ml/training/utils/__init__.py rename to src/dwi_ml/general/tracking/__init__.py diff --git a/src/dwi_ml/tracking/io_utils.py b/src/dwi_ml/general/tracking/io_utils.py similarity index 96% rename from src/dwi_ml/tracking/io_utils.py rename to src/dwi_ml/general/tracking/io_utils.py index 3667068d..6f0e0a07 100644 --- a/src/dwi_ml/tracking/io_utils.py +++ b/src/dwi_ml/general/tracking/io_utils.py @@ -11,11 +11,12 @@ from scilpy.tracking.seed import SeedGenerator -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 -from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.tracker import DWIMLAbstractTracker +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.io_utils import ( + add_arg_existing_experiment_path, add_memory_args) +from dwi_ml.general.testing.utils import add_args_testing_subj_hdf5 +from dwi_ml.general.tracking.tracking_mask import TrackingMask +from dwi_ml.general.tracking.tracker import DWIMLAbstractTracker ALWAYS_VOX_SPACE = Space.VOX ALWAYS_CORNER = Origin('corner') diff --git a/src/dwi_ml/tracking/propagation.py b/src/dwi_ml/general/tracking/propagation.py similarity index 99% rename from src/dwi_ml/tracking/propagation.py rename to src/dwi_ml/general/tracking/propagation.py index 86639d5d..74cc58b3 100644 --- a/src/dwi_ml/tracking/propagation.py +++ b/src/dwi_ml/general/tracking/propagation.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from dwi_ml.tracking.tracking_mask import TrackingMask +from dwi_ml.general.tracking.tracking_mask import TrackingMask logger = logging.getLogger('tracker_logger') diff --git a/src/dwi_ml/tracking/tracker.py b/src/dwi_ml/general/tracking/tracker.py similarity index 98% rename from src/dwi_ml/tracking/tracker.py rename to src/dwi_ml/general/tracking/tracker.py index 012f4d1c..da72e1ed 100644 --- a/src/dwi_ml/tracking/tracker.py +++ b/src/dwi_ml/general/tracking/tracker.py @@ -10,19 +10,19 @@ from dipy.tracking.streamlinespeed import compress_streamlines import numpy as np import torch -from dwi_ml.tracking.utils import prepare_step_size_vox +from dwi_ml.general.tracking.utils import prepare_step_size_vox from torch import Tensor from tqdm.contrib.logging import tqdm_logging_redirect from scilpy.tracking.seed import SeedGenerator -from dwi_ml.data.dataset.multi_subject_containers import MultisubjectSubset -from dwi_ml.models.direction_getter_models import \ +from dwi_ml.general.data.dataset.multi_subject_containers import MultisubjectSubset +from dwi_ml.general.models.main_layers.direction_getter_models import \ AbstractRegressionDG -from dwi_ml.models.main_models import ModelWithDirectionGetter, \ +from dwi_ml.general.models.main_models.main_models import ModelWithDirectionGetter, \ ModelWithOneInput -from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.tracking.tracking_mask import TrackingMask +from dwi_ml.general.tracking.propagation import propagate_multiple_lines +from dwi_ml.general.tracking.tracking_mask import TrackingMask logger = logging.getLogger('tracker_logger') diff --git a/src/dwi_ml/tracking/tracking_mask.py b/src/dwi_ml/general/tracking/tracking_mask.py similarity index 97% rename from src/dwi_ml/tracking/tracking_mask.py rename to src/dwi_ml/general/tracking/tracking_mask.py index 84199732..4a5d4e9d 100644 --- a/src/dwi_ml/tracking/tracking_mask.py +++ b/src/dwi_ml/general/tracking/tracking_mask.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import torch -from dwi_ml.data.processing.volume.interpolation import \ +from dwi_ml.general.data.processing.volume.interpolation import \ torch_nearest_neighbor_interpolation, torch_trilinear_interpolation eps = 1e-6 diff --git a/src/dwi_ml/tracking/utils.py b/src/dwi_ml/general/tracking/utils.py similarity index 100% rename from src/dwi_ml/tracking/utils.py rename to src/dwi_ml/general/tracking/utils.py diff --git a/src/dwi_ml/unit_tests/__init__.py b/src/dwi_ml/general/training/__init__.py similarity index 100% rename from src/dwi_ml/unit_tests/__init__.py rename to src/dwi_ml/general/training/__init__.py diff --git a/src/dwi_ml/training/batch_loaders.py b/src/dwi_ml/general/training/batch_loaders.py similarity index 98% rename from src/dwi_ml/training/batch_loaders.py rename to src/dwi_ml/general/training/batch_loaders.py index d22c1029..353a4ece 100644 --- a/src/dwi_ml/training/batch_loaders.py +++ b/src/dwi_ml/general/training/batch_loaders.py @@ -48,12 +48,12 @@ import numpy as np import torch -from dwi_ml.data.dataset.multi_subject_containers import ( +from dwi_ml.general.data.dataset.multi_subject_containers import ( MultiSubjectDataset, MultisubjectSubset) -from dwi_ml.data.processing.streamlines.data_augmentation import ( +from dwi_ml.general.data.processing.streamlines.data_augmentation import ( reverse_streamlines, split_streamlines, resample_or_compress) -from dwi_ml.data.processing.utils import add_noise_to_tensor -from dwi_ml.models.main_models import ModelWithOneInput, \ +from dwi_ml.general.data.processing.utils import add_noise_to_tensor +from dwi_ml.general.models.main_models.main_models import ModelWithOneInput, \ ModelWithNeighborhood, MainModelAbstract logger = logging.getLogger('batch_loader_logger') diff --git a/src/dwi_ml/training/batch_samplers.py b/src/dwi_ml/general/training/batch_samplers.py similarity index 99% rename from src/dwi_ml/training/batch_samplers.py rename to src/dwi_ml/general/training/batch_samplers.py index 035ede3c..b3a076af 100644 --- a/src/dwi_ml/training/batch_samplers.py +++ b/src/dwi_ml/general/training/batch_samplers.py @@ -30,8 +30,8 @@ import numpy as np from torch.utils.data import Sampler -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.general.experiment_utils.prints import format_dict_to_str DEFAULT_CHUNK_SIZE = 256 logger = logging.getLogger('batch_sampler_logger') diff --git a/src/dwi_ml/training/trainers.py b/src/dwi_ml/general/training/trainers.py similarity index 98% rename from src/dwi_ml/training/trainers.py rename to src/dwi_ml/general/training/trainers.py index 52e5bd52..94945d07 100644 --- a/src/dwi_ml/training/trainers.py +++ b/src/dwi_ml/general/training/trainers.py @@ -12,17 +12,17 @@ from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from dwi_ml.experiment_utils.memory import ( +from dwi_ml.general.experiment_utils.memory import ( log_gpu_per_tensor, log_currently_allocated, log_gpu_general_info, torch_reset_peaks_memory, log_max_allocated) -from dwi_ml.experiment_utils.tqdm_logging import tqdm_logging_redirect -from dwi_ml.models.main_abstract_model import MainModelAbstract -from dwi_ml.models.main_models import ModelWithDirectionGetter -from dwi_ml.training.batch_loaders import ( +from dwi_ml.general.experiment_utils.tqdm_logging import tqdm_logging_redirect +from dwi_ml.general.models.main_models.main_abstract_model import MainModelAbstract +from dwi_ml.general.models.main_models.main_models import ModelWithDirectionGetter +from dwi_ml.general.training.batch_loaders import ( DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput) -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler -from dwi_ml.training.utils.gradient_norm import compute_gradient_norm -from dwi_ml.training.utils.monitoring import ( +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.general.training.utils.gradient_norm import compute_gradient_norm +from dwi_ml.general.training.utils.monitoring import ( BestEpochMonitor, IterTimer, BatchHistoryMonitor, TimeMonitor, EarlyStoppingError) @@ -795,8 +795,7 @@ def back_propagation(self, loss): logger.debug('*** Computing back propagation') loss.backward() - # Any other steps. Ex: clip gradients. Not implemented here. - # See Learn2track's Trainer for an example. + # Any other steps. Ex: clip gradients. unclipped_grad_norm = self.fix_parameters() # Supervizing the gradient's norm. diff --git a/src/dwi_ml/training/trainers_withGV.py b/src/dwi_ml/general/training/trainers_withGV.py similarity index 96% rename from src/dwi_ml/training/trainers_withGV.py rename to src/dwi_ml/general/training/trainers_withGV.py index 342dea99..80c60061 100644 --- a/src/dwi_ml/training/trainers_withGV.py +++ b/src/dwi_ml/general/training/trainers_withGV.py @@ -26,15 +26,15 @@ import torch from torch.nn import PairwiseDistance -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ compute_triu_connectivity_from_blocs, compute_triu_connectivity_from_labels -from dwi_ml.experiment_utils.memory import BYTES_IN_GB -from dwi_ml.models.main_models import ModelWithDirectionGetter -from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.tracking.io_utils import prepare_tracking_mask -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput -from dwi_ml.training.trainers import DWIMLTrainerOneInput -from dwi_ml.training.utils.monitoring import BatchHistoryMonitor +from dwi_ml.general.experiment_utils.memory import BYTES_IN_GB +from dwi_ml.general.models.main_models.main_models import ModelWithDirectionGetter +from dwi_ml.general.tracking.propagation import propagate_multiple_lines +from dwi_ml.general.tracking.io_utils import prepare_tracking_mask +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.training.trainers import DWIMLTrainerOneInput +from dwi_ml.general.training.utils.monitoring import BatchHistoryMonitor logger = logging.getLogger('train_logger') diff --git a/src/dwi_ml/unit_tests/utils/__init__.py b/src/dwi_ml/general/training/utils/__init__.py similarity index 100% rename from src/dwi_ml/unit_tests/utils/__init__.py rename to src/dwi_ml/general/training/utils/__init__.py diff --git a/src/dwi_ml/training/utils/batch_loaders.py b/src/dwi_ml/general/training/utils/batch_loaders.py similarity index 93% rename from src/dwi_ml/training/utils/batch_loaders.py rename to src/dwi_ml/general/training/utils/batch_loaders.py index 24df17a1..7b2c3b08 100644 --- a/src/dwi_ml/training/utils/batch_loaders.py +++ b/src/dwi_ml/general/training/utils/batch_loaders.py @@ -2,9 +2,9 @@ import argparse import logging -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput def add_args_batch_loader(p: argparse.ArgumentParser): diff --git a/src/dwi_ml/training/utils/batch_samplers.py b/src/dwi_ml/general/training/utils/batch_samplers.py similarity index 95% rename from src/dwi_ml/training/utils/batch_samplers.py rename to src/dwi_ml/general/training/utils/batch_samplers.py index fc60b94f..63d6a211 100644 --- a/src/dwi_ml/training/utils/batch_samplers.py +++ b/src/dwi_ml/general/training/utils/batch_samplers.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- import argparse -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.general.experiment_utils.timer import Timer +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler def add_args_batch_sampler(p: argparse.ArgumentParser): diff --git a/src/dwi_ml/training/utils/experiment.py b/src/dwi_ml/general/training/utils/experiment.py similarity index 100% rename from src/dwi_ml/training/utils/experiment.py rename to src/dwi_ml/general/training/utils/experiment.py diff --git a/src/dwi_ml/training/utils/gradient_norm.py b/src/dwi_ml/general/training/utils/gradient_norm.py similarity index 100% rename from src/dwi_ml/training/utils/gradient_norm.py rename to src/dwi_ml/general/training/utils/gradient_norm.py diff --git a/src/dwi_ml/training/utils/monitoring.py b/src/dwi_ml/general/training/utils/monitoring.py similarity index 100% rename from src/dwi_ml/training/utils/monitoring.py rename to src/dwi_ml/general/training/utils/monitoring.py diff --git a/src/dwi_ml/training/utils/trainer.py b/src/dwi_ml/general/training/utils/trainer.py similarity index 97% rename from src/dwi_ml/training/utils/trainer.py rename to src/dwi_ml/general/training/utils/trainer.py index 88cecae4..f5c62047 100644 --- a/src/dwi_ml/training/utils/trainer.py +++ b/src/dwi_ml/general/training/utils/trainer.py @@ -2,8 +2,8 @@ import argparse import logging -from dwi_ml.training.utils.monitoring import EarlyStoppingError -from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.general.training.utils.monitoring import EarlyStoppingError +from dwi_ml.general.experiment_utils.timer import Timer logger = logging.getLogger('train_logger') diff --git a/src/dwi_ml/unit_tests/visual_tests/__init__.py b/src/dwi_ml/general/unit_tests/__init__.py similarity index 100% rename from src/dwi_ml/unit_tests/visual_tests/__init__.py rename to src/dwi_ml/general/unit_tests/__init__.py diff --git a/src/dwi_ml/unit_tests/test_dataset.py b/src/dwi_ml/general/unit_tests/test_dataset.py similarity index 93% rename from src/dwi_ml/unit_tests/test_dataset.py rename to src/dwi_ml/general/unit_tests/test_dataset.py index 47aaeaf9..fa73777d 100755 --- a/src/dwi_ml/unit_tests/test_dataset.py +++ b/src/dwi_ml/general/unit_tests/test_dataset.py @@ -8,21 +8,21 @@ import numpy as np from dipy.io.stateful_tractogram import StatefulTractogram -from dwi_ml.data.dataset.multi_subject_containers import \ +from dwi_ml.general.data.dataset.multi_subject_containers import \ MultiSubjectDataset, MultisubjectSubset -from dwi_ml.data.dataset.mri_data_containers import MRIData, LazyMRIData -from dwi_ml.data.dataset.single_subject_containers import \ +from dwi_ml.general.data.dataset.mri_data_containers import MRIData, LazyMRIData +from dwi_ml.general.data.dataset.single_subject_containers import \ SubjectData, LazySubjectData -from dwi_ml.data.dataset.subjectdata_list_containers import \ +from dwi_ml.general.data.dataset.subjectdata_list_containers import \ SubjectsDataList, LazySubjectsDataList -from dwi_ml.data.dataset.streamline_containers import \ +from dwi_ml.general.data.dataset.streamline_containers import \ SFTData, LazySFTData -from dwi_ml.unit_tests.utils.expected_values import ( +from dwi_ml.general.unit_tests.utils.expected_values import ( TEST_EXPECTED_SUBJ_NAMES, TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_NB_STREAMLINES, TEST_EXPECTED_MRI_SHAPE, TEST_EXPECTED_NB_SUBJECTS, TEST_EXPECTED_NB_FEATURES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import \ +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import \ fetch_testing_data dps_key_1 = 'mean_color_dps' diff --git a/src/dwi_ml/unit_tests/test_directionGetter_losses.py b/src/dwi_ml/general/unit_tests/test_directionGetter_losses.py similarity index 99% rename from src/dwi_ml/unit_tests/test_directionGetter_losses.py rename to src/dwi_ml/general/unit_tests/test_directionGetter_losses.py index e6f0719a..0da4d264 100755 --- a/src/dwi_ml/unit_tests/test_directionGetter_losses.py +++ b/src/dwi_ml/general/unit_tests/test_directionGetter_losses.py @@ -10,11 +10,11 @@ from torch.nn.utils.rnn import PackedSequence from torch.nn.functional import softmax -from dwi_ml.models.direction_getter_models import ( +from dwi_ml.general.models.main_layers.direction_getter_models import ( CosineRegressionDG, FisherVonMisesDG, GaussianMixtureDG, L2RegressionDG, SingleGaussianDG, SphereClassificationDG, AbstractDirectionGetterModel) -from dwi_ml.models.utils.fisher_von_mises import ( +from dwi_ml.general.models.utils.fisher_von_mises import ( fisher_von_mises_log_prob_vector) """ diff --git a/src/dwi_ml/unit_tests/test_model_embeddings.py b/src/dwi_ml/general/unit_tests/test_model_embeddings.py similarity index 98% rename from src/dwi_ml/unit_tests/test_model_embeddings.py rename to src/dwi_ml/general/unit_tests/test_model_embeddings.py index c8bfb09c..c1d6defe 100644 --- a/src/dwi_ml/unit_tests/test_model_embeddings.py +++ b/src/dwi_ml/general/unit_tests/test_model_embeddings.py @@ -5,7 +5,7 @@ import numpy as np import torch -from dwi_ml.models.embeddings import keys_to_embeddings +from dwi_ml.general.models.main_layers.embeddings import keys_to_embeddings def test_embeddings(): diff --git a/src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py b/src/dwi_ml/general/unit_tests/test_model_prepare_batchOneInput.py similarity index 76% rename from src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py rename to src/dwi_ml/general/unit_tests/test_model_prepare_batchOneInput.py index abf0a2d1..b40aaf53 100644 --- a/src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py +++ b/src/dwi_ml/general/unit_tests/test_model_prepare_batchOneInput.py @@ -7,10 +7,10 @@ import numpy as np import torch -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.models.main_models import ModelWithOneInput -from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data -from dwi_ml.unit_tests.utils.expected_values import TEST_EXPECTED_MRI_SHAPE +from dwi_ml.general.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.general.models.main_models.main_models import ModelWithOneInput +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import fetch_testing_data +from dwi_ml.general.unit_tests.utils.expected_values import TEST_EXPECTED_MRI_SHAPE def test_model_batch(): diff --git a/src/dwi_ml/unit_tests/test_models_positional_encoding.py b/src/dwi_ml/general/unit_tests/test_models_positional_encoding.py similarity index 92% rename from src/dwi_ml/unit_tests/test_models_positional_encoding.py rename to src/dwi_ml/general/unit_tests/test_models_positional_encoding.py index 58f92918..e467a15a 100644 --- a/src/dwi_ml/unit_tests/test_models_positional_encoding.py +++ b/src/dwi_ml/general/unit_tests/test_models_positional_encoding.py @@ -3,7 +3,7 @@ import math import numpy as np -from dwi_ml.models.positional_encoding import ( +from dwi_ml.general.models.main_layers.positional_encoding import ( SinusoidalPositionalEncoding) diff --git a/src/dwi_ml/unit_tests/test_submethods_connectivity_matrix.py b/src/dwi_ml/general/unit_tests/test_submethods_connectivity_matrix.py similarity index 87% rename from src/dwi_ml/unit_tests/test_submethods_connectivity_matrix.py rename to src/dwi_ml/general/unit_tests/test_submethods_connectivity_matrix.py index 3b8f9793..02ea416b 100644 --- a/src/dwi_ml/unit_tests/test_submethods_connectivity_matrix.py +++ b/src/dwi_ml/general/unit_tests/test_submethods_connectivity_matrix.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import numpy as np -from dwi_ml.data.processing.streamlines.post_processing import compute_triu_connectivity_from_blocs +from dwi_ml.general.data.processing.streamlines.post_processing import compute_triu_connectivity_from_blocs def test_connectivity(): diff --git a/src/dwi_ml/unit_tests/test_submethods_eos_sos_and_class.py b/src/dwi_ml/general/unit_tests/test_submethods_eos_sos_and_class.py similarity index 96% rename from src/dwi_ml/unit_tests/test_submethods_eos_sos_and_class.py rename to src/dwi_ml/general/unit_tests/test_submethods_eos_sos_and_class.py index 20387476..4066e112 100644 --- a/src/dwi_ml/unit_tests/test_submethods_eos_sos_and_class.py +++ b/src/dwi_ml/general/unit_tests/test_submethods_eos_sos_and_class.py @@ -1,9 +1,9 @@ #!/usr/bin/env python import torch from dipy.data import get_sphere -from dwi_ml.data.processing.streamlines.sos_eos_management import \ +from dwi_ml.general.data.processing.streamlines.sos_eos_management import \ add_label_as_last_dim, add_zeros_sos_eos, convert_dirs_to_class -from dwi_ml.data.spheres import TorchSphere +from dwi_ml.general.data.spheres import TorchSphere from matplotlib import pyplot as plt streamline = torch.as_tensor([[1.0, 1, 1], diff --git a/src/dwi_ml/unit_tests/test_submethods_interpolation.py b/src/dwi_ml/general/unit_tests/test_submethods_interpolation.py similarity index 96% rename from src/dwi_ml/unit_tests/test_submethods_interpolation.py rename to src/dwi_ml/general/unit_tests/test_submethods_interpolation.py index 958f6547..ee632ab7 100644 --- a/src/dwi_ml/unit_tests/test_submethods_interpolation.py +++ b/src/dwi_ml/general/unit_tests/test_submethods_interpolation.py @@ -4,9 +4,9 @@ import numpy as np import torch -from dwi_ml.data.processing.space.neighborhood import \ +from dwi_ml.general.data.processing.space.neighborhood import \ prepare_neighborhood_vectors, unflatten_neighborhood -from dwi_ml.data.processing.volume.interpolation import \ +from dwi_ml.general.data.processing.volume.interpolation import \ interpolate_volume_in_neighborhood diff --git a/src/dwi_ml/unit_tests/test_submethods_neighborhood.py b/src/dwi_ml/general/unit_tests/test_submethods_neighborhood.py similarity index 96% rename from src/dwi_ml/unit_tests/test_submethods_neighborhood.py rename to src/dwi_ml/general/unit_tests/test_submethods_neighborhood.py index e319df3a..ec3dc002 100644 --- a/src/dwi_ml/unit_tests/test_submethods_neighborhood.py +++ b/src/dwi_ml/general/unit_tests/test_submethods_neighborhood.py @@ -3,7 +3,7 @@ import torch -from dwi_ml.data.processing.space.neighborhood import \ +from dwi_ml.general.data.processing.space.neighborhood import \ prepare_neighborhood_vectors, extend_coordinates_with_neighborhood diff --git a/src/dwi_ml/unit_tests/test_submethods_packing.py b/src/dwi_ml/general/unit_tests/test_submethods_packing.py similarity index 100% rename from src/dwi_ml/unit_tests/test_submethods_packing.py rename to src/dwi_ml/general/unit_tests/test_submethods_packing.py diff --git a/src/dwi_ml/unit_tests/test_submethods_previous_dirs.py b/src/dwi_ml/general/unit_tests/test_submethods_previous_dirs.py similarity index 96% rename from src/dwi_ml/unit_tests/test_submethods_previous_dirs.py rename to src/dwi_ml/general/unit_tests/test_submethods_previous_dirs.py index a8c0bd91..783676af 100755 --- a/src/dwi_ml/unit_tests/test_submethods_previous_dirs.py +++ b/src/dwi_ml/general/unit_tests/test_submethods_previous_dirs.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import torch -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ compute_n_previous_dirs NB_PREVIOUS_DIRS = 4 diff --git a/src/dwi_ml/unit_tests/test_train_batch_loader.py b/src/dwi_ml/general/unit_tests/test_train_batch_loader.py similarity index 95% rename from src/dwi_ml/unit_tests/test_train_batch_loader.py rename to src/dwi_ml/general/unit_tests/test_train_batch_loader.py index 63125a4f..d4b8232b 100755 --- a/src/dwi_ml/unit_tests/test_train_batch_loader.py +++ b/src/dwi_ml/general/unit_tests/test_train_batch_loader.py @@ -6,10 +6,10 @@ from dipy.io.stateful_tractogram import set_sft_logger_level from torch.utils.data.dataloader import DataLoader -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.models.main_models import ModelWithOneInput -from dwi_ml.unit_tests.utils.expected_values import TEST_EXPECTED_NB_STREAMLINES -from dwi_ml.unit_tests.utils.data_and_models_for_tests import ( +from dwi_ml.general.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.general.models.main_models.main_models import ModelWithOneInput +from dwi_ml.general.unit_tests.utils.expected_values import TEST_EXPECTED_NB_STREAMLINES +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import ( create_test_batch_sampler, create_batch_loader, fetch_testing_data) SPLIT_RATIO = 0.5 diff --git a/src/dwi_ml/unit_tests/test_train_batch_sampler.py b/src/dwi_ml/general/unit_tests/test_train_batch_sampler.py similarity index 95% rename from src/dwi_ml/unit_tests/test_train_batch_sampler.py rename to src/dwi_ml/general/unit_tests/test_train_batch_sampler.py index 1f7a414a..628a094b 100755 --- a/src/dwi_ml/unit_tests/test_train_batch_sampler.py +++ b/src/dwi_ml/general/unit_tests/test_train_batch_sampler.py @@ -4,10 +4,10 @@ from dipy.tracking.metrics import length -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.unit_tests.utils.expected_values import ( +from dwi_ml.general.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.general.unit_tests.utils.expected_values import ( TEST_EXPECTED_SUBJ_NAMES, TEST_EXPECTED_NB_STREAMLINES) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import ( +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import ( create_test_batch_sampler, fetch_testing_data) diff --git a/src/dwi_ml/unit_tests/test_train_trainerOneInput.py b/src/dwi_ml/general/unit_tests/test_train_trainerOneInput.py similarity index 93% rename from src/dwi_ml/unit_tests/test_train_trainerOneInput.py rename to src/dwi_ml/general/unit_tests/test_train_trainerOneInput.py index f6f56de4..77cc0220 100644 --- a/src/dwi_ml/unit_tests/test_train_trainerOneInput.py +++ b/src/dwi_ml/general/unit_tests/test_train_trainerOneInput.py @@ -5,9 +5,9 @@ import pytest -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.training.trainers import DWIMLTrainerOneInput -from dwi_ml.unit_tests.utils.data_and_models_for_tests import ( +from dwi_ml.general.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.general.training.trainers import DWIMLTrainerOneInput +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import ( create_test_batch_sampler, create_batch_loader, fetch_testing_data, ModelForTest, TrackingModelForTestWithPD) diff --git a/src/dwi_ml/general/unit_tests/utils/__init__.py b/src/dwi_ml/general/unit_tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/src/dwi_ml/general/unit_tests/utils/data_and_models_for_tests.py similarity index 95% rename from src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py rename to src/dwi_ml/general/unit_tests/utils/data_and_models_for_tests.py index 1ac3dfb0..fb782359 100644 --- a/src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/src/dwi_ml/general/unit_tests/utils/data_and_models_for_tests.py @@ -8,15 +8,15 @@ from scilpy import get_home from scilpy.io.fetcher import fetch_data -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ compute_directions -from dwi_ml.models.main_models import ( +from dwi_ml.general.models.main_models.main_models import ( ModelWithDirectionGetter, ModelWithNeighborhood, ModelWithOneInput, ModelWithPreviousDirections) -from dwi_ml.unit_tests.utils.expected_values import ( +from dwi_ml.general.unit_tests.utils.expected_values import ( TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_VOLUME_GROUPS) -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.general.training.batch_samplers import DWIMLBatchIDSampler +from dwi_ml.general.training.batch_loaders import DWIMLBatchLoaderOneInput def fetch_testing_data(): diff --git a/src/dwi_ml/unit_tests/utils/expected_values.py b/src/dwi_ml/general/unit_tests/utils/expected_values.py similarity index 100% rename from src/dwi_ml/unit_tests/utils/expected_values.py rename to src/dwi_ml/general/unit_tests/utils/expected_values.py diff --git a/src/dwi_ml/general/unit_tests/visual_tests/__init__.py b/src/dwi_ml/general/unit_tests/visual_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py b/src/dwi_ml/general/unit_tests/visual_tests/analyze_batch_loader_visually.py similarity index 93% rename from src/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py rename to src/dwi_ml/general/unit_tests/visual_tests/analyze_batch_loader_visually.py index 4a872747..ba5caed6 100755 --- a/src/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py +++ b/src/dwi_ml/general/unit_tests/visual_tests/analyze_batch_loader_visually.py @@ -11,10 +11,10 @@ from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.streamline import save_tractogram -from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset -from dwi_ml.models.main_models import ModelWithOneInput -from dwi_ml.tracking.io_utils import ALWAYS_VOX_SPACE, ALWAYS_CORNER -from dwi_ml.unit_tests.utils.data_and_models_for_tests import ( +from dwi_ml.general.data.dataset.multi_subject_containers import MultiSubjectDataset +from dwi_ml.general.models.main_models.main_models import ModelWithOneInput +from dwi_ml.general.tracking.io_utils import ALWAYS_VOX_SPACE, ALWAYS_CORNER +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import ( create_test_batch_sampler, create_batch_loader, fetch_testing_data, ModelForTest) diff --git a/src/dwi_ml/general/viz/__init__.py b/src/dwi_ml/general/viz/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/viz/logs_plots.py b/src/dwi_ml/general/viz/logs_plots.py similarity index 100% rename from src/dwi_ml/viz/logs_plots.py rename to src/dwi_ml/general/viz/logs_plots.py diff --git a/src/dwi_ml/projects/AE/__init__.py b/src/dwi_ml/projects/AE/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/models/projects/ae_models.py b/src/dwi_ml/projects/AE/ae_models.py similarity index 98% rename from src/dwi_ml/models/projects/ae_models.py rename to src/dwi_ml/projects/AE/ae_models.py index 1d9a30b8..0aeceec9 100644 --- a/src/dwi_ml/models/projects/ae_models.py +++ b/src/dwi_ml/projects/AE/ae_models.py @@ -5,7 +5,7 @@ import torch from torch.nn import functional as F -from dwi_ml.models.main_abstract_model import MainModelAbstract +from dwi_ml.general.models.main_models.main_abstract_model import MainModelAbstract class ModelAE(MainModelAbstract): diff --git a/src/dwi_ml/projects/Learn2track/__init__.py b/src/dwi_ml/projects/Learn2track/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/models/projects/learn2track_model.py b/src/dwi_ml/projects/Learn2track/learn2track_model.py similarity index 97% rename from src/dwi_ml/models/projects/learn2track_model.py rename to src/dwi_ml/projects/Learn2track/learn2track_model.py index 0ab6a1b7..246d8350 100644 --- a/src/dwi_ml/models/projects/learn2track_model.py +++ b/src/dwi_ml/projects/Learn2track/learn2track_model.py @@ -6,16 +6,15 @@ import torch from torch.nn.utils.rnn import invert_permutation, PackedSequence, pack_sequence -from dwi_ml.data.processing.space.neighborhood import unflatten_neighborhood -from dwi_ml.data.processing.streamlines.post_processing import \ - compute_directions, normalize_directions, compute_n_previous_dirs -from dwi_ml.data.processing.streamlines.sos_eos_management import \ - convert_dirs_to_class -from dwi_ml.models.embeddings import NoEmbedding -from dwi_ml.models.main_models import ( +from dwi_ml.general.data.processing.space.neighborhood import \ + unflatten_neighborhood +from dwi_ml.general.data.processing.streamlines.post_processing import ( + compute_directions, normalize_directions, compute_n_previous_dirs) +from dwi_ml.general.models.main_layers.embeddings import NoEmbedding +from dwi_ml.general.models.main_models.main_models import ( ModelWithPreviousDirections, ModelWithDirectionGetter, ModelWithNeighborhood, ModelWithOneInput) -from dwi_ml.models.stacked_rnn import StackedRNN +from dwi_ml.general.models.main_layers.stacked_rnn import StackedRNN logger = logging.getLogger('model_logger') # Same logger as Super. diff --git a/src/dwi_ml/tracking/projects/learn2track_tracker.py b/src/dwi_ml/projects/Learn2track/learn2track_tracker.py similarity index 95% rename from src/dwi_ml/tracking/projects/learn2track_tracker.py rename to src/dwi_ml/projects/Learn2track/learn2track_tracker.py index 76336ac5..9a77b6c7 100644 --- a/src/dwi_ml/tracking/projects/learn2track_tracker.py +++ b/src/dwi_ml/projects/Learn2track/learn2track_tracker.py @@ -3,8 +3,8 @@ import numpy as np -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.tracking.tracker import DWIMLTrackerOneInput +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.general.tracking.tracker import DWIMLTrackerOneInput logger = logging.getLogger('tracker_logger') diff --git a/src/dwi_ml/training/projects/learn2track_trainer.py b/src/dwi_ml/projects/Learn2track/learn2track_trainer.py similarity index 94% rename from src/dwi_ml/training/projects/learn2track_trainer.py rename to src/dwi_ml/projects/Learn2track/learn2track_trainer.py index 48deb3f2..57a54c52 100644 --- a/src/dwi_ml/training/projects/learn2track_trainer.py +++ b/src/dwi_ml/projects/Learn2track/learn2track_trainer.py @@ -6,10 +6,10 @@ import numpy as np import torch -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.tracking.io_utils import prepare_tracking_mask -from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.training.trainers_withGV import \ +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.general.tracking.io_utils import prepare_tracking_mask +from dwi_ml.general.tracking.propagation import propagate_multiple_lines +from dwi_ml.general.training.trainers_withGV import \ DWIMLTrainerOneInputWithGVPhase logger = logging.getLogger('trainer_logger') diff --git a/src/dwi_ml/models/projects/learn2track_utils.py b/src/dwi_ml/projects/Learn2track/learn2track_utils.py similarity index 96% rename from src/dwi_ml/models/projects/learn2track_utils.py rename to src/dwi_ml/projects/Learn2track/learn2track_utils.py index e401a1c0..5bbd0585 100644 --- a/src/dwi_ml/models/projects/learn2track_utils.py +++ b/src/dwi_ml/projects/Learn2track/learn2track_utils.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import argparse -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel def add_model_args(p: argparse.ArgumentParser): diff --git a/src/dwi_ml/unit_tests/test_models_learn2track.py b/src/dwi_ml/projects/Learn2track/test_models_learn2track.py similarity index 91% rename from src/dwi_ml/unit_tests/test_models_learn2track.py rename to src/dwi_ml/projects/Learn2track/test_models_learn2track.py index 708ea0e4..c604dcf5 100644 --- a/src/dwi_ml/unit_tests/test_models_learn2track.py +++ b/src/dwi_ml/projects/Learn2track/test_models_learn2track.py @@ -3,10 +3,12 @@ from torch.nn.utils.rnn import pack_sequence -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.models.stacked_rnn import StackedRNN, ADD_SKIP_TO_OUTPUT -from dwi_ml.unit_tests.utils.data_and_models_for_tests import create_test_batch_2lines_4features +from dwi_ml.general.experiment_utils.prints import format_dict_to_str +from dwi_ml.projects.Learn2track.learn2track_model import Learn2TrackModel +from dwi_ml.general.models.main_layers.stacked_rnn import ( + StackedRNN, ADD_SKIP_TO_OUTPUT) +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import \ + create_test_batch_2lines_4features batch_x, _, batch_s, _ = create_test_batch_2lines_4features() diff --git a/src/dwi_ml/projects/Transformers/__init__.py b/src/dwi_ml/projects/Transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/unit_tests/test_models_transformers.py b/src/dwi_ml/projects/Transformers/test_models_transformers.py similarity index 97% rename from src/dwi_ml/unit_tests/test_models_transformers.py rename to src/dwi_ml/projects/Transformers/test_models_transformers.py index 5b642e78..7d496622 100644 --- a/src/dwi_ml/unit_tests/test_models_transformers.py +++ b/src/dwi_ml/projects/Transformers/test_models_transformers.py @@ -3,9 +3,9 @@ from torch import isnan, set_printoptions -from dwi_ml.models.projects.transformer_models import ( +from dwi_ml.projects.Transformers.transformer_models import ( OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel) -from dwi_ml.unit_tests.utils.data_and_models_for_tests import create_test_batch_2lines_4features +from dwi_ml.general.unit_tests.utils.data_and_models_for_tests import create_test_batch_2lines_4features (batch_x_various_lengths, batch_x_same_lengths, batch_s_various_lengths, batch_s_same_lengths) = create_test_batch_2lines_4features() diff --git a/src/dwi_ml/projects/Transformers/tester/__init__.py b/src/dwi_ml/projects/Transformers/tester/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dwi_ml/testing/projects/tt_visu_argparser.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py similarity index 98% rename from src/dwi_ml/testing/projects/tt_visu_argparser.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py index 9aadbac4..3d944e5f 100644 --- a/src/dwi_ml/testing/projects/tt_visu_argparser.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py @@ -54,8 +54,8 @@ from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, add_verbose_arg) -from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 +from dwi_ml.general.io_utils import add_arg_existing_experiment_path, add_memory_args +from dwi_ml.general.testing.utils import add_args_testing_subj_hdf5 def build_argparser_transformer_visu(): diff --git a/src/dwi_ml/testing/projects/tt_visu_bertviz.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_bertviz.py similarity index 100% rename from src/dwi_ml/testing/projects/tt_visu_bertviz.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_bertviz.py diff --git a/src/dwi_ml/testing/projects/tt_visu_colored_sft.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py similarity index 99% rename from src/dwi_ml/testing/projects/tt_visu_colored_sft.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py index f129cf71..61ed0b00 100644 --- a/src/dwi_ml/testing/projects/tt_visu_colored_sft.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py @@ -9,7 +9,7 @@ from scilpy.viz.color import get_lookup_table -from dwi_ml.testing.projects.tt_visu_utils import ( +from dwi_ml.projects.Transformers.tester.tt_visu_utils import ( get_visu_params_from_options, prepare_projections_from_options) diff --git a/src/dwi_ml/testing/projects/tt_visu_main.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_main.py similarity index 95% rename from src/dwi_ml/testing/projects/tt_visu_main.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_main.py index 2c706514..3aafaa1b 100644 --- a/src/dwi_ml/testing/projects/tt_visu_main.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_main.py @@ -18,19 +18,19 @@ from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist -from dwi_ml.io_utils import verify_which_model_in_path -from dwi_ml.models.projects.transformer_models import find_transformer_class -from dwi_ml.testing.projects.tt_visu_bertviz import ( +from dwi_ml.general.io_utils import verify_which_model_in_path +from dwi_ml.projects.Transformers.transformer_models import find_transformer_class +from dwi_ml.projects.Transformers.tester.tt_visu_bertviz import ( encoder_decoder_show_head_view, encoder_decoder_show_model_view, encoder_show_model_view, encoder_show_head_view) -from dwi_ml.testing.projects.tt_visu_colored_sft import ( +from dwi_ml.projects.Transformers.tester.tt_visu_colored_sft import ( color_sft_duplicate_lines, color_sft_x_y_projections) -from dwi_ml.testing.projects.tt_visu_matrix import show_model_view_as_imshow -from dwi_ml.testing.projects.tt_visu_utils import ( +from dwi_ml.projects.Transformers.tester.tt_visu_matrix import show_model_view_as_imshow +from dwi_ml.projects.Transformers.tester.tt_visu_utils import ( prepare_encoder_tokens, prepare_decoder_tokens, reshape_unpad_rescale_attention, resample_attention_one_line, get_out_dir_and_create) -from dwi_ml.testing.testers import TesterOneInput +from dwi_ml.general.testing.testers import TesterOneInput def tt_visualize_weights_main(args, parser): diff --git a/src/dwi_ml/testing/projects/tt_visu_matrix.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py similarity index 98% rename from src/dwi_ml/testing/projects/tt_visu_matrix.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py index cdd5ea8a..8952c72b 100644 --- a/src/dwi_ml/testing/projects/tt_visu_matrix.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py @@ -4,7 +4,7 @@ from matplotlib import pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable -from dwi_ml.testing.projects.tt_visu_utils import ( +from dwi_ml.projects.Transformers.tester.tt_visu_utils import ( get_visu_params_from_options, prepare_projections_from_options) diff --git a/src/dwi_ml/testing/projects/tt_visu_submethods.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_submethods.py similarity index 94% rename from src/dwi_ml/testing/projects/tt_visu_submethods.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_submethods.py index 00a4f93a..8a96d1ec 100644 --- a/src/dwi_ml/testing/projects/tt_visu_submethods.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_submethods.py @@ -13,10 +13,10 @@ from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.tractograms.streamline_operations import \ resample_streamlines_step_size -from scilpy.utils.streamlines import compress_sft +from scilpy.tractograms.tractogram_operations import compress_sft -from dwi_ml.models.projects.transformer_models import AbstractTransformerModel -from dwi_ml.testing.utils import prepare_dataset_one_subj +from dwi_ml.projects.Transformers.transformer_models import AbstractTransformerModel +from dwi_ml.general.testing.utils import prepare_dataset_one_subj # Currently, with our quite long sequences compared to their example, this # is a bit ugly. @@ -258,7 +258,7 @@ def tto_show_head_view(encoder_attention, decoder_attention, cross_attention, encoder_tokens=encoder_tokens, decoder_tokens=decoder_tokens) -def show_model_view_as_imshow(attention, tokens_x, tokens_y=None): +def _show_model_view_as_imshow(attention, tokens_x, tokens_y=None): torch.set_printoptions(precision=2, sci_mode=False, linewidth=150) nb_layers = len(attention) @@ -328,14 +328,14 @@ def tto_show_model_view(encoder_attention, decoder_attention, cross_attention, decoder_tokens=decoder_tokens) else: print("ENCODER ATTENTION: ") - show_model_view_as_imshow(encoder_attention, - encoder_tokens, encoder_tokens) + _show_model_view_as_imshow(encoder_attention, + encoder_tokens, encoder_tokens) print("DECODER ATTENTION: ") - show_model_view_as_imshow(decoder_attention, - decoder_tokens, decoder_tokens) + _show_model_view_as_imshow(decoder_attention, + decoder_tokens, decoder_tokens) print("CROSS ATTENTION: ") - show_model_view_as_imshow(cross_attention, - encoder_tokens, decoder_tokens) + _show_model_view_as_imshow(cross_attention, + encoder_tokens, decoder_tokens) def ttst_show_model_view(encoder_attention, tokens): @@ -354,4 +354,4 @@ def ttst_show_model_view(encoder_attention, tokens): else: print("ENCODER ATTENTION: ") - show_model_view_as_imshow(encoder_attention, tokens) + _show_model_view_as_imshow(encoder_attention, tokens) diff --git a/src/dwi_ml/testing/projects/tt_visu_utils.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_utils.py similarity index 100% rename from src/dwi_ml/testing/projects/tt_visu_utils.py rename to src/dwi_ml/projects/Transformers/tester/tt_visu_utils.py diff --git a/src/dwi_ml/testing/projects/tt_visualize_weights.ipynb b/src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb similarity index 92% rename from src/dwi_ml/testing/projects/tt_visualize_weights.ipynb rename to src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb index 13666ee6..aaf60ee3 100644 --- a/src/dwi_ml/testing/projects/tt_visualize_weights.ipynb +++ b/src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb @@ -22,16 +22,13 @@ "metadata": {}, "outputs": [], "source": [ - "import argparse\n", + "\n", "import os\n", - "from os.path import dirname, join\n", "import sys\n", "\n", - "from IPython.display import HTML\n", - "\n", "from scilpy.io.fetcher import get_home as get_scilpy_folder\n", "\n", - "from dwi_ml.testing.projects.tt_visu_main import \\\n", + "from dwi_ml.general.testing.projects import \\\n", " (build_argparser_transformer_visu, get_config_filename, tt_visualize_weights_main) \n" ] }, diff --git a/src/dwi_ml/models/projects/transformer_models.py b/src/dwi_ml/projects/Transformers/transformer_models.py similarity index 98% rename from src/dwi_ml/models/projects/transformer_models.py rename to src/dwi_ml/projects/Transformers/transformer_models.py index 14685eb0..2b21b09e 100644 --- a/src/dwi_ml/models/projects/transformer_models.py +++ b/src/dwi_ml/projects/Transformers/transformer_models.py @@ -8,17 +8,17 @@ from torch.nn import Dropout, Transformer from torch.nn.functional import pad -from dwi_ml.data.processing.streamlines.sos_eos_management import \ +from dwi_ml.general.data.processing.streamlines.sos_eos_management import \ add_label_as_last_dim, convert_dirs_to_class -from dwi_ml.data.processing.streamlines.post_processing import \ +from dwi_ml.general.data.processing.streamlines.post_processing import \ compute_directions -from dwi_ml.data.spheres import TorchSphere -from dwi_ml.models.embeddings import keys_to_embeddings -from dwi_ml.models.main_models import (ModelWithDirectionGetter, - ModelWithNeighborhood, - ModelWithOneInput) -from dwi_ml.models.positional_encoding import keys_to_positional_encodings -from dwi_ml.models.utils.transformers_from_torch import ( +from dwi_ml.general.data.spheres import TorchSphere +from dwi_ml.general.models.main_layers.embeddings import keys_to_embeddings +from dwi_ml.general.models.main_models.main_models import ( + ModelWithDirectionGetter, ModelWithNeighborhood, ModelWithOneInput) +from dwi_ml.general.models.main_layers.positional_encoding import \ + keys_to_positional_encodings +from dwi_ml.general.models.main_layers.transformers_from_torch import ( ModifiedTransformer, ModifiedTransformerEncoder, ModifiedTransformerEncoderLayer, ModifiedTransformerDecoder, ModifiedTransformerDecoderLayer) diff --git a/src/dwi_ml/tracking/projects/transformer_tracker.py b/src/dwi_ml/projects/Transformers/transformer_tracker.py similarity index 77% rename from src/dwi_ml/tracking/projects/transformer_tracker.py rename to src/dwi_ml/projects/Transformers/transformer_tracker.py index b21d063b..bfbba859 100644 --- a/src/dwi_ml/tracking/projects/transformer_tracker.py +++ b/src/dwi_ml/projects/Transformers/transformer_tracker.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from dwi_ml.models.projects.transformer_models import AbstractTransformerModel -from dwi_ml.tracking.tracker import ( +from dwi_ml.projects.Transformers.transformer_models import AbstractTransformerModel +from dwi_ml.general.tracking.tracker import ( DWIMLTrackerOneInput, DWIMLTrackerFromWholeStreamline) diff --git a/src/dwi_ml/training/projects/transformer_trainer.py b/src/dwi_ml/projects/Transformers/transformer_trainer.py similarity index 95% rename from src/dwi_ml/training/projects/transformer_trainer.py rename to src/dwi_ml/projects/Transformers/transformer_trainer.py index bf3b0080..bc98a2cd 100644 --- a/src/dwi_ml/training/projects/transformer_trainer.py +++ b/src/dwi_ml/projects/Transformers/transformer_trainer.py @@ -6,10 +6,10 @@ import numpy as np import torch -from dwi_ml.tracking.io_utils import prepare_tracking_mask -from dwi_ml.tracking.propagation import propagate_multiple_lines +from dwi_ml.general.tracking.io_utils import prepare_tracking_mask +from dwi_ml.general.tracking.propagation import propagate_multiple_lines -from dwi_ml.training.trainers_withGV import \ +from dwi_ml.general.training.trainers_withGV import \ DWIMLTrainerOneInputWithGVPhase diff --git a/src/dwi_ml/models/projects/transformers_utils.py b/src/dwi_ml/projects/Transformers/transformers_utils.py similarity index 97% rename from src/dwi_ml/models/projects/transformers_utils.py rename to src/dwi_ml/projects/Transformers/transformers_utils.py index 64a2e34a..81712042 100644 --- a/src/dwi_ml/models/projects/transformers_utils.py +++ b/src/dwi_ml/projects/Transformers/transformers_utils.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from dwi_ml.models.positional_encoding import ( +from dwi_ml.general.models.main_layers.positional_encoding import ( keys_to_positional_encodings) -from dwi_ml.models.projects.transformer_models import AbstractTransformerModel +from dwi_ml.projects.Transformers.transformer_models import AbstractTransformerModel sphere_choices = ['symmetric362', 'symmetric642', 'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200'] diff --git a/src/dwi_ml/projects/__init__.py b/src/dwi_ml/projects/__init__.py new file mode 100644 index 00000000..e69de29b