diff --git a/docs/_static/images/Transformers.png b/docs/_static/images/Transformers.png new file mode 100644 index 00000000..efd3432d Binary files /dev/null and b/docs/_static/images/Transformers.png differ diff --git a/docs/_static/images/create_your_model.png b/docs/_static/images/create_your_model.png index da47c718..af6f7b8e 100644 Binary files a/docs/_static/images/create_your_model.png and b/docs/_static/images/create_your_model.png differ diff --git a/docs/_static/images/create_your_model2.png b/docs/_static/images/create_your_model2.png new file mode 100644 index 00000000..c03064f2 Binary files /dev/null and b/docs/_static/images/create_your_model2.png differ diff --git a/docs/_static/my_style.css b/docs/_static/my_style.css index 60abce31..f7cd2def 100644 --- a/docs/_static/my_style.css +++ b/docs/_static/my_style.css @@ -12,15 +12,11 @@ text-decoration: line-through; } -/* Hiding captions on the main page. Did not allow to add comments under the -caption. Adding subtitles instead, and we don't have duplicate caption. But we -still want the caption in the sidebar. - */ -.wy-nav-content .caption-text { - display: none; -} +/* ------------------- + Changes prepared by Emmanuelle Renauld: + -------------------*/ -/* Code shows in red when writing with `code`. Changing to gray */ +/* Code shows in red when writing with ``code``. Changing to gray */ code.literal { color: dimgray !important; } @@ -46,7 +42,7 @@ div class rst-content background-color: #eeefef; } -/*Now tables colors are off*/ +/*Now tables colors are off. Fixing.*/ table .row-odd { background-color: white } @@ -55,6 +51,15 @@ table .row-even { background-color: #bfbdbd; } +/* Creating a "centered" container.*/ +.centered { + text-align: center; +} +.centered table { + margin-left: auto; + margin-right: auto; +} + /* Adding the same padding to footers */ footer { padding: 1em; diff --git a/docs/for_developers/data_management/Advanced_data_containers.rst b/docs/for_developers/data_management/Advanced_data_containers.rst index de1656c9..64fc5fb8 100644 --- a/docs/for_developers/data_management/Advanced_data_containers.rst +++ b/docs/for_developers/data_management/Advanced_data_containers.rst @@ -6,7 +6,8 @@ The MultisubjectDataset Here is how our data is organized to allow torch to use them through Dataloaders. All of the following can be used as lazy instead. Then, the data is only loaded when needed. -- **(Lazy)MultisubjectDataset** +(Lazy)MultisubjectDataset +------------------------- - subcontainers: @@ -20,9 +21,10 @@ Here is how our data is organized to allow torch to use them through Dataloaders - methods: - - *.load_data()*: loads the training and validation sets. + - *.load_data()*: loads the training and validation sets (if not lazy), or loads only the information about each subject. -- **Multisubjectsubset** +Multisubjectsubset +------------------ - subcontainers: @@ -36,12 +38,13 @@ Here is how our data is organized to allow torch to use them through Dataloaders - methods: - - *.get_volume()*: gets a specific mri volume (ID corresponds to the group ID in the config_file) from a specific subject. + - *.get_volume(id)*: gets a specific mri volume (contained in the hdf5) from a specific subject. - *.get_volume_verify_cache()*: same, but if data was lazy, checks the volume cache first. If it was not cached, loads it and sends it to the cache. - *.__getitem__()*: used by the dataloader. Does not do anything per say, simply returns the sampled streamline id. The batch sampler will do the job of actually loading the data. -- **(Lazy)SubjectsDataList** +(Lazy)SubjectsDataList +---------------------- - subcontainers: @@ -59,7 +62,8 @@ Here is how our data is organized to allow torch to use them through Dataloaders - *.getitem_with_handle()*: same, but in the lazy case, adds a hdf5 handle first to allow loading. You probably won't need this method: used in the get_volume() method of the MultisubjectSubset. -- **(Lazy)SubjectData** +(Lazy)SubjectData +----------------- - subcontainers: @@ -67,6 +71,7 @@ Here is how our data is organized to allow torch to use them through Dataloaders - .sft_data_list: list of **SFTData** (lazy or not) - other attributes: + - .volume_groups, .streamline_group, .subject_id: general attributes found in the hdf5 file. - methods: @@ -75,7 +80,8 @@ Here is how our data is organized to allow torch to use them through Dataloaders - *.with_handle()*: useful only in the lazy case. Adds hdf_handle to the subject to allow loading. -**MRIData** +MRIData +------- - attributes: @@ -89,7 +95,8 @@ Here is how our data is organized to allow torch to use them through Dataloaders - *.as_tensor()*: gets the data. -**SFTData** +SFTData +------- - attributes: diff --git a/docs/for_developers/data_management/BatchLoader.rst b/docs/for_developers/data_management/BatchLoader.rst index b9c5b63e..54873891 100644 --- a/docs/for_developers/data_management/BatchLoader.rst +++ b/docs/for_developers/data_management/BatchLoader.rst @@ -3,4 +3,23 @@ Batch loader ============ -toDo \ No newline at end of file +These classes define how batches of streamlines are loaded from a MultiSubjectDataset and how data augmentation is applied. Two main types of batch loaders are implemented: + +DWIMLStreamlinesBatchLoader +--------------------------- + +Loads augmented streamlines only (no MRI volumes). Loads streamlines from a dataset, applies optional data augmentation (resampling, cutting, reversing, noise), and returns them in voxel/corner space. + +Methods: + - ``set_context``: Sets whether the loader operates on training or validation data. Also configures which noise augmentation applies. + - ``load_batch_streamlines(streamline_ids_per_subj)``: Loads the streamlines for each subject, applies: resampling or compression, splitting, reversing, conversion to voxel + corner coordinates + +DWIMLBatchLoaderOneInput +------------------------ + +Child class of DWIMLStreamlinesBatchLoader. Additionnally communicates with the model to prepare input volume(s) under each point of each streamline and performs trilinear interpolation, and, optionnally, neighborhood extraction. + +Methods: + - ``load_batch_inputs(batch_streamlines, ids_per_subj)`` + +Note: Must be used with a model with inputs, uses the model’s method: ``prepare_batch_one_input()``. \ No newline at end of file diff --git a/docs/for_developers/data_management/BatchSampler.rst b/docs/for_developers/data_management/BatchSampler.rst index 7c6eaebc..e6e38f47 100644 --- a/docs/for_developers/data_management/BatchSampler.rst +++ b/docs/for_developers/data_management/BatchSampler.rst @@ -3,45 +3,11 @@ Batch sampler ============= -These classes defines how to sample the streamlines available in the -MultiSubjectData. +These classes defines how to sample the streamlines available in the MultiSubjectData. You are encouraged to contribute to dwi_ml by adding any child class here. -**AbstractBatchSampler:** -- Defines the __iter__ method: +DWIMLBatchIDSampler +------------------- - - Finds a list of streamlines ids and associated subj that you can later load in your favorite way. +- Defines the __iter__ method: It finds a list of streamlines ids and associated subjects that you can later load in your favorite way. It limits the number of subjects per batch and orders streamlines by subjects to make sure you don't need to load a full new volume at each new streamline. -- Define the load_batch method: - - - Loads the streamlines associated to sampled ids. Can resample them. - - - Performs data augmentation (on-the-fly to avoid having to multiply data on disk) (ex: splitting, reversing, adding noise). - -Child class : **BatchStreamlinesSamplerOneInput:** - -- Redefines the load_batch method: - - - Now also loads the input data under each point of the streamline (and possibly its neighborhood), for one input volume. - -You are encouraged to contribute to dwi_ml by adding any child class here. - - - - - For instance, the BatchSequenceSampler creates batches of streamline ids that will be used for each batch iteratively through *__iter__*. Those chosen streamlines can then be loaded and processed with *load_batch*, which also uses data augmentation. - - - For instance, the BatchSequencesSamplerOneInputVolume then uses the generated streamlines to load one underlying input for each timestep of each streamline. - - - The BatchSampler uses the **MultiSubjectDataset** (or lazy) - - - Creates a list of subjects. Using *self.load_data()*, it loops on all subjects (for either the training set or the validation set) and loads the data from the hdf5 (lazily or not). - - - The list is a **DataListForTorch** (or lazy). It contains the subjects but also common parameters such as the feature sizes of each input. - - - The elements are **SubjectData** (or lazy) - - - They contain both the volumes and the streamlines, with the subject ID. - - - Volumes are **MRIData** (or lazy). They contain the data and affine. - - - Streamlines are **SFTData** (or lazy). They contain all information necessary to create a Stateful Tractogram from the streamlines. In the lazy version, streamlines are not loaded all together but read when needed from the **LazyStreamlinesGetter**. diff --git a/docs/for_developers/models/directionGetters.rst b/docs/for_developers/models/directionGetters.rst index 1b66c1a8..10f38bb1 100644 --- a/docs/for_developers/models/directionGetters.rst +++ b/docs/for_developers/models/directionGetters.rst @@ -1,9 +1,9 @@ .. _direction_getters: -Create a tractography model: use a DirectionGetter -================================================== +The DirectionGetter Layer +========================= -Direction getter layers should be used as last layer of any streamline generation model for the tractography task. They define the format of the output and possible associated loss functions. +``DirectionGetter`` layers should be used as last layer of any streamline generation model for the tractography task. They define the format of the output and possible associated loss functions. General architecture -------------------- @@ -21,6 +21,8 @@ Regression models These models use regression to learn directly a direction, formatted as a coordinate [x, y, z]. If EOS is used, then the direction is formatted as [x, y, z, eos], where eos is a probability between 0 and 1. +.. container:: centered + +------------------------+---------------------------------+ | Shape of the output | A vector of length 3 or 4 | +------------------------+---------------------------------+ @@ -51,6 +53,7 @@ Classification models This model uses classification by formatting directions as a choice of direction among a list of discrete points on the sphere (Ex: ``dipy.data.get_sphere('symmetric724')``). Each point is a class. If EOS is used, it represents an additional class. +.. container:: centered +------------------------+----------------------------------+ | Shape of the output | A vector of length K (nb class) | @@ -80,6 +83,8 @@ This is a regression model that learns parameters representing a *probability fu - **SingleGaussianDG**: the model is a 2-layer NN for the means and a 2-layer NN for the variances. The output is 6 parameters: 3 means (x, y, z) and 3 variances (x, y, z). If EOS is used, it is a 7th learned value. +.. container:: centered + +------------------------+-------------------------------------------------+ | Shape of the output | A vector of length 6 or 7 | +------------------------+-------------------------------------------------+ @@ -92,6 +97,8 @@ This is a regression model that learns parameters representing a *probability fu - **GaussianMixtureDG**: In this case, the models learns to represent the function probability as a mixture of N Gaussians, possibly representing direction choices in the case of fiber crossing and other special configurations. The loss is again the negative log-likelihood. Note that the model is a 2-layer NN for the mean and a 2-layer NN for the variance, for each of N Gaussians. The output is N * (6 parameters: 3 means (x, y, z) and 3 variances (x, y, z) plus a mixture parameter for each, giving the probability that the right direction would be given by this Gaussian. +.. container:: centered + +------------------------+-------------------------------------------------+ | Shape of the output | A vector of length N * (7 or 8) | +------------------------+-------------------------------------------------+ @@ -109,6 +116,12 @@ Note that tyically, in the literature, Gaussian mixtures are used with expectati .. [1] Some code is available `in this blog `_. .. [2] See the `Tractoinferno paper `_ +.. toctree:: + :maxdepth: 1 + :caption: Full formulas (Gaussian + FvM) + :hidden: + + formulas_fisher_gaussian Fisher von mises models @@ -120,6 +133,7 @@ See the detailed mathematics in :ref:`ref_formulas`. - **FisherVonMisesDG**: The loss is the negative log-likelihood. Note that the model is a 2-layer NN for the means and a 2-layer NN for the variances. See :ref:`ref_formulas` for the complete formulas. The output is 4 parameters: 3 for the means and one for kappa. If EOS is used, it is a 5th learned value. +.. container:: centered +------------------------+-------------------------------------------------+ | Shape of the output | A vector of length 4 or 5. | @@ -133,12 +147,6 @@ See the detailed mathematics in :ref:`ref_formulas`. **FisherVonMisesMixtureDG**: Not implemented yet. -Other ideas -''''''''''' - -An equivalent model could learn to represent the direction on the sphere to learn a normalized direction. The means would be 2D (phi, rho), and the variances too. This has not been implemented yet. - - %%%%%%%%%%%%%%%%%%% .. [3] Directional Statistics (Mardia and Jupp, 1999)), implemented `here `_. diff --git a/docs/for_developers/models/formulas_fisher_gaussian.rst b/docs/for_developers/models/formulas_fisher_gaussian.rst index a9f36af3..6229cc46 100644 --- a/docs/for_developers/models/formulas_fisher_gaussian.rst +++ b/docs/for_developers/models/formulas_fisher_gaussian.rst @@ -10,9 +10,9 @@ Log-likelihood of a single Gaussian - N(x) = The normal distribution, with mean mu and variance sigma^2. - C = the covariance matrix. It indicates the relations between each all dimensions x,y,z. C is diagonal if the axis are independant, with the variances on the diagonal. - - d = the dimension of data (here, 3: x,y,z) + - d = the dimension of data (here, 3: [x,y,z]) - We do pose that the axis are independant. WHY CAN WE POSE THAT???? + We pose that the axis are independant. We assume the x, y, and z components are independent because the model predicts each directional axis separately, and there is no prior reason to expect correlations between them in the learned representation. This simplifies the covariance structure to a diagonal matrix without significantly affecting the model’s performance. **Formulas:** diff --git a/docs/for_developers/models/index.rst b/docs/for_developers/models/index.rst index 3b80643f..33a84b35 100644 --- a/docs/for_developers/models/index.rst +++ b/docs/for_developers/models/index.rst @@ -1,7 +1,9 @@ .. _create_your_model: -Create your own model -===================== +Creating your own model +======================= + +You are welcome to add your own project's model in dwi_ml! .. toctree:: :maxdepth: 1 @@ -11,4 +13,3 @@ Create your own model main_abstract_model other_main_models directionGetters - formulas_fisher_gaussian diff --git a/docs/for_developers/models/main_abstract_model.rst b/docs/for_developers/models/main_abstract_model.rst index 4bcf4f6f..eb13a691 100644 --- a/docs/for_developers/models/main_abstract_model.rst +++ b/docs/for_developers/models/main_abstract_model.rst @@ -1,11 +1,9 @@ .. _main_abstract_model: -Create your own model: use MainModelAbstract -============================================ +Inherit from MainModelAbstract +============================== -You are welcome to add your own project's model in dwi_ml! - -Projects using diffusion imaging and tractography streamlines are usually quite heavy in memory. For this reason, dwi_ml is not only a space for models, but it also includes smart management of data loading during training (see :ref:`trainers` for more information). Our training objects and model objects are thus intertwined. For this reason, you should always make your model a child class of our **MainModelAbstract**. +Projects using diffusion imaging and tractography streamlines are usually quite heavy in memory. For this reason, dwi_ml is not only a space for models, but it also includes smart management of data loading during training. Our training objects (see :ref:`trainers` for more information) and model objects are thus intertwined. For this reason, you should always make your model a child class of our **MainModelAbstract**. The MainModelAbstract class @@ -48,17 +46,18 @@ Our model: Where to start? --------------- -As a first step, try to implement a child class of ``dwi.models.main_models.MainModelAbstract`` and see how you would implement the two following methods: ``forward()`` (the core of your model) and ``compute_loss()``. - -Once you are more experimented, you can explore how to use methods already implemented in dwi_ml to improve your model, using, for instance, our models used for tractographie generation, or our models able to add a neighborhood to each input point, etc. +As a first step, try to implement a child class of ``dwi.general.models.main_models.main_abstract_model.MainModelAbstract`` and see how you would implement the two following methods: ``forward()`` (the core of your model) and ``compute_loss()``. -1. Create a new file in ``src/dwi_ml/models/projects`` named my_project.py +1. Create a new file in ``src/dwi_ml/projects/my_project`` named my_project.py 2. Start your project like this: .. image:: /_static/images/create_your_model.png + :align: center :width: 500 3. Learn to use your model in our Trainer (see page :ref:`trainers`). -4. Before coding everything from scratch in our model, verify if it could inherit from our other models (see page :ref:`other_main_models`) to benefit from their methods. +4. Discover the optional parameters of the ``MainModelAbstract`` model. + +5. Before coding everything else from scratch in our model, verify if it could inherit from our other models (see page :ref:`other_main_models`) to benefit from their methods. diff --git a/docs/for_developers/models/other_main_models.rst b/docs/for_developers/models/other_main_models.rst index fbbf4c4b..093f0344 100644 --- a/docs/for_developers/models/other_main_models.rst +++ b/docs/for_developers/models/other_main_models.rst @@ -1,10 +1,22 @@ .. _other_main_models: -Create your own model: Inherit from our models -============================================== +You may inherit from many models! +================================= + +All models in dwi_ml should be a child class of our MainModelAbstract. See :ref:`main_abstract_model` first. We have also prepared classes to help with common usages. Your model could easily inherit from them to benefit from what they have to offer. Each of them is a child of our main abstract model (as your own model should be, see +Where to start? +--------------- + +You can read on python multi-parent inheritance, or you can test it yourself! For instance, the model below inherits from all our classes and uses ``super.__init__(...)`` to set all parameters at once! + + .. image:: /_static/images/create_your_model2.png + :align: center + :width: 600 + +.. note:: You noticed that we did not explicitely inherit from the ``MainModelAbstract``? It's because all our models are themselves children of it. So it still is a child of ``MainModelAbstract``. General models -------------- @@ -12,15 +24,15 @@ General models ``ModelWithPreviousDirections`` ******************************* -In models, streamlines are often used as targets. But you may also need to use them as input. If your model iterates on streamline points and needs a fixed number of inputs at the time, you could use the N previous directions at each position of the streamline (see for instance in Poulin et al. 2017). This model adds parameters for the previous direction, plus embedding choices. + In models, streamlines are often used as targets. But you may also need to use them as input. If your model iterates on streamline points and needs a fixed number of inputs at the time, you could use the N previous directions at each position of the streamline (see for instance in Poulin et al. 2017). This model adds parameters for the previous direction, plus embedding choices. ``ModelWithDirectionGetter`` **************************** -This is our model intented for tractography models. It defines a layer of what we call the "directionGetter", which outputs a chosen direction for the next step of tractography propagation, in many possible formats, and knows how to compute the loss function accordingly. See the page :ref:`direction_getters` for more information. + This is our model intented for tractography models. It defines a layer of what we call the "directionGetter", which outputs a chosen direction for the next step of tractography propagation, in many possible formats, and knows how to compute the loss function accordingly. See the page :ref:`direction_getters` for more information. -It also contains a ``get_tracking_directions`` method, which should be implemented in your project to use this model for tractography. + It also contains a ``get_tracking_directions`` method, which should be implemented in your project to use this model for tractography. Models for usage with DWI inputs @@ -29,11 +41,11 @@ Models for usage with DWI inputs ``ModelWithNeighborhood`` ************************* -Neighborhood usage: the class adds parameters to deal with a few choices of neighborhood definitions. + Neighborhood usage: the class adds parameters to deal with a few choices of neighborhood definitions. ``ModelWithOneInput`` ********************* -The ``MainAbstractModel`` makes no assumption of the type of data required. In this model here, we add the parameters necessary to add one input volume (ex: underlying dMRI data), choose this model, together with the DWIMLTrainerOneInput, and the volume will be interpolated and send to your model's forward method. Note that if you want to use many images as input, such as the FA, the T1, the dMRI, etc., this can still be considered as "one volume", if your prepare your hdf5 data accordingly by concatenating the images. + The ``MainAbstractModel`` makes no assumption of the type of input data required, only that it uses streamlines. In this model here, we added the parameters necessary to add one input volume (ex: the underlying dMRI data). Choose this model, together with the ``DWIMLTrainerOneInput``, and the volume will be interpolated and send to your model's forward method. Note that if you want to use many images as input, such as the FA, the T1, the dMRI, etc., this can still be considered as "one volume", if your prepare your hdf5 data accordingly by concatenating the images. -It defines parameters to add an embedding layer. + It also defines parameters to add an embedding layer. diff --git a/docs/for_developers/testing/general_testing.rst b/docs/for_developers/testing/general_testing.rst deleted file mode 100644 index b108a20c..00000000 --- a/docs/for_developers/testing/general_testing.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. _model_testing: - -General testing of a model --------------------------- - -This step depends on your model and your choice of metrics, but in generative models, you probably want to track on new data and verify the quality of your reconstruction. We have prepared a script that allows you to track from a model. diff --git a/docs/for_developers/testing/tracking_objects.rst b/docs/for_developers/testing/tracking_objects.rst index 2ca99b7c..9b797793 100644 --- a/docs/for_developers/testing/tracking_objects.rst +++ b/docs/for_developers/testing/tracking_objects.rst @@ -5,29 +5,23 @@ Tracking with your model For tracking, you may observe how scripts `l2t_track_from_model` or `tt_track_from_model` work. They use two main objects, the Tracker and the Propagator, similarly as in scilpy. -Tracker -------- +Trackers +-------- -toDO +- ``DWIMLAbstractTracker``: Performs tractography: starts from a seeding mask, and, at each point, advances one step using the model's output. -Propagator ----------- +- ``DWIMLTrackerFromWholeStreamline``: Child class in cases where we need to send the whole streamline to the model in order to generate the next point's position. We need to copy them in memory here as long as the streamline is not finished being tracked. -toDo +- ``DWIMLTrackerOneInput``: Child class where the dMRI input must be interpolated at each point (using the BatchLoader) and sent as input to the model. Can be combined with DWIMLTrackerFromWholeStreamline. -Similarities with scilpy ------------------------- - -If you are familiar with scilpy, here is a comparison of our Tracker and Propagator to theirs: -**Similarities:** - -- ToDo +Differences with scilpy +------------------------ -**Differences:** +If you are familiar with scilpy (`Renauld 2023 `_), you will notice similarity in the code. Here is a comparison of our Tracker to theirs: - In scilpy, the *theta* parameter defines an aperture cone inside which the next direction can be sampled. Here, sampling is not as straightforward. Ex, in the case of regression, the next direction is directly obtained from the model. Instead, theta is used as a stopping criterion. -- In scilpy, at each propagation step, the propagator uses the local model (ex, DTI, fODF) to decide the next direction. Here, the propagator sends data as input to the machine learning model. The model may receive additional inputs as compared to classical tractography (ex, the hidden states in RNNs, or the full beginning of the streamline in Transformers). +- In scilpy, at each propagation step, the propagator uses the local model (ex, DTI, fODF) to decide the next direction. Here, we send data as input to the machine learning model. The model may receive additional inputs as compared to classical tractography (ex, the hidden states in RNNs, or the full beginning of the streamline in Transformers). -- GPU processing: As dwi_ml users tend to use GPU/CPU more than scilpy users, we offer a GPU options, where many streamlines are created simultaneously, to take advantage of the GPU capacities. In scilpy, CPU is always used, although possibly with parallel processes (not fully implemented yet). +- GPU processing: We offer a GPU option, where many streamlines are created simultaneously, to take advantage of the GPU capacities. Our GPU option uses torch, whereas in scilpy, the GPU option uses openCL and a different implementation. diff --git a/docs/for_developers/training/trainers_details.rst b/docs/for_developers/training/trainers_details.rst deleted file mode 100644 index 7305605d..00000000 --- a/docs/for_developers/training/trainers_details.rst +++ /dev/null @@ -1,52 +0,0 @@ -.. _trainers_details: - -Trainers: the code explained -============================ - -THIS SECTION IS UNDER CONSTRUCTION. - -- They have a ``train_and_validate`` method that can be used to iterate on epochs (until a maximum number of iteration is reached, or a maximum number of bad epochs based on some loss). -- They save a checkpoint folder after each epoch, containing all information to resume the training any time. -- When a minimum loss value is reached, the model's parameters and states are save in a best_model folder. -- They save a good quantity of logs, both as numpy arrays (.npy logs) and online using Comet.ml. -- They know how to deal with the ``BatchSampler`` (which samples a list of streamlines to get for each batch) and with the ``BatchLoader`` (which gets data and performs data augmentation operations, if any). -- They prepare torch's optimizer (ex, Adam, SGD, RAdam), define the learning rate, etc. - -The ``train_and_validate``'s action, in short, is: - -.. code-block:: python - - for epoch in range(nb_epochs): - set_the_learning_rate - self.train_one_epoch() - self.validate_one_epoch() - if this_is_the_best_epoch: - save_best_model - save_checkpoint - -Where ``train_one_epoch`` does: - -.. code-block:: python - - for batch in batches: - self.run_one_batch() - self.back_propagation() - -And ``validate_one_epoch`` runs the batch but does not do the back-propagation. - -Finally, ``run_one_batch`` is not implemented in the ``DWIMLAbstractTrainer`` class, as it depends on your model. - - -Putting it all together -*********************** - -This class's main method is *train_and_validate()*: - -- Creates torch DataLoaders from the data_loaders. Collate_fn will be the sampler.load_batch() method, and the dataset will be sampler.source_data. - -- Trains each epoch by using compute_batch_loss, which should be implemented in each project's child class, on each batch. Saves the loss evolution and gradient norm in a log file. - -- Validates each epoch (also by using compute_batch_loss on each batch, but skipping the backpropagation step). Saves the loss evolution in a log file. - -After each epoch, a checkpoint is saved with current parameters. Training can be continued from a checkpoint using the script resume_training_from_checkpoint.py. - diff --git a/docs/for_developers/training/training.rst b/docs/for_developers/training/training.rst index 4dfaf47d..a213867c 100644 --- a/docs/for_developers/training/training.rst +++ b/docs/for_developers/training/training.rst @@ -3,46 +3,35 @@ Training your model =================== -If your model fits well with our structures, you can use our Trainer. If your model does not have specific needs, our Trainer should already be sufficient for you, and you can read section Using our Trainer below. Else, if you need to modify something, we explain our class more in detail below. +If your model fits well with our structures and does not have specific needs, our Trainers should already be sufficient for you. -The trainer: +Advantages of using our trainers +-------------------------------- -- Runs training and validation for all batchs, for a chosen number of epochs. -- Saves the state of the model and of the optimizer in a checkpoint directory, to allow resuming your experiment if it was stopped prematurely. +- **Checkpoints**: Our trainers save the model state at each epoch if it is the best one so far, in a folder best_model. They also always save the model state and optimizer state in a checkpoint. This way, if anything happens and your training is stopped, you can continue training from the latest checkpoint. +- **Data Management**: Our trainers know how to interact with your data in the HDF5 and your model. For instance, it can use the BatchSampler to sample streamlines at each batch, and the BatchLoader to interpolate the diffusion data at each coordinate. This way, your model class stays as simple as possible, purely AI-based layers, without the rigmarole and shenanigans of data management. -1. Our choices of trainers --------------------------- +- **Logs and visu**: They save many metrics as logs on your computer, which you can visualize with our scripts. It also sends data to comet.ml. See :ref:`visu_logs` for more information. -``DWIMLTrainer`` -************************ - -This is the main class. For every batch, it loads the chosen streamlines and uses the model, as explained in section 2 below. - -``DWIMLTrainerOneInput`` -************************ - -This trainer additionally loads one volume group and accessed the coordinates at each point of your streamlines, or possibly in a neighborhood at each coordinate. Of note, this is done as a separate step, and not through torch's DataLoaders (see explanation in :ref:`batch_loaders`), because interpolation of data is faster through GPU, if you have access, but DataLoaders always work on CPU. - -This trainer is expected to be used with a child of ``ModelWithOneInput`` (see page :ref:`other_main_models`). - -``DWIMLTrainerOneInputWithGVPhase`` -*********************************** - -We will soon publish how we have used a new generation-validation phase to supervise our models. +- **Heavy data - ready**: They can manage GPU usage and selecting sampling to limit the loading of heavy data. +- **Training options**: They prepare torch's optimizer (ex, Adam, SGD, RAdam), define the learning rate, etc. -2. Using a trainer for your model ---------------------------------- +Overview of the process +----------------------- This is an example of basic script that you could create to train your model with our trainer. It will require: - Your model -- An instance of our object ``MultiSubjectDataset``: it knows how to get data in the hdf5, possibly in a lazy way. See :ref:`ref_data_containers` for more information. -- An instance of a ``BatchSampler``: it knows how to sample a list of chosen streamlines for a batch. See :ref:`batch_sampler` for more information. -- An instance of a ``BatchLoader``: it knows how to load the data using the ``MultiSubjectDataset``, and how to modify the streamlines based on your model's requirements, for instance, adding noise or compressing / changing the step size / reversing / splitting the streamlines. See :ref:`batch_loader` for more information. +- An instance of our object ``MultiSubjectDataset``: the Trainer knows how to get data in the hdf5, possibly in a lazy way, and store it in a MultiSubjectDataset. See :ref:`ref_data_containers` for more information. +- An instance of a ``BatchSampler``: the Trainer knows how to sample a list of chosen streamlines for a batch. See :ref:`batch_sampler` for more information. +- An instance of a ``BatchLoader``: the Trainer knows how to load the data using the ``MultiSubjectDataset``, and how to modify the streamlines based on your model's requirements, for instance adding noise or compressing, changing the step size, and reversing or splitting the streamlines. See :ref:`batch_loaders` for more information. -Your final python script could look like:: +For instance, if you need a dMRI input, your final python script could look like this: + +.. code-block:: python + :linenos: # Loading the data, possibly with lazy option dataset = MultiSubjectDataset(hdf5_file) @@ -70,10 +59,101 @@ Your final python script could look like:: # Run the training! trainer.train_and_validate() +Once all objects are ready, the Trainer's method ``train_and_validate`` can be used to iterate on epochs until a maximum number of iteration is reached, or a maximum number of bad epochs based on some loss. + + +Our choices of trainers +----------------------- + +``DWIMLTrainer`` +************************ + +This is the main class. For every batch, it loads the chosen streamlines and uses the model, as explained in section 2 below. + +``DWIMLTrainerOneInput`` +************************ + +This trainer additionally loads one volume group and accessed the coordinates at each point of your streamlines, or possibly in a neighborhood at each coordinate. Of note, this is done as a separate step, and not through torch's DataLoaders (see explanation in :ref:`batch_loaders`), because interpolation of data is faster through GPU, if you have access, but DataLoaders always work on CPU. + +This trainer is expected to be used with a child of ``ModelWithOneInput`` (see page :ref:`other_main_models`). + +``DWIMLTrainerOneInputWithGVPhase`` +*********************************** + +We will soon publish how we have used a new generation-validation phase to supervise our models. + + +Trainers: the code explained +---------------------------- + +The Trainer's main method is ``train_and_validate``. It is summarized below. + +.. code-block:: python + :linenos: + + def self.train_and_validate(): + for epoch in range(nb_epochs): + # 1) set the learning rate + ... + + # 2) Train + self.train_one_epoch() + + # 3) Validate + self.validate_one_epoch() + + # 4) Save the model if it's the best epoch + if this_is_the_best_epoch: + ... + + # 5) Save a checkpoint + self.save_checkpoint() + +Other steps managed in this method include creating the torch DataLoader from the data_loaders. The DataLoader's collate_fn will be the sampler's load_batch() method. + +The ``train_one_epoch`` method and ``validate_one_epoch`` are similar, but validation excludes back-propagation. + +.. code-block:: python + :linenos: + + def self.train_one_epoch(): + for batch in batches: + self.run_one_batch() + + # If training: back-prop includes: + # - clip gradients + # - update torch's optimizer: self.optimizer.step() + # - reset torch's gradients: self.optimizer.zero_grad(set_to_none=True) + self.back_propagation() + +Finally, ``run_one_batch`` depends on your model. For instance, in ``DWIMLTrainerOneInput``, it interpolates the input at each point and calls the model: + +.. code-block:: python + :linenos: + + def self.run_one_batch(): + # 1) Send data to GPU if available + ... + + # 2) Formats the streamlines if required by the model + # ex: SOS, EOS + ... + + # 3) Interpolate the input (done in the BatchLoader) + batch_inputs = self.batch_loader.load_batch_inputs( + streamlines, ids_per_subj) + + # 4) Data augmentation if required + streamlines = self.batch_loader.add_noise_streamlines_forward( + streamlines, self.device) + # 5) Call the model + model_outputs = self.model(batch_inputs, streamlines_f) + # 6) Compute the loss + mean_loss, n = self.model.compute_loss(model_outputs, targets, + average_results=True) -3. Visualizing logs ---------------------- + return mean_loss, n -See :ref:`visu_logs`. \ No newline at end of file +If this is not right for you, you can override the DWIMLTrainer and re-code this last method. diff --git a/docs/for_users/from_start_to_finish.rst b/docs/for_users/from_start_to_finish.rst deleted file mode 100644 index 880233fb..00000000 --- a/docs/for_users/from_start_to_finish.rst +++ /dev/null @@ -1,68 +0,0 @@ - -Training and using: from start to finish -======================================== - -If you want, you can use our scripts to train our models with a new set of hyperparameters. No matter the model, the process will probably contain the following steps: - -1. Creating a hdf5 file. Our library works with data in the hdf5 format. See :ref:`hdf5_usage` for more information. - -2. Training the model. At each epoch, the script saves the model state if it is the best one so far, in a folder ``best_model``, but also always saves the model state and optimizer state in a checkpoint. This way, if anything happens and your training is stopped, you can continue training from the latest checkpoint. - -3. Visualizing the logs to make sure you are satisfied with the results. For more information, see :ref:`visu_logs` for more information. - -4. Using your newly trained model! For tractography models, this uses scripts such as ``**_track_from_model``. See :ref:`user_tracking` for more information. - -Denoising models ----------------- - -Coming soon: Autoencoder (AE) model! - -Tractography models -------------------- - -Learn2track (l2t) -***************** - -Full steps:: - - # Create a hdf5 file - dwiml_create_hdf5_dataset $input_folder $out_file $config_file \ - $training_subjs $validation_subjs $testing_subjs - - # Train a model. Play with options! Here are the mandatory inputs: - l2t_train_model $saving_path $experiment_name $hdf5_file \ - $input_group_name $streamline_group_name - - # If you want to train your model a little more... - l2t_resume_training_from_checkpoint $saving_path $experiment_name \ - --new_patience 10 --new_max_epochs 300 - - # Visualize the logs - dwiml_visualize_logs $saving_path - - # See which points of your validation streamlines have the worst loss - l2t_visualise_loss $saving_path $hdf5_file $subj $input_group_name - - # Once happy, use your final model to track from it! - l2t_track_from_model $saving_path $subj $input_group $out_tractgram $seeding_mask_group - - - -TractographyTransformers (tt) -***************************** - -Full steps:: - - dwiml_create_hdf5_dataset $input_folder $out_file $config_file \ - $training_subjs $validation_subjs $testing_subjs - - tt_train_model ... - - tt_resume_training_from_checkpoint ... - - tt_track_from_model ... - - dwiml_visualize_logs ... - tt_visualize_loss ... - tt_visualize_weights ... - diff --git a/docs/for_users/from_start_to_finish_denoising.rst b/docs/for_users/from_start_to_finish_denoising.rst new file mode 100644 index 00000000..2e53ed9c --- /dev/null +++ b/docs/for_users/from_start_to_finish_denoising.rst @@ -0,0 +1,6 @@ +.. _from_start_to_finish_denoising: + +Training denoising models: from start to finish +=============================================== + +Coming soon: Autoencoder (AE) model! \ No newline at end of file diff --git a/docs/for_users/from_start_to_finish_tracking.rst b/docs/for_users/from_start_to_finish_tracking.rst new file mode 100644 index 00000000..24dbf1fe --- /dev/null +++ b/docs/for_users/from_start_to_finish_tracking.rst @@ -0,0 +1,100 @@ +.. _from_start_to_finish_tracking: + +Training tracking models: from start to finish +============================================== + +If you want, you can use our scripts to train our models with a new set of hyperparameters! + +Overview of the process +*********************** + +No matter the model, the process will probably contain the following steps: + +1. **Creating a hdf5 file**. + Our library works with data in the hdf5 format. See :ref:`hdf5_usage` for more information. + +2. **Training the model**. + At each epoch, the script saves the model state if it is the best one so far, in a folder ``best_model``, but also always saves the model state and optimizer state in a checkpoint. This way, if anything happens and your training is stopped, you can continue training from the latest checkpoint. + + To learn more about training options, see the help. For instance, ``l2t_train_model --help`` or ``tt_train_model --help``. + +3. **Visualizing the logs**. + We have tools to help you supervise the results. See :ref:`visu_logs` for more information. + +4. **Visualizing the loss**. + You can use your favorite .trk visualizer (ex, Mi-Brain) to view the local loss along points of your streamlines. + +5. **Using your newly trained model!** + See :ref:`tractography_models` for more information. + + +Learn2track (l2t) +***************** + +Here is what your bash script will look like for Learn2track. For each step, you will have many options to define! + +.. code-block:: bash + + # Create a hdf5 file + # Most options are given through the config file + dwiml_create_hdf5_dataset $input_folder $out_file $config_file \ + $training_subjs $validation_subjs $testing_subjs + + # Train a model. + # Play with options! Here are the mandatory inputs: + l2t_train_model $saving_path $experiment_name $hdf5_file \ + $input_group_name $streamline_group_name + + # If you want to train your model a little more... + l2t_resume_training_from_checkpoint $saving_path $experiment_name \ + --new_patience 10 --new_max_epochs 300 + + # Visualize the logs + dwiml_visualize_logs $saving_path/$experiment_name + + # See which points of your training streamlines have the worst loss + l2t_visualise_loss $saving_path/$experiment_name $hdf5_file $subj $input_group_name \ + --out_prefix colored_tractogram --subset training \ + --use_gpu --batch_size 400 --streamlines_group $streamlines \ + --colormap turbo --show_now --compute_histogram --save_colored_tractogram \ + --save_colored_best_and_worst 0.5 + + # Once happy, use your final model to track from it! + l2t_track_from_model $saving_path/$experiment_name $subj $input_group $out_tractgram $seeding_mask_group + + + +TractoTransformers (tt) +*********************** + +Here is watch your bash script will look like for TractoTransformers. For each step, you will have many options to define! + +.. code-block:: bash + + # Create a hdf5 file + # Most options are given through the config file + dwiml_create_hdf5_dataset $input_folder $out_file $config_file \ + $training_subjs $validation_subjs $testing_subjs + + # Train a model. + # Play with options! Here are the mandatory inputs: + tt_train_model $saving_path $experiment_name $hdf5_file \ + $input_group_name $streamline_group_name + + # If you want to train your model a little more... + tt_resume_training_from_checkpoint $saving_path $experiment_name \ + --new_patience 10 --new_max_epochs 300 + + # Visualize the logs + dwiml_visualize_logs $saving_path/$experiment_name + + # See which points of your streamlines have the worst loss + tt_visualize_loss $saving_path/$experiment_name $hdf5_file $subj $input_group_name + + # Once happy, use your final model to track from it! + tt_track_from_model $saving_path/$experiment_name $subj $input_group $out_tractgram $seeding_mask_group + + # Visualize where the attention focuses for your tractogram! + tt_visualize_weights $saving_path/$experiment_name $hdf5_file $subj $input_group_name $out_tractogram + + diff --git a/docs/for_users/hdf5.rst b/docs/for_users/hdf5.rst index 1f5441bc..060dd7d9 100644 --- a/docs/for_users/hdf5.rst +++ b/docs/for_users/hdf5.rst @@ -67,28 +67,27 @@ This folder is the most important one and must be organized in a very precise wa Preparing the config file ************************* -To create the hdf5 file, you will need a config file such as below. HDF groups will be created accordingly for each subject in the hdf5. +To create the hdf5 file, you will need a configuration file such as below. HDF groups will be created accordingly for each subject in the hdf5. Here is an example of a working config file. -.. code-block:: bash +.. code-block:: json { "input": { "type": "volume", - "files": ["dwi/dwi.nii.gz", "anat/t1.nii.gz", "dwi/*__dwi.nii.gz], --> Will get, for instance, subX__dwi.nii.gz + "files": ["dwi/dwi.nii.gz", "anat/t1.nii.gz", "dwi/*__dwi.nii.gz"], "standardization": "all", - "std_mask": [masks/some_mask.nii.gz] + "std_mask": ["masks/some_mask.nii.gz"] }, "target": { "type": "streamlines", - "files": ["tractograms/bundle1.trk", "tractograms/wholebrain.trk", "tractograms/*__wholebrain.trk"], ----> Will get, for instance, sub1000__bundle1.trk + "files": ["tractograms/bundle1.trk", "tractograms/wholebrain.trk", "tractograms/*__wholebrain.trk"], "connectivity_matrix": "my_file.npy", - "connectivity_nb_blocs": 6 ---> OR - "connectivity_labels": labels_volume_group, - "dps_keys": ['dps1', 'dps2'] + "connectivity_labels": "labels_volume_group", + "dps_keys": ["dps1", "dps2"] } "bad_streamlines": { "type": "streamlines", - "files": ["bad_tractograms/*"] ---> Will get all trk and tck files. + "files": ["bad_tractograms/*"] } "wm_mask": { "type": "volume", @@ -101,9 +100,9 @@ To create the hdf5 file, you will need a config file such as below. HDF groups w General group attributes in the config file: """""""""""""""""""""""""""""""""""""""""""" -Each group key will become the group's **name** in the hdf5. It can be anything you want. We suggest you keep it significative, ex 'input_volume', 'target_volume', 'target_directions'. In other scripts (ex, l2t_train_model.py, tt_train_model.py, etc), you will often be asked for the labels given to your groups. +Each group key will become the group's **name** in the hdf5. It can be anything you want. We suggest you keep it significative, ex 'input_volume', 'target_volume', 'target_directions'. -Each group may have a number of parameters: +Each group must have the following parameters: - **"type"**: It must be recognized in dwi_ml. Currently, accepted datatype are: @@ -126,9 +125,8 @@ Additional attributes for volume groups: - "per_file", to apply it independently on each file included in the group. - "none", to skip this step (default) -****A note about data standardization** - -If all voxel were to be used, most of them would probably contain the background of the data, bringing the mean and std probably very close to 0. Thus, non-zero voxels only are used to compute the mean and std, or voxels inside the provided mask if any. If a mask is provided, voxels outside the mask could have been set to NaN, but the simpler choice made here was to simply modify all voxels [ data = (data - mean) / std ], even voxels outside the mask, with the mean and std of voxels in the mask. Mask name is provided through the config file. It is formatted as a list: if many files are listed, the union of the binary masks will be used. +.. note:: + **A note about data standardization:** If all voxel were to be used, most of them would probably contain the background of the data, bringing the mean and std probably very close to 0. Thus, non-zero voxels only are used to compute the mean and std, or voxels inside the provided mask if any. If a mask is provided, voxels outside the mask could have been set to NaN, but the simpler choice made here was to simply modify all voxels [ data = (data - mean) / std ], even voxels outside the mask, with the mean and std of voxels in the mask. Mask name is provided through the config file. It is formatted as a list: if many files are listed, the union of the binary masks will be used. Additional attributes for streamlines groups: diff --git a/docs/for_users/models/denoising_models.rst b/docs/for_users/models/denoising_models.rst index 06874f6d..9aff1771 100644 --- a/docs/for_users/models/denoising_models.rst +++ b/docs/for_users/models/denoising_models.rst @@ -1,7 +1,7 @@ .. _denoising_models: -Denoising models -================ +Using denoising models +====================== Coming soon: Autoencoder (AE) model! diff --git a/docs/for_users/models/our_models.rst b/docs/for_users/models/our_models.rst deleted file mode 100644 index 835c68bf..00000000 --- a/docs/for_users/models/our_models.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. _our_models: - -Our models -========== - - .. toctree:: - :maxdepth: 1 - :caption: Our models - - tractography_models - denoising_models diff --git a/docs/for_users/models/tractography_models.rst b/docs/for_users/models/tractography_models.rst index f69f4c41..f74c1c5a 100644 --- a/docs/for_users/models/tractography_models.rst +++ b/docs/for_users/models/tractography_models.rst @@ -1,14 +1,79 @@ .. _tractography_models: -Tractography models -=================== +Using tractography models +========================= -For more explanation on how to use models for tracking, see :ref:`user_tracking`. +In both cases, the input data must be formatted as an hdf5. See :ref:`hdf5_usage` for more information. -TractographyTransformers (tt) -***************************** -This uses transformers and should be the subject of an upcoming publication. +Hdf5 preparation for our two tractography models +************************************************ + +Examples below suppose that your hdf5 has the following properties: + + - A subject $subj on which you want to run the tractography. + - An hdf5 volume group called 'inputs', which contains the input data for the model + - An hdf5 volume group called 'seeding_mask', which contains a binary mask for the seeds placement (ex, the WM-GM interface). + - An hdf5 volume group called 'tracking_mask', which contains a binary mask for the tractography (ex, the white matter). + +The config file below can be used for Learn2track and TractoTransformer: + +.. code-block:: json + + { + "input": { + "type": "volume", + "files": ["anat/*__T1w.nii.gz", "dwi/*__fa.nii.gz", "dwi/*__fodf.nii.gz"], + "standardization": "per_file", + "std_mask": ["masks/*__brain_mask.nii.gz"] + }, + "wm_mask": { + "type": "volume", + "files": ["masks/*__mask_wm.nii.gz"], + "standardization": "none" + }, + "interface_mask": { + "type": "volume", + "files": ["other_masks/*__interface.nii.gz"], + "standardization": "none" + }, + } + +Tracking from a model +********************* + +AI-based tractography is similar to classical tractography, except the inputs are sent throughout the model at each point. This can be heavier in memory and requires some adaptation in the code. Our tracking scripts offer the possibility to track many streamlines at once, in order to make it efficient. All streamlines are sent in the model at once. Options below are common to our two models: + + - ``--use_gpu``: If your computer has a GPU, we strongly recommand using it. Then, you may also track many streamlines at once for an efficient usage of the model. Use option ``--simultaneous_tracking``. For instance, with 10Gb GPU, we could launch ~500 streamlines at the time. + - ``--algo``: Choices are 'det' or 'prob'. Depending on the choice of output (regression, classification, etc.), the model may not support probabilistic tractography. + - ``-max_length``: Useful to avoid tracking long sequences. ``--min_length`` is used as post-tractography filtering. + - ``--eos_stop``: If your model was created with an EOS option (end of streamline), you can use the learned EOS probability as an additionnal stopping criteria during tracking. + - ``--tracking_mask``: This option is facultative if you use ``--eos_stop``. + - ``--discard_last``: If a tracking mask is used, default is to stop when leaving the mask, but to keep the last point. You can ensure all points are inside the tracking mask with option discard_last. + - ``-npv``: Number of seeds per voxel in the seeding mask. + - ``--help``: Use the help to learn more about other options. + +Required parameters are: + + - ``experiment_path``: Where your experiment folder is. It should contain a "best_model" sub-folder. + - ``subj_id``: A subject in your hdf5 file. + - ``input_group``: The name of the hdf5 volume group to use as input in the model. + - ``out_tractogram``: The ouput filename. + - ``seeding_mask_group``: The name of the hdf5 volume group to use as seeding mask. + +You can use your favorite .trk visualizer (ex, Mi-Brain) to view the local loss along points of your streamlines. This example supposes your hdf5 has a subject $subj, a volume group 'input' as input to the model, and examples of streamlines in the streamlines group 'streamlines': + +.. code-block:: bash + + l2t_visualize_loss my_experiment $hdf5 $subj input --out_prefix colored_tractogram \ + --use_gpu --batch_size 400 --streamlines_group streamlines \ + --colormap turbo --show_now --compute_histogram --save_colored_tractogram \ + --save_colored_best_and_worst 0.5 + +TractoTransformers (tt) +*********************** + +This uses transformers and should be the subject of an upcoming publication. Its name, TractoTransformer, reflects that this model is similar to the one proposed in `Waizman et al., 2025 `_. Others have also published Transformer models for Tractography, but did not name their model. .. image:: /_static/images/Transformers.png @@ -16,9 +81,16 @@ This uses transformers and should be the subject of an upcoming publication. :width: 600 -To use this model, run script `tt_track_from_model.py`. . To learn more, run:: +To use this model, run script ``tt_track_from_model.py``. For instance: + +.. code-block:: bash - tt_track_from_model --help + tt_track_from_model -f -v \ + --algo det --min_length 10 --max_length 200 \ + --npv 1 --use_gpu --simultaneous_tracking 500 --discard_last \ + --tracking_mask_group $wm_mask --eos_stop 0.5 --hdf5 $hdf5_file \ + $experiment_folder $tracking_subj $input out_tractogram.trk \ + $interface_mask Learn2track (l2t) @@ -30,7 +102,13 @@ This is a refactored version of the code prepared by authors of `Poulin2017 `_. Click on the key icon, copy value to the clipboard and save it in your file in $HOME. +An API (application programming interface) is a code that gets passed in by applications, containing information to identify its user, for instance. To get an API key, see ``_. Click on the key icon, copy value to the clipboard and save it in your file in $HOME. Installing dwi_ml diff --git a/docs/index.rst b/docs/index.rst index 7f844c49..ac4547bc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,7 +9,7 @@ using machine learning and deep learning methods. It is mostly focused on the tr :align: center :width: 500 -In this doc, we will present you everything included in this library for you to become either a developer or a user. Note that to get a full understanding of every line of code, you can browse further in each section. +In this doc, we will present you everything included in this library for you to become either a developer or a user. On this page: @@ -39,17 +39,17 @@ Pages in this section explain how to use our scripts to use our pre-trained mode - **1. Downloading models**: If you want to use our pre-trained models, you may contact us for access to the models learned weights. They will be available online once publications are accepted. - - :ref:`our_models` - - :ref:`tractography_models` - - :ref:`denoising_models` - - **2. Organizing your data**: In most cases, data must be organized correctly as a hdf5 before usage. Follow the link below for an explanation. - :ref:`hdf5_usage` - **3. Using our models to perform tractography**: Use our models to track on your own subjects! - - :ref:`user_tracking` + - :ref:`tractography_models` + +- **OR, Using our models to denoise your tractograms**: (upcoming) + + - :ref:`denoising_models` .. --------------------Hidden toctree: --------------- @@ -59,11 +59,9 @@ Pages in this section explain how to use our scripts to use our pre-trained mode :hidden: :caption: Explanations for users (pre-trained) - for_users/models/our_models for_users/hdf5 - for_users/tracking - ------------------------------- + for_users/models/tractography_models + for_users/models/denoising_models .. _section_advanced_users: @@ -72,13 +70,13 @@ Pages in this section explain how to use our scripts to use our pre-trained mode Pages in this section are useful if you want to train a model based on pre-existing code, such as Learn2track or TractographyTransformers, using your favorite set of hyperparameters. -(Improved documentation coming soon!) .. toctree:: :maxdepth: 2 - :caption: Explanations for users (re-train) + :caption: Explanations for explorers - for_users/from_start_to_finish + for_users/from_start_to_finish_tracking + for_users/from_start_to_finish_denoising for_users/visu_logs .. _section_developers: @@ -92,6 +90,10 @@ Page in this section explain more in details how the code is implemented in pyth - :ref:`create_your_model` + - :ref:`main_abstract_model` + - :ref:`other_main_models` + - :ref:`direction_getters` + - **2. Explore our hdf5 organization**: Our library has been organized to use data in the hdf5 format. Our hdf5 data organization should probably be enough for your needs. - :ref:`hdf5_usage` @@ -100,26 +102,24 @@ Page in this section explain more in details how the code is implemented in pyth - **3. Train your model**: Take a look at how we have implemented our trainers for an efficient management of heavy data. Note that our trainer uses Data Management classes such as our BathLoader and BatchSampler. See below for more information. - :ref:`trainers` - - :ref:`trainers_details` - :ref:`data_management_index` -- **4. Use your trained model**: Discover our objects allowing to perform a full tractography from a tractography-model. + - :ref:`ref_data_containers` + - :ref:`batch_sampler` + - :ref:`batch_loaders` + +- **4. Use your trained model**: This step depends on your model. For tractography models, discover our objects allowing to perform a full tractography from a tractography model. You can also see our pages for Learntrack and TractoTransformer usage: :ref:`tractography_models`. - - :ref:`model_testing` - - :ref:`user_tracking` - :ref:`tracking` -.. --------------------Hidden toctree: --------------- .. toctree:: - :maxdepth: 1 + :maxdepth: 3 :caption: Explanations for developers :hidden: for_developers/models/index for_developers/hdf5/advanced_hdf5_organization for_developers/training/training - for_developers/training/trainers_details for_developers/data_management/index - for_developers/testing/general_testing for_developers/testing/tracking_objects \ No newline at end of file diff --git a/docs/print_toctree.py b/docs/print_toctree.py new file mode 100644 index 00000000..fdbab019 --- /dev/null +++ b/docs/print_toctree.py @@ -0,0 +1,69 @@ +from sphinx.application import Sphinx + +""" +This is useful to get a map of the website. + +USAGE: + >> sphinx-build -b dummy docs/ _build + >> python docs/print_toctree.py +""" + +app = Sphinx( + srcdir="docs", + confdir="docs", + outdir="_build", + doctreedir="_build/.doctrees", + buildername="dummy", +) + +app.build() +env = app.env + + +def print_sections(source, indent="", is_last=False): + # 1. Prepare connector. + connector = "└── " if is_last else "├── " + + # 2. Captions? + doctree = env.get_doctree(source) + txt = str(doctree) + if 'toctree caption=' in txt: + tmp = txt.split('toctree caption="') + for section in range(1, len(tmp)): + info = tmp[section].split('" entries="[') + caption = info[0] + caption = "** " + caption + " **" + print(indent + connector + caption) + + entries = info[1].split("]")[0] + entries = entries.split(',') + for entrie in entries: + if 'None' in entrie: + continue + else: + entrie = entrie.strip(')').strip().strip("'") + print_tree(entrie, " ") + + +def print_tree(filename, indent="", is_last=True): + connector = "└── " if is_last else "├── " + + # Get the title if it exists, else will print the filename + title = env.titles.get(filename) + if title is not None: + display_name = title.astext() + else: + display_name = filename + + # Print the current file + print(indent + connector + display_name) + + # Now print its children + children = env.toctree_includes.get(filename, []) + for i, child in enumerate(children): + last = i == len(children) - 1 + new_prefix = indent + (" " if is_last else "│ ") + print_tree(child, new_prefix, last) + +print("Site structure:\n") +print_sections("index") diff --git a/src/dwi_ml/general/training/trainers.py b/src/dwi_ml/general/training/trainers.py index 94945d07..531436da 100644 --- a/src/dwi_ml/general/training/trainers.py +++ b/src/dwi_ml/general/training/trainers.py @@ -1203,14 +1203,15 @@ def run_one_batch(self, targets, ids_per_subj): n: int Total number of points for this batch. """ + # 1) Send to GPU # Dataloader always works on CPU. Sending to right device. # (model is already moved). targets = [s.to(self.device, non_blocking=True, dtype=torch.float) for s in targets] - # Getting the inputs points from the volumes. - # Uses the model's method, with the batch_loader's data. - # Possibly skipping the last point if not useful. + # 2) Format the streamlines + # Possibly skipping the last point if not useful (no EOS). Avoids + # interpolation for no reason at that point. streamlines_f = targets if isinstance(self.model, ModelWithDirectionGetter) and \ not self.model.direction_getter.add_eos: @@ -1218,25 +1219,29 @@ def run_one_batch(self, targets, ids_per_subj): # associated target direction. streamlines_f = [s[:-1, :] for s in streamlines_f] - # Batch inputs is already the right length. Models don't need to - # discard the last point if no EOS. Avoid interpolation for no reason. + # 3) Interpolate the inputs batch_inputs = self.batch_loader.load_batch_inputs( streamlines_f, ids_per_subj) - logger.debug('*** Computing forward propagation') - # todo Possibly add noise to inputs here. Not ready - # Now add noise to streamlines for the forward pass + # 4) Data augmentation: noise is done AFTER interpolation + # Adds noise to streamlines for the forward pass only # (batch loader will do it depending on training / valid) streamlines_f = self.batch_loader.add_noise_streamlines_forward( streamlines_f, self.device) + + # 5) Call the model + logger.debug('*** Computing forward propagation') model_outputs = self.model(batch_inputs, streamlines_f) del streamlines_f - logger.debug('*** Computing loss') + # 6) Other option for data augmentation: noise on the targets only. # Add noise to targets. # (batch loader will do it depending on training / valid) targets = self.batch_loader.add_noise_streamlines_loss(targets, self.device) + + # 6) Compute the loss + logger.debug('*** Computing loss') mean_loss, n = self.model.compute_loss(model_outputs, targets, average_results=True)