diff --git a/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py b/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py index 7e967fab..91837d0d 100644 --- a/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py +++ b/src/dwi_ml/cli/dwiml_divide_volume_into_blocs.py @@ -1,5 +1,15 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +""" +Utilitary script to allow investigating the "blocs" if we divide a volume +by L x P x H blocs. + +Useful to understand what the option 'connectivity_nb_blocs' does in the config +file of the HDF5 creation +(see here https://dwi-ml.readthedocs.io/en/latest/for_users/hdf5.html). + +""" import argparse import nibabel as nib @@ -15,20 +25,23 @@ def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter) p.add_argument('in_image', metavar='IN_FILE', - help='Input file name, in nifti format.') + help='Input file name, in nifti format. Any reference file.') p.add_argument('out_filename', help='name of the output file, which will be saved as a ' - 'text file.') + 'nifti file.') p.add_argument('nb_blocs', nargs='+', type=int, help="Number of blocs. Either a single int, or a list of " "3 values.") + p.add_argument('--shuffle_colors', action='store_true', + help="If set, will randomly shuffle the label of blocs in " + "the volume.") add_overwrite_arg(p) return p -def color_mri_connectivity_blocs(nb_blocs, volume_size): +def color_mri_connectivity_blocs(nb_blocs, volume_size, shuffle_colors): # For tracking coordinates: we can work with float. # Here, dividing as ints. @@ -37,13 +50,24 @@ def color_mri_connectivity_blocs(nb_blocs, volume_size): sizex, sizey, sizez = (volume_size / nb_blocs).astype(int) print("Coloring into blocs of size: ", sizex, sizey, sizez) + # Preparing colors. + nb_blocs_total = np.prod(nb_blocs) + all_colors = np.arange(nb_blocs_total) + + # Shuffling. Else, in Mi-Brain, the colors of blocs on beside the other are + # the same color. + if shuffle_colors: + all_colors = np.random.permutation(all_colors) + + all_colors = all_colors.reshape(nb_blocs) + final_volume = np.zeros(volume_size) for i in range(nb_blocs[0]): for j in range(nb_blocs[1]): for k in range(nb_blocs[2]): final_volume[i*sizex: (i+1)*sizex, j*sizey: (j+1)*sizey, - k*sizez: (k+1)*sizez] = i + 10*j + 100*k + k*sizez: (k+1)*sizez] = all_colors[i, j, k] return final_volume @@ -61,7 +85,8 @@ def main(): volume = nib.load(args.in_image) # Processing - final_volume = color_mri_connectivity_blocs(args.nb_blocs, volume.shape) + final_volume = color_mri_connectivity_blocs(args.nb_blocs, volume.shape, + args.shuffle_colors) # Saving img = nib.Nifti1Image(final_volume, volume.affine) diff --git a/src/dwi_ml/cli/dwiml_visualize_logs.py b/src/dwi_ml/cli/dwiml_visualize_logs.py index e45f536d..d572e4bc 100644 --- a/src/dwi_ml/cli/dwiml_visualize_logs.py +++ b/src/dwi_ml/cli/dwiml_visualize_logs.py @@ -8,22 +8,31 @@ The number of graphs per figure can be modified with --nb_plots_per_fig. +Specifying which logs +--------------------- You may also specify the logs you want to plot. Use the option --graph, with -the graph's title and the name(s) of the logs to add to one graph. This option -can be used many times. Ex: +the graph's title and the name(s) of the logs to add to one graph. Ex: >> --graph "Training loss" train_loss_monitor_per_epoch ->> --graph "Some plot" log1 --graph "Two plots" log2 log3 +>> --graph "Two plots" log2 log3 +You may use many times the option --graph. + +Specifying y-axis limits +------------------------ Optionally, you can add the 2-valued ylims for each graph. These value will supersede the given --ylim, if any. >> --graph "Some plot" log1 0 100 +Operations on logs +------------------ Finally, you can also supply operations to apply to your logs, amongst: ['diff', 'sum']. Ex: >> --graph "Training minus validation" diff(log1, log2) 0 100 +>> --graph "Value1 plus value2" sum(log1, log2) 0 100 + ** Note that we only accept one operation per graph. The following is not supported: ->> --graph "Training minus validation" diff(log1, log2) log 3 0 100 +>> --graph "Training minus validation" diff(log1, log2) log3 0 100 ------------------------------ """ @@ -62,7 +71,8 @@ def _build_arg_parser(): g = p.add_argument_group("Figure options") g.add_argument("--graph", action='append', nargs='+', dest='graphs', - help="See description above for usage.") + help="See description above for usage." + "If not set, will plot all logs in current directory.") g.add_argument("--nb_plots_per_fig", type=int, default=3, metavar='n', help="Number of (rows) of plot per figure. Default: 3.") g.add_argument('--xlim', type=int, metavar='epoch_max', @@ -86,7 +96,18 @@ def _build_arg_parser(): def _parse_graphs_arg(parser, args): - """Parse args.graphs""" + """ + Parse args.graphs. Possible options: + --graph title log_name ylim1 ylim2 + --graph title diff(log1, log2) ylim1 ylim2 + + Returns + ------- + graphs_titles: list + graphs_logs: list + graphs_ylims: list + graph_operations: list + """ if args.graphs is None: return None, None, None, None else: @@ -121,7 +142,7 @@ def _parse_graphs_arg(parser, args): _ylims = None # Verify if user gave an operation (diff or sum) - _logs, operation = __parse_log_operations(parser, _logs) + _logs, operation = _parse_graphs_arg_operations(parser, _logs) # Remove .npy to log names if added by user. for i, log in enumerate(_logs): @@ -141,7 +162,11 @@ def _parse_graphs_arg(parser, args): return graphs_titles, graphs_logs, graphs_ylims, graph_operations -def __parse_log_operations(parser, graph): +def _parse_graphs_arg_operations(parser, graph): + """ + Used when parsing option args.graphs. Checking if the operation uses + diff or sum. + """ if len(graph) == 1: _graph = graph[0] if _graph[0:5] == 'diff(': @@ -171,7 +196,7 @@ def __parse_log_operations(parser, graph): def _load_all_logs(parser, args, logs_path, previous_graphs): """ - Load all logs in an experiment's dir (no option --graph was supplied) + Load all logs in an experiment's dir (if no option --graph was supplied) """ logging.debug("Loading all logs for that experiment.") files_to_load = list(logs_path.glob('*.npy')) @@ -203,7 +228,8 @@ def _load_all_logs(parser, args, logs_path, previous_graphs): def _load_chosen_logs(parser, args, logs_path, parsed_graphs, graph_operations): """ - Load only logs specified through --graph + Load only logs specified through if option --graph was given. + (args.graphs is already parsed) """ logging.debug("Loading only chosen logs for that experiment") this_exp_dict = {} @@ -275,7 +301,9 @@ def main(): pil_logger = logging.getLogger('PIL') pil_logger.setLevel('WARNING') + # --------------- # Verifications + # --------------- if not (args.save_figures or args.show_now): parser.error("This script will plot nothing. Choose either --show_now " "or --save_figures.") @@ -291,9 +319,12 @@ def main(): # Prefix given, but directory does not exist. parser.error("Output dir for figures does not exist.") - # Loop on all experiments + # --------------- + # Loop on all experiments to load all logs + # --------------- loaded_logs = {} # dict of dicts for i, exp_path in enumerate(args.experiments): + # Verifications for this experiment if not pathlib.Path(exp_path).exists(): raise ValueError("Experiment folder does not exist: {}" @@ -314,7 +345,9 @@ def main(): graphs_logs, graphs_operation) loaded_logs[exp_name] = this_exp_dict + # --------------- # Formatting the final graphs choice. + # --------------- if args.graphs is None: graphs_titles = graphs_logs graphs_logs = [[log] for log in graphs_logs] @@ -330,6 +363,9 @@ def main(): graphs_ylims = [args.ylims if ylim is None else ylim for ylim in graphs_ylims] + # --------------- + # MAIN CALL: Plotting everything + # --------------- _args = (loaded_logs, graphs_titles, graphs_logs, graphs_ylims, args.nb_plots_per_fig) kwargs = {'xlim': args.xlim, diff --git a/src/dwi_ml/cli/dwiml_visualize_logs_correlation.py b/src/dwi_ml/cli/dwiml_visualize_logs_correlation.py index b1ecfee2..8f1f9a98 100644 --- a/src/dwi_ml/cli/dwiml_visualize_logs_correlation.py +++ b/src/dwi_ml/cli/dwiml_visualize_logs_correlation.py @@ -31,6 +31,15 @@ def _build_arg_parser(): p.add_argument('--ignore_first_epochs', type=int, metavar='n', help="If set, ignores the first n epochs of each " "experiment.") + p.add_argument('--show_individual_logs', action='store_true', + help="If set, shows individual logs as well as the " + "correlation graph (3 graphs total)") + p.add_argument('--show_first_order_fit', action='store_true', + help="If set, shows first order fit.") + p.add_argument('--show_second_order_fit', action='store_true', + help="If set, show the quadratic fit.") + p.add_argument('--xlim', nargs=2, type=float) + p.add_argument('--ylim', nargs=2, type=float) add_overwrite_arg(p) add_verbose_arg(p) @@ -54,7 +63,9 @@ def _load_chosen_logs(parser, args, logs_path): def _compute_correlations(loaded_dicts, log1_key, log2_key, - name_log1, name_log2, first_epoch): + name_log1, name_log2, first_epoch, xlim, ylim, + show_individual, + show_first_order, show_second_order): # One color per experiment jet = plt.get_cmap('jet') exp_names = list(loaded_dicts.keys()) @@ -64,7 +75,11 @@ def _compute_correlations(loaded_dicts, log1_key, log2_key, all_x = [] all_y = [] labels = [] - fig, axs = plt.subplots(3, 1) + if show_individual: + fig, axs = plt.subplots(3, 1) + else: + fig, axs = plt.subplots(1, 1) + axs = [axs] for i, exp in enumerate(loaded_dicts.keys()): color_val = scalar_map.to_rgba(i) x = loaded_dicts[exp][log1_key][first_epoch:] @@ -74,30 +89,70 @@ def _compute_correlations(loaded_dicts, log1_key, log2_key, corr = np.corrcoef(x, y)[0][1] print("Correlation for exp {} is {}".format(exp, corr)) - axs[0].scatter(epochs, x, color=color_val, s=10) - axs[1].scatter(epochs, y, color=color_val, s=10) - axs[2].scatter(x, y, color=color_val, s=10) + if show_individual: + axs[0].scatter(epochs, x, color=color_val, s=10) + axs[1].scatter(epochs, y, color=color_val, s=10) + axs[2].scatter(x, y, color=color_val, s=10) + axs[0].set_ylabel(name_log1) + axs[0].set_xlabel("Epochs") + axs[1].set_ylabel(name_log2) + axs[1].set_xlabel("Epochs") + ax_corr = axs[2] + else: + axs[0].scatter(x, y, color=color_val, s=10) + ax_corr = axs[0] labels.append(exp + ':{:.2f}'.format(corr)) all_x.extend(x) all_y.extend(y) corr = np.corrcoef(all_x, all_y)[0][1] - b, m = np.polynomial.polynomial.polyfit(all_x, all_y, 1) - xx = np.linspace(np.min(all_x), np.max(all_x), 100) - axs[2].plot(xx, b + m * xx, color='k', - label="y={:.4f}x + {:.4f}".format(m, b)) - axs[2].legend() - - print("Correlation over all experiments is {}".format(corr)) - axs[0].set_ylabel(name_log1) - axs[0].set_xlabel("Epochs") - axs[1].set_ylabel(name_log2) - axs[1].set_xlabel("Epochs") - axs[2].set_xlabel(name_log1) - axs[2].set_ylabel(name_log2) - axs[2].set_title("Correlation between {} and {} = {:.4f}" - .format(name_log1, name_log2, corr)) + idx = np.argsort(all_x) + all_x = np.asarray(all_x)[idx] + all_y = np.asarray(all_y)[idx] + titre = ("Correlation between {} and {} = {:.3f}." + .format(name_log1, name_log2, corr)) + print("Correlation over all experiments is {}.".format(corr)) + + # Linear fitting + x_line = np.linspace(min(all_x), max(all_x), 200) + if show_first_order: + # Fit + coef_lin = np.polyfit(all_x, all_y, 1) + + # Residuals + y_lin = np.polyval(coef_lin, all_x) + res_lin = all_y - y_lin + mse_lin = np.mean(res_lin**2) + + # Nice plot + y_lin_line = np.polyval(coef_lin, x_line) + ax_corr.plot(x_line, y_lin_line, color='k') + titre += "\nMSE (linear): {:.1e}".format(mse_lin) + + # Quadratic fitting + if show_second_order: + # Fit + coef_quad = np.polyfit(all_x, all_y, 2) + + # Residuals + y_quad = np.polyval(coef_quad, all_x) + res_quad = all_y - y_quad + mse_quad = np.mean(res_quad ** 2) + + # Nice plot + y_quad_line = np.polyval(coef_quad, x_line) + ax_corr.plot(x_line, y_quad_line, color='k') + titre += "\nMSE (quadratic): {:.1e}".format(mse_quad) + + + ax_corr.set_xlabel(name_log1) + ax_corr.set_ylabel(name_log2) + ax_corr.set_title(titre) + if xlim: + ax_corr.set_xlim(*xlim) + if ylim: + ax_corr.set_ylim(*ylim) # The xlabels and titles overlap # plt.tight_layout() # Makes the subplots very thin. @@ -139,7 +194,11 @@ def main(): name_log2 = args.rename_log2 or args.log2 first_epoch = args.ignore_first_epochs or 0 _compute_correlations(loaded_logs, args.log1, args.log2, - name_log1, name_log2, first_epoch) + name_log1, name_log2, first_epoch, + args.xlim, args.ylim, + args.show_individual_logs, + args.show_first_order_fit, + args.show_second_order_fit) if __name__ == '__main__': diff --git a/src/dwi_ml/cli/tt_visualize_weights.py b/src/dwi_ml/cli/tt_visualize_weights.py index 2af8a13f..94c32231 100644 --- a/src/dwi_ml/cli/tt_visualize_weights.py +++ b/src/dwi_ml/cli/tt_visualize_weights.py @@ -58,7 +58,7 @@ def main(): # 1) Finding the jupyter notebook dwi_ml_dir = dirname(dirname(__file__)) raw_ipynb_filename = os.path.join( - dwi_ml_dir, 'dwi_ml/testing/projects/tt_visualize_weights.ipynb') + dwi_ml_dir, 'projects/Transformers/tester/tt_visualize_weights.ipynb') if not os.path.isfile(raw_ipynb_filename): raise ValueError( "We could not find the jupyter notebook file. Probably a " @@ -68,12 +68,14 @@ def main(): # 2) Verify that output dir exists but not the html output files. args = get_out_dir_and_create(args) - out_html_filename = args.out_prefix + 'tt_bertviz.html' - out_html_file = os.path.join(args.out_dir, out_html_filename) + bertviz_out = os.path.join(args.out_dir, 'bertviz') + out_html_filename = os.path.join( + bertviz_out, args.out_prefix + 'tt_bertviz.html') + out_html_file = os.path.join(bertviz_out, out_html_filename) out_ipynb_file = os.path.join( - args.out_dir, args.out_prefix + 'tt_bertviz.ipynb') + bertviz_out, args.out_prefix + 'tt_bertviz.ipynb') out_config_file = os.path.join( - args.out_dir, args.out_prefix + 'tt_bertviz.config') + bertviz_out, args.out_prefix + 'tt_bertviz.config') assert_outputs_exist(parser, args, [out_html_file, out_ipynb_file, out_config_file]) diff --git a/src/dwi_ml/general/models/main_layers/transformer_sublayers.py b/src/dwi_ml/general/models/main_layers/transformer_sublayers.py index e2c65731..77431b0f 100644 --- a/src/dwi_ml/general/models/main_layers/transformer_sublayers.py +++ b/src/dwi_ml/general/models/main_layers/transformer_sublayers.py @@ -29,10 +29,11 @@ def do_not_share_linear_weights(attn: MultiheadAttention, d_model): # Overriding some parameters in the self attention. # Ugly but.... Torch does not have a parameter to NOT share linear - # weights. In their code, their only NOT share weights when dimensions + # weights. In their code, they only NOT share weights when dimensions # are not the same. This is not our case. This is saved in their # parameter _qkv_same_embed_dim. By changing this, we change their - # forward call to the MultiHeadAttention in self.self_attn. + # forward call to the MultiHeadAttention in self.self_attn. Heavier in + # memory because it thinks they are not the same size. attn._qkv_same_embed_dim = False attn.q_proj_weight = Parameter( torch.empty((d_model, d_model), **factory_kwargs)) diff --git a/src/dwi_ml/general/models/main_layers/transformers_from_torch.py b/src/dwi_ml/general/models/main_layers/transformers_from_torch.py index 6e1813da..30d44adf 100644 --- a/src/dwi_ml/general/models/main_layers/transformers_from_torch.py +++ b/src/dwi_ml/general/models/main_layers/transformers_from_torch.py @@ -29,7 +29,14 @@ def __init__(self, encoder_layer, *args, **kw): raise ValueError("Encoder layer should be of type {}. Got {}" .format(ModifiedTransformerEncoderLayer.__name__, type(encoder_layer))) - super().__init__(encoder_layer, *args, **kw) + + # Need to enforce enable_nested_tensor=False. This is because it thinks + # that the Q, K, V matrices are not the same size. Not true but see + # comment in the ModifiedTransformerEncoderLayer + # Else, we get warnings that it gets set to False automatically. + # Annoying. + super().__init__(encoder_layer, *args, **kw, + enable_nested_tensor=False) def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, diff --git a/src/dwi_ml/general/models/main_models/main_abstract_model.py b/src/dwi_ml/general/models/main_models/main_abstract_model.py index b9c68fb5..946afb4a 100644 --- a/src/dwi_ml/general/models/main_models/main_abstract_model.py +++ b/src/dwi_ml/general/models/main_models/main_abstract_model.py @@ -187,6 +187,7 @@ def load_model_from_params_and_state(cls, model_dir, logger.debug("Loading model from saved parameters:" + format_dict_to_str(params)) params.update(log_level=log_level) + print("Loading a model:", cls.__name__) model = cls(**params) model_state = cls._load_state(model_dir) diff --git a/src/dwi_ml/general/viz/logs_plots.py b/src/dwi_ml/general/viz/logs_plots.py index 1474bff0..3a1d8d8d 100644 --- a/src/dwi_ml/general/viz/logs_plots.py +++ b/src/dwi_ml/general/viz/logs_plots.py @@ -71,7 +71,7 @@ def _plot_one_graph(ax, logs, exp_names, chosen_log_keys, scalar_map, style = log_styles[i % nb_styles] ax.plot([], [], label='{}: {}'.format(log_key, style)) - ax.set_title(title) + ax.set_title(title, fontsize=20) ax.legend() if xlim is not None: ax.set_xlim([0, xlim]) @@ -154,6 +154,7 @@ def visualize_logs(logs_data: Dict[str, Dict[str, np.ndarray]], _plot_one_graph(axs[ax], logs_data, exp_names, graphs_logs[i], scalar_map, writer, remove_outliers, graphs_titles[i], xlim, graphs_ylim[i]) + axs[ax].axhline(0) plt.tight_layout() if save_figs is not None: diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py index 3d944e5f..6dbca4e2 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_argparser.py @@ -1,22 +1,24 @@ # -*- coding: utf-8 -*- """ -TT weights visualisation choices. +TT weights vizualisation choices. + +** Example output names below are for the encoder. The same outputs are named +decoder_ or cross_ for other attention types. + + ** Currently, options bertviz and as_matrices use only the first + streamline in the data. Then, a file _single_streamline.trk is saved. Output options. Choose any number (at least one). - ** Example output names below are for the encoder. The same outputs - are named decoder_ or cross_ for other attention types. - ** See the --prefix option. Below are listed the output suffixes. Suffixes - with * include layer and head information. - ** Currently, options bertviz and as_matrices use only the first - streamline in the data. Then, a file _single_streamline.trk is saved. 1) 'as_matrices': Shows attention as matrices. - Outputs: _matrix_encoder.png - N.B. If bertviz is also chosen, matrices will show in the html file. + 2) 'bertviz': Shows attention using bertviz head_view visualisation. - Outputs: _bertviz.html, _bertviz.ipynb, _bertviz.config - - Will create a html file that can be viewed (see --out_dir) + - The main output is the html file. + 3) 'color_multi_length': Saves a colored tractogram. Streamlines are duplicated at all lengths. Color of the streamline of length n is the weight of each point when getting the next direction at point n. This is @@ -25,6 +27,7 @@ a specific point. - Outputs: _encoder_colored_multi_length_*.trk, _encoder_colored_multi_length_cbar.png + 4) 'color_x_y_summary': Saves two colored tractogram: - Projection on x: This is a measure of the importance of each point on the streamline. @@ -44,6 +47,7 @@ 1 = looked very far. - The projection technique depends on the chosen rescaling options. _encoder_colored_looked_far*.trk + 5) 'bertviz_locally': Run the bertviz without using jupyter. (Debugging purposes. Output will not show, but html stuff will print in the terminal.) @@ -109,6 +113,13 @@ def build_argparser_transformer_visu(): '--out_dir', metavar='d', help="Output directory where to save the output files.\n" "Default: experiment_path/visu_weights") + g.add_argument( + '--cmap', default='jet', + help='A colormap choice recognized by matplotlib.\n' + 'Suggestions: - viridis: from blue to yellow\n' + ' - jet: from blue to yellow to red\n' + ' - gnuplot2: from black to yellow\n' + ' - CMRmap: idem') # -------------- # Options @@ -141,14 +152,10 @@ def build_argparser_transformer_visu(): gg = g.add_mutually_exclusive_group() gg.add_argument('--group_heads', action='store_true', help="If true, average all heads (per layer, per " - "attention type).\n" - "To regroup using maximum instead, use " - "--group_with_max") + "attention type).") gg.add_argument('--group_all', action='store_true', help="If true, average all heads in all layers (per " - "attention type).\n" - "To regroup using maximum instead, use " - "--group_with_max") + "attention type).") g.add_argument('--group_with_max', action='store_true', help="Default grouping option is to average heads. Use " "this option to group \nhead using their maximal " @@ -158,14 +165,12 @@ def build_argparser_transformer_visu(): "(max of the rank use usefullness).") g = p.add_argument_group("Matrices options") - g.add_argument('--show_now', action='store_true', - help="If set, shows the matrices on screen. Else, only " - "saves them.") g.add_argument('--resample_plots', type=int, metavar='nb', dest='resample_nb', - help="Streamlines will be sampled (nb points) as decided " - "by the model. \nHowever, attention shown as a matrix " - "can be resampled.") + help="Streamlines themselves will be sampled (nb points)\n" + "as decided by the model's hyperparameters.\n" + "However, attention shown as a matrix can be resampled " + "afterwards.") g = add_memory_args(p) g.add_argument('--batch_size', type=int, metavar='n', diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visu_bertviz.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_bertviz.py index c288ccf0..fd04b771 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visu_bertviz.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_bertviz.py @@ -5,7 +5,7 @@ # Currently, with our quite long sequences compared to their example, this # is a bit ugly. -SHOW_MODEL_VIEW = False +SHOW_MODEL_VIEW = True def print_head_view_help(): @@ -17,7 +17,7 @@ def print_head_view_help(): print(" -- Single-click on any of the colored tiles to toggle selection " "of the corresponding attention head.") print(" -- Click on the Layer drop-down to change the model layer " - "(zero-indexed).") + "(zero-indexed).\n\n") def print_model_view_help(): @@ -102,3 +102,6 @@ def encoder_show_model_view(encoder_attention, tokens): tmp_e = [encoder_attention[i][:, head, :, :][:, None, :, :] for i in range(nb_layers)] model_view(encoder_attention=tmp_e, encoder_tokens=tokens) + else: + print("Option not ready, sorry") + diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py index 61ed0b00..bee8b42c 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_colored_sft.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging from copy import deepcopy from typing import Tuple @@ -10,37 +11,68 @@ from scilpy.viz.color import get_lookup_table from dwi_ml.projects.Transformers.tester.tt_visu_utils import ( - get_visu_params_from_options, - prepare_projections_from_options) + get_min_max_from_options, + prepare_projections_from_options, get_rescale_name, + get_explanation_projections) def color_sft_duplicate_lines( sft: StatefulTractogram, lengths, prefix_name: str, attentions_per_line: list, attention_names: Tuple, - average_heads, average_layers, group_with_max, explanation): + average_heads: bool, average_layers: bool, group_with_max: bool, + explanation_rescaling: str, cmap: str): """ Saves the whole weight matrix on streamlines of all lengths. - Output name is: + Parameters + ---------- + sft: StatefulTractogram + The input streamlines + lengths: list + The length of each streamline + prefix_name: str + The saving directory + prefix + attentions_per_line: list[list] + The attention weights + attention_names: Tuple[str] + The attention names, ex, encoder, decoder, cross + average_heads: bool + If True, we will average heads + average_layers: bool + If True, we will average layers. average_heads must also be true. + group_with_max: bool + If True, we will do a max-pooling rather than an average. + explanation_rescaling: str + Text coming from our rescaling function. Will be used to add in the + plot titles. + cmap: str + The chosen colormap + + Returns + ------- + (Nothing. Outputs are saved on disk) + + Saves + ----- prefix_name_colored_sft_encoder_lN_hM.trk, - where N is the layer, M is the head. OR: + where N is the layer, M is the head. OR: prefix_name_colored_sft_encoder_lN_meanH.trk - with option --group_heads (or _maxH). OR: + with option --group_heads (or _maxH). OR: prefix_name_colored_sft_encoder_meanL_meanH.trk - with option --group_all (or _maxL_maxH). + with option --group_all (or _maxL_maxH). """ - # Avoid duplicating the attention in-place. + # Avoid duplicating the attention in-place for the main method. attentions_per_line = deepcopy(attentions_per_line) - # Supposing the same nb layers / head for each type of attention. + # Supposing that we have same nb layers / head for each type of attention. nb_layers = len(attentions_per_line[0][0]) nb_heads = attentions_per_line[0][0][0].shape[0] # Duplicating! - # Using s[0:current_point], so starting at current point 2, to have - # [s0, s1]. Else, cannot visualize a single point. - # (Anyway at point 0: Always looking at point 0 only) + # Using s[0:current_point], so starting at current point 1 (ie #2 in + # the range), to have [s0, s1]. Else, cannot visualize a single point. + # (Anyways at point 0: Always looking at point 0 only) remaining_streamlines = sft.streamlines whole_sft = None for current_point in range(2, max(lengths) + 1): @@ -52,12 +84,12 @@ def color_sft_duplicate_lines( zip(att_type, remaining_streamlines) if len(s) >= current_point] - # Removing shorter streamlines for list of streamlines + # Removing shorter streamlines from the list of streamlines remaining_streamlines = [s for s in remaining_streamlines if len(s) >= current_point] # Saving first part of streamlines, up to current_point: - # = "At current_point: which point did we look at?" + # = "At current_point: which points did we look at?" tmp_sft = sft.from_sft([s[0:current_point] for s in remaining_streamlines], sft) @@ -66,15 +98,18 @@ def color_sft_duplicate_lines( # (Careful. Nibabel force names to be <18 character) # encoder_lX_hX for att_name, att_type in zip(attention_names, attentions_per_line): + + # Looping on layers. If average: layer 0 is actually the average for layer in range(nb_layers): if average_layers: if group_with_max: - layer_prefix = '_maxL' + layer_suffix = '_maxL' else: - layer_prefix = '_meanL' + layer_suffix = '_meanL' else: - layer_prefix = 'l{}'.format(layer) + layer_suffix = 'l{}'.format(layer) + # Looping on heads! If average: head 0 is actually the average for head in range(nb_heads): if average_heads: if group_with_max: @@ -93,7 +128,7 @@ def color_sft_duplicate_lines( 0:current_point][:, None] for line_att in att_type] - dpp_name = att_name + layer_prefix + head_suffix + dpp_name = att_name + layer_suffix + head_suffix tmp_sft.data_per_point[dpp_name] = dpp if whole_sft is None: @@ -101,11 +136,11 @@ def color_sft_duplicate_lines( else: whole_sft = whole_sft + tmp_sft - print(" **The initial {} streamlines were transformed into {} " + print("**The initial {} streamlines were transformed into {} " "streamlines of \n" - " variable lengths. Color for streamline i of length N is the " - "attention's value \n" - " at each point when deciding the next direction at point N." + "variable lengths. Color for streamline i of length N is the " + "attention's \n" + "value at each point when deciding the next direction at point N." .format(len(sft), len(whole_sft.streamlines))) del sft del tmp_sft @@ -126,23 +161,26 @@ def color_sft_duplicate_lines( # Not using a fixed vmax, vmin. Uses the bundle's data. # Easier to view. colored_sft, cbar_fig = _color_sft_from_dpp( - colored_sft, key, prepare_fig=True, title=explanation) + colored_sft, key, cmap=cmap, + prepare_fig=True, title=explanation_rescaling) filename_prefix = prefix_name + '_colored_multi_length_' + key filename_trk = filename_prefix + '.trk' filename_cbar = filename_prefix + '_cbar.png' - print("Saving {} with dpp: {}" - .format(filename_trk, list(colored_sft.data_per_point.keys()))) + logging.info("Saving {} with dpp: {}" + .format(filename_trk, + list(colored_sft.data_per_point.keys()))) save_tractogram(colored_sft, filename_trk, bbox_valid_check=False) plt.savefig(filename_cbar) + plt.close() def color_sft_x_y_projections( sft: StatefulTractogram, prefix_name: str, attentions_per_line: list, attention_names: Tuple, average_heads, average_layers, group_with_max, - rescale_0_1, rescale_non_lin, rescale_z, explanation): + rescale_0_1, rescale_non_lin, rescale_z, explanation, cmap): """ Saves one tractogram per "projection": @@ -154,19 +192,54 @@ def color_sft_x_y_projections( - looked_far - max_pos - nb_looked + + Parameters + ---------- + sft: StatefulTractogram + prefix_name: str + Contains the output directory + prefix + attentions_per_line: list + attention_names: Tuple + average_heads: bool + Only used to format the title + average_layers: bool + Only used to format the title + group_with_max: bool + Only used to format the title + rescale_0_1: bool + Used to choose best colorbar options + rescale_non_lin: bool + Used to choose best colorbar options + rescale_z: bool + Used to choose best colorbar options + explanation: str + Coming from the rescaling. Will be added in the title. + cmap: str + Matplotlib colormap. """ - (options_main, options_range_length, explanation_part2, - rescale_name, thresh) = get_visu_params_from_options( - rescale_0_1, rescale_non_lin, rescale_z) + rescale_name = get_rescale_name(rescale_0_1, rescale_non_lin, rescale_z) + explanation_part2 = get_explanation_projections(rescale_name) + min_max_attention_values, min_max_position = \ + get_min_max_from_options(rescale_name) explanation += '\n' + explanation_part2 + print("\n\n" + explanation_part2) + # Supposing the same nb layers / head for each type of attention. nb_layers = len(attentions_per_line[0][0]) nb_heads = attentions_per_line[0][0][0].shape[0] - for i, att_type in enumerate(attentions_per_line): + # Looping on attentions. Ex: Decoder, encoder, cross. + print("\nProcessing...") + for i, attentions in enumerate(attentions_per_line): + print("Attention: ", attention_names[i]) + + # Looping on layer for layer in range(nb_layers): + print(" Layer: ", layer) + + # Prefix if average_layers: if group_with_max: layer_prefix = '_maxL' @@ -174,7 +247,10 @@ def color_sft_x_y_projections( layer_prefix = '_meanL' else: layer_prefix = 'l{}'.format(layer) + + # Loop on head for head in range(nb_heads): + print(" Head: ", head) if average_heads: if group_with_max: head_suffix = '_maxH' @@ -189,7 +265,7 @@ def color_sft_x_y_projections( all_maxp = [] all_nb_looked = [] for s in range(len(sft.streamlines)): - a = att_type[s][layer][head, :, :] + a = attentions[s][layer][head, :, :] a, mean_att, nb_usage, looked_far, max_p, nb_looked = \ prepare_projections_from_options( a, rescale_0_1, rescale_non_lin, rescale_z) @@ -203,37 +279,46 @@ def color_sft_x_y_projections( filename_prefix = prefix_name + attention_names[i] + \ layer_prefix + head_suffix - # Mean att: not fixing the vmin, vmax. - name = 'x_mean_att' - sft.data_per_point[name] = all_mean_att - _color_sft_from_dpp(sft, name, **options_range_length) - filename_trk = filename_prefix + '_' + name + '.trk' - print("Saving {} with dpp {}" - .format(filename_trk, list(sft.data_per_point.keys()))) - save_tractogram(sft, filename_trk, bbox_valid_check=False) - del sft.data_per_point[name] - del sft.data_per_point['color'] - - # Others: Range is 0 - 1 (where 1 = 100% of streamline length) + # Saving various attentions + # Range is 0 - 1 (where 1 = 100% of streamline length) # Not saving the colorbar. - names = ('x_nb_usage', + names = ('x_mean_att', 'x_importance', 'y_looked_far', 'y_max_pos', 'y_nb_looked') - data = (all_nb_usage, + data = (all_mean_att, all_nb_usage, all_looked_far, all_maxp, all_nb_looked) for name, vectors in zip(names, data): + if 'mean_att' in name: + options = min_max_attention_values + else: + options = min_max_position sft.data_per_point[name] = vectors - _color_sft_from_dpp(sft, name, **options_range_length) + _color_sft_from_dpp(sft, name, cmap=cmap, **options) filename_trk = filename_prefix + '_' + name + '.trk' - print("Saving {} with dpp {}" - .format(filename_trk, - list(sft.data_per_point.keys()))) + logging.info("Saving {} with dpp {}" + .format(filename_trk, + list(sft.data_per_point.keys()))) save_tractogram(sft, filename_trk, bbox_valid_check=False) del sft.data_per_point[name] del sft.data_per_point['color'] def _color_sft_from_dpp(sft, key, cmap='viridis', vmin=None, vmax=None, - prepare_fig: bool = False, title=None, **kw): + prepare_fig: bool = False, title=None): + """ + Parameters + ---------- + sft : Tractogram + key: str + The name of the data_per_point to transform into a color + vmin: float + For the colorbar + vmax: float + For the colorbar + prepare_fig: bool + If True, prepare a figure with the colorbar + title: str + Title for the figure + """ cmap = get_lookup_table(cmap) tmp = [np.squeeze(sft.data_per_point[key][s]) for s in range(len(sft))] data = np.hstack(tmp) @@ -249,14 +334,18 @@ def _color_sft_from_dpp(sft, key, cmap='viridis', vmin=None, vmax=None, # Preparing a figure fig = None if prepare_fig: - fig = plt.figure(figsize=(9, 3)) - plt.imshow(np.array([[1, 0]]), cmap=cmap, vmin=vmin, vmax=vmax) - plt.gca().set_visible(False) - cax = plt.axes([0.1, 0.1, 0.6, 0.85]) - plt.colorbar(orientation="horizontal", cax=cax, aspect=0.01) + fig, ax = plt.subplots(figsize=(3, 6)) # tall figure for vertical bar + norm = plt.Normalize(vmin=vmin, vmax=vmax) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + + fig.colorbar(sm, ax=ax, orientation='vertical') if title is not None: - plt.title("Colorbar for key: {}\n".format(key) + title) + fig.suptitle("Colorbar for key: {}\n".format(key) + title) else: - plt.title("Colorbar for key: {}".format(key)) + fig.suptitle("Colorbar for key: {}".format(key)) + + ax.remove() # remove empty axes + fig.tight_layout() return sft, fig diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visu_main.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_main.py index 3aafaa1b..a2629112 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visu_main.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_main.py @@ -39,19 +39,23 @@ def tt_visualize_weights_main(args, parser): loads the models, runs it to get the attention, and calls the right visu method. """ + logging.getLogger().setLevel(level=args.verbose) + logging.getLogger('PIL').setLevel(logging.WARNING) + logging.getLogger('PIL.PngImagePlugin').setLevel(logging.WARNING) + logging.getLogger('matplotlib.colorbar').setLevel(logging.WARNING) + logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING) + + # ------ Finalize parser verification - if not (args.as_matrices or args.bertviz or args.colored_multi_length or - args.colored_x_y_summary or args.bertviz_locally): + if not (args.as_matrices or args.bertviz or args.color_multi_length or + args.color_x_y_summary or args.bertviz_locally): parser.error("Expecting at least one visualisation option.") if args.resample_nb is not None and \ not (args.as_matrices or args.bertviz or args.bertviz_locally): logging.warning("We only resample attention when visualizing matrices " - "or bertviz. Not required with current visualization " - "options. Ignoring.") - - average_heads = args.group_heads or args.group_all - average_layers = args.group_all + "or bertviz. Option --resample_plots not required " + "with current visualization options. Ignoring.") # -------- Verify inputs and outputs assert_inputs_exist(parser, [args.hdf5_file, args.in_sft]) @@ -62,21 +66,11 @@ def tt_visualize_weights_main(args, parser): # Whole filenames depend on rescaling options and grouping option. # Using the prefix to find any output. args = get_out_dir_and_create(args) - prefix_total = os.path.join(args.out_dir, args.out_prefix) + prefix_total = os.path.join(args.out_dir, '*', args.out_prefix) out_files = glob.glob(prefix_total + '*colored*.trk') + \ glob.glob(prefix_total + '*.png') assert_outputs_exist(parser, args, out_files) - if args.overwrite and len(out_files) > 0: - # logging.warning("Removing these files from a previous run: {}" - # .format(out_files)) - for f in out_files: - if os.path.isfile(f): - os.remove(f) - - sub_logger_level = 'WARNING' - logging.getLogger().setLevel(level=args.verbose) - if args.use_gpu: if torch.cuda.is_available(): logging.debug("We will be using GPU!") @@ -87,23 +81,47 @@ def tt_visualize_weights_main(args, parser): else: device = torch.device('cpu') - # ------------ Ok. Loading and formatting attention. + # Verify now that suggested cmap exists + _ = plt.get_cmap(args.cmap) + + # ------------ Ok. Loading and running transformer on batches + logging.debug("All inputs look ok. Running the Transformer!") sft, model, weights = _run_transformer_get_weights( - parser, args, sub_logger_level, device) + parser, args, sub_logger_level='WARNING', device=device) + logging.debug("Done! Preparing visualization of the weights!") # ------------ Now show all + average_heads = args.group_heads or args.group_all + average_layers = args.group_all _visu_encoder_decoder( - weights, sft, model, average_heads, average_layers, args, prefix_total) - - if args.show_now: - print("Showing matplotlib figures now. Everything is done, you can " - "close figures manually or enter ctrl + c safely.") - plt.show() + weights, sft, model, average_heads, average_layers, args) def _run_transformer_get_weights(parser, args, sub_logger_level, device): - # 1. Load model - logging.debug("Loading the model") + """ + Runs the transformer model on the input tractogram. + + Parameters + ---------- + parser: ArgParser + args: Namespace + Contains the main parameters for this function: experiment path and + input tractogram, in particular. + sub_logger_level: str + device: torch.device + + Returns + ------- + sft: StatefulTractogram + The tractogram, formatted as required by the model (ex, resampling) and + by options in args (ex, reverse streamlines). + model: AbstractTransformerModel + The loaded model + weights: Tuple[list] + The attention weights + """ + # 1. Load the model + logging.info("Loading the model") if args.use_latest_epoch: model_dir = os.path.join(args.experiment_path, 'checkpoint/model') else: @@ -116,7 +134,7 @@ def _run_transformer_get_weights(parser, args, sub_logger_level, device): model_dir, log_level=sub_logger_level) # 2. Load SFT - logging.info("Loading tractogram. Note that space comptability " + logging.info("Loading then tractogram. Note that space compatibility " "with training data will NOT be verified.") args.bbox_check = False sft = load_tractogram_with_reference(parser, args, args.in_sft) @@ -128,11 +146,9 @@ def _run_transformer_get_weights(parser, args, sub_logger_level, device): if len(sft) > 1 and not ( args.color_multi_length or args.color_x_y_summary): # Taking only one streamline - line_id = 0 - logging.info(" Picking THE FIRST streamline ONLY to show with " - "bertviz / show as matrices: #{} / {}." - .format(line_id, len(sft))) - sft = sft[[line_id]] + logging.warning("Picking THE FIRST streamline ONLY to show with " + "bertviz / show as matrices") + sft = sft[[0]] if args.reverse_lines: sft.streamlines = [np.flip(line, axis=0) for line in sft.streamlines] @@ -159,7 +175,7 @@ def _run_transformer_get_weights(parser, args, sub_logger_level, device): def _visu_encoder_decoder( weights: Tuple, sft: StatefulTractogram, model, - average_heads: bool, average_layers: bool, args, prefix_name: str): + average_heads: bool, average_layers: bool, args): """ Parameters ---------- @@ -177,8 +193,10 @@ def _visu_encoder_decoder( average_layers: bool, Argparser's default = False. Must be False if average_head is False. args: Namespace - prefix_name: str - Includes the output path. + + Returns + ------- + (nothing. Saves outputs on disk) """ if len(weights) == 3: has_decoder = True @@ -188,18 +206,29 @@ def _visu_encoder_decoder( # 1. Prepare the streamlines has_eos = model.direction_getter.add_eos if not has_eos: - logging.warning("No EOS in model. Will ignore the last point per " + logging.warning("No EOS in model. We will ignore the last point per " "streamline") sft.streamlines = [s[:-1, :] for s in sft.streamlines] lengths = [len(s) for s in sft.streamlines] # 2. Arrange the weights weights = list(weights) - explanation = None + print( + "\n\n-------------- Found the following architecture: --------------") + print("Number of layers: {}".format(len(weights[0]))) + print("Number of heads: {}".format(weights[0][0].shape[1])) + print("Display options: Average heads: {}. Average layers: {}" + .format(average_heads, average_layers)) + + print("\n============== Rescaling the attention with option ==============") + explanation_rescale = None for i in range(len(weights)): - weights[i], explanation = reshape_unpad_rescale_attention( + weights[i], explanation_rescale = reshape_unpad_rescale_attention( weights[i], average_heads, average_layers, args.group_with_max, lengths, args.rescale_0_1, args.rescale_z, args.rescale_non_lin) + print(explanation_rescale) + + print("\n\n================ Visu on the whole tractogram ==============") if has_decoder: attention_names = ('encoder', 'decoder', 'cross') @@ -208,36 +237,42 @@ def _visu_encoder_decoder( if args.color_multi_length: print( - "\n-------------- Preparing the colors for each length of " + "\n\n-------------- Preparing the colors for each length of " "each streamline --------------") + prefix_name = os.path.join(args.out_dir, 'color_multi_length', + args.out_prefix) color_sft_duplicate_lines(sft, lengths, prefix_name, weights, attention_names, average_heads, average_layers, args.group_with_max, - explanation) + explanation_rescale, args.cmap) if args.color_x_y_summary: print( - "\n-------------- Preparing the colors summary (nb_usage, " + "\n\n-------------- Preparing the colors summary (nb_usage, " "where looked, etc) for each streamline --------------") + prefix_name = os.path.join(args.out_dir, 'color_x_y_summary', args.out_prefix) color_sft_x_y_projections( sft, prefix_name, weights, attention_names, average_heads, average_layers, args.group_with_max, args.rescale_0_1, args.rescale_non_lin, args.rescale_z, - explanation) + explanation_rescale, args.cmap) if args.bertviz or args.as_matrices: + # If only those options selected, we already chose one streamline before + # running the whole model to make it faster. Else, everything on the + # full tractogram is already done, we can continue with only one + # streamline. if args.color_multi_length or args.color_x_y_summary: # Taking only one streamline. Was not done yet. - line_id = 0 logging.info(" Picking THE FIRST streamline ONLY to show with " - "bertviz / show as matrices: #{} / {}." - .format(line_id, len(sft))) - sft = sft[[line_id]] - # Else we already chose one streamline before running the whole model. - - name = prefix_name + '_single_streamline.trk' - print("Saving the single line used for matrices, for debugging " - "purposes, as ", name) + "bertviz / show as matrices") + sft = sft[[0]] + + print("\n\n================ Visu on a single streamline ============== \n" + "Saving the single line used for matrices / bertviz, for debugging " + "purposes.\n") + name = os.path.join(args.out_dir, args.out_prefix) + '_single_streamline.trk' + logging.info("Saved as {}".format(name)) save_tractogram(sft, name, bbox_valid_check=False) this_seq_len = lengths[0] @@ -261,13 +296,15 @@ def _visu_encoder_decoder( print( "\n\n-------------- Preparing the attention as a matrix for " "one streamline --------------") + prefix_name = os.path.join(args.out_dir, 'as_matrices', args.out_prefix) for i in range(len(weights)): print("Matrix for ", attention_names[i]) + show_model_view_as_imshow( weights[i], prefix_name + '_matrix_' + attention_names[i], *weights_token[i], args.rescale_0_1, args.rescale_z, args.rescale_non_lin, average_heads, average_layers, - args.group_with_max) + args.group_with_max, args.cmap) if args.bertviz or args.bertviz_locally: print( @@ -279,10 +316,19 @@ def _visu_encoder_decoder( for att in weights[i]] if has_decoder: + print( + "\n\n---- Head view: ----") encoder_decoder_show_head_view( *weights, encoder_tokens, decoder_tokens) + + print( + "\n\n---- Model view: ----") encoder_decoder_show_model_view( *weights, encoder_tokens, decoder_tokens) else: + print( + "\n\n---- Head view: ----") encoder_show_head_view(*weights, encoder_tokens) + print( + "\n\n---- Model view: ----") encoder_show_model_view(*weights, encoder_tokens) diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py index 8952c72b..8c235b37 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_matrix.py @@ -1,73 +1,83 @@ # -*- coding: utf-8 -*- +import logging + import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable from dwi_ml.projects.Transformers.tester.tt_visu_utils import ( - get_visu_params_from_options, - prepare_projections_from_options) + get_min_max_from_options, + prepare_projections_from_options, get_rescale_name, + get_explanation_projections) def show_model_view_as_imshow( attention_one_line, fig_prefix, tokens_x, tokens_y, rescale_0_1, rescale_z, rescale_non_lin, - average_heads, average_layers, group_with_max): + average_heads, average_layers, group_with_max, cmap): nb_layers = len(attention_one_line) - size_x = len(tokens_x) - size_y = len(tokens_y) - - (options_main, options_range_length, explanation, rescale_name, - thresh) = get_visu_params_from_options( - rescale_0_1, rescale_non_lin, rescale_z) + rescale_name = get_rescale_name(rescale_0_1, rescale_non_lin, rescale_z) + explanation = get_explanation_projections(rescale_name) + min_max_attention_values, min_max_position = \ + get_min_max_from_options(rescale_name) for i in range(nb_layers): - att = attention_one_line[i] - nb_heads = att.shape[0] + layer_att = attention_one_line[i] + nb_heads = layer_att.shape[0] fig, axs = plt.subplots(1, nb_heads, figsize=(20, 8), layout='compressed') if nb_heads == 1: axs = [axs] for h in range(nb_heads): - a, mean_att, importance, looked_far, max_pos, nb_looked = \ + head_att, mean_att, importance, looked_far, max_pos, nb_looked = \ prepare_projections_from_options( - att[h, :, :], rescale_0_1, rescale_non_lin, rescale_z) + layer_att[h, :, :], rescale_0_1, rescale_non_lin, rescale_z) divider = make_axes_locatable(axs[h]) - ax_mean_att = divider.append_axes("bottom", size=0.2, pad=0) - ax_importance = divider.append_axes("bottom", size=0.2, pad=0) + ax_mean_att = divider.append_axes("bottom", size=0.2, pad=0.1) + ax_importance = divider.append_axes("bottom", size=0.2, pad=0.1) ax_lookedfar = divider.append_axes("right", size=0.2, pad=0) ax_max = divider.append_axes("right", size=0.2, pad=0) ax_nb_looked = divider.append_axes("right", size=0.2, pad=0) ax_cbar_main = divider.append_axes("right", size=0.3, pad=0.3) - ax_cbar_length = divider.append_axes("right", size=0.3, pad=0.55) # Plot the main image - im_main = axs[h].imshow(a, **options_main) + im_main = axs[h].imshow(head_att, + **min_max_attention_values, cmap=cmap, + interpolation='None') # Bottom and right images _ = ax_mean_att.imshow(mean_att[None, :], - **options_main, aspect='auto') - im_b = ax_importance.imshow(importance[None, :], - **options_range_length, aspect='auto') + **min_max_attention_values, cmap=cmap, + aspect='auto', interpolation='None') + _ = ax_importance.imshow(importance[None, :], + **min_max_position, cmap=cmap, + aspect='auto', interpolation='None') _ = ax_lookedfar.imshow(looked_far[:, None], - **options_range_length, aspect='auto') + **min_max_position, cmap=cmap, + aspect='auto', interpolation='None') _ = ax_max.imshow(max_pos[:, None], - **options_range_length, aspect='auto') + **min_max_position, cmap=cmap, + aspect='auto', interpolation='None') _ = ax_nb_looked.imshow(nb_looked[:, None], - **options_range_length, aspect='auto') + **min_max_position, cmap=cmap, + aspect='auto', interpolation='None') # Set the titles (see also suptitle below) if average_heads: if group_with_max: axs[h].set_title("Max of ({}) heads" .format(rescale_name)) + head_suffix = "_maxOfHeads" else: axs[h].set_title("Average of heads, {}" .format(rescale_name)) + head_suffix = "_meanHead" else: axs[h].set_title("Head {}".format(h)) + head_suffix = "_allHeads" # Titles proj X ax_mean_att.set_ylabel("Mean", rotation=0, labelpad=25) @@ -77,14 +87,8 @@ def show_model_view_as_imshow( ax_lookedfar.set_title("Looked far", rotation=45, loc='left') ax_max.set_title("Max pos", rotation=45, loc='left') ax_nb_looked.set_title("Nb looked", rotation=45, loc='left') - # ("Importance" is a bit too close to last tick. Tried to use - # loc='bottom' but then ignores labelpad). - - # Main image: set the ticks with tokens. - axs[h].set_xticks(np.arange(size_x), fontsize=10) - axs[h].set_yticks(np.arange(size_y), fontsize=10) - axs[h].set_xticklabels(tokens_x, rotation=-90) - axs[h].set_yticklabels(tokens_y) + axs[h].set_xlabel("The points that the attention looks at") + axs[h].set_ylabel("The current tractography point") # Move x ticks under the projections axs[h].tick_params(axis='x', pad=40) @@ -95,13 +99,8 @@ def show_model_view_as_imshow( plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) - # Set the colorbars, with titles. - # ToDo. Colorbar mean_att. + # Colorbar fig.colorbar(im_main, cax=ax_cbar_main) - ax_cbar_main.set_ylabel('Main figure', rotation=90, labelpad=-55) - fig.colorbar(im_b, cax=ax_cbar_length) - ax_cbar_length.set_ylabel('x / y projections: %% of length', - rotation=90, labelpad=-55) if average_layers: if group_with_max: @@ -117,6 +116,8 @@ def show_model_view_as_imshow( plt.suptitle("Layer: {}\n{}" .format(layer_title, explanation)) - name = fig_prefix + layer_name + '.png' - print("Saving matrix : {}".format(name)) + name = fig_prefix + layer_name + head_suffix + '.png' + print("Saving matrix : {}".format(layer_name + head_suffix)) + logging.info("Saved as {}".format(name)) plt.savefig(name) + plt.close() diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visu_utils.py b/src/dwi_ml/projects/Transformers/tester/tt_visu_utils.py index e4ad0743..eed3cc3a 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visu_utils.py +++ b/src/dwi_ml/projects/Transformers/tester/tt_visu_utils.py @@ -5,6 +5,7 @@ import numpy as np from scipy.ndimage import zoom +from sklearn.utils.sparsefuncs import min_max_axis from tqdm import tqdm from scilpy import get_home as get_scilpy_folder @@ -52,8 +53,6 @@ def reshape_unpad_rescale_attention( [nheads, length, length] Where nheads=1 if average_heads. """ - logging.info("Arranging attention based on visualisation options " - "(rescaling, averaging, etc.)") explanation = '' # 1. To numpy. Possibly average heads. @@ -66,35 +65,44 @@ def reshape_unpad_rescale_attention( # Averaging heads (but keeping 4D). if average_heads and not group_with_max: + logging.debug("Averaging heads on layer {}".format(layer)) attention_per_layer[layer] = np.mean(attention_per_layer[layer], axis=1, keepdims=True) # Possibly average layers (but keeping as list) if average_layers and not group_with_max: + logging.debug("Averaging layers.") attention_per_layer = [np.mean(attention_per_layer, axis=0)] nb_layers = 1 explanation += "Attention of each layer were then averaged.\n" - # 2. Rearrange attention into one list per line, unpadded, rescaled. + # Preparing explanations... if rescale_0_1: + logging.debug("Preparing to rescale! (0-1)") explanation += ("For each streamline, attention at each point (each " "row of the matrix) \nhas been " "rescaled between 0-1: X = X / max(row)") elif rescale_z: + logging.debug("Preparing to rescale! (z-score)") explanation += ("For each streamline, attention at each point (each " "row of the matrix) \nhas been " "rescaled to a z-score: X = (X-mu) / std") elif rescale_non_lin: + logging.debug("Preparing to rescale! (non-linear option)") explanation += ("For each streamline, attention at each point (each " "row of the matrix) \nhas been " "rescaled so that 0.5 is an average point.") + + # 2. Rearrange attention into one list per line, unpadded, rescaled. + # Must do it line by line to unpad correctly. attention_per_line = [] + print("Unpadding each line (and rescaling if asked):") for line in tqdm(range(len(lengths)), total=len(lengths), - desc="Rearranging, unpadding, rescaling (if asked)", maxinterval=3): attention_per_line.append([None] * nb_layers) for layer in range(nb_layers): - # 1. Taking one streamline, unpadding. + + # 1. Taking the right streamline, unpadding. line_att = attention_per_layer[layer][line, :, :, :] line_att = line_att[:, 0:lengths[line], 0:lengths[line]] @@ -145,7 +153,7 @@ def reshape_unpad_rescale_attention( line_att = tmp1 * where_below + tmp2 * where_above if average_heads and group_with_max: - explanation += "We then the maximal value through heads.\n" + explanation += "We then kept the maximal value through heads.\n" line_att = np.max(line_att, axis=0, keepdims=True) attention_per_line[-1][layer] = line_att @@ -154,6 +162,9 @@ def reshape_unpad_rescale_attention( explanation += "We then kept the maximal value trough layers." attention_per_line[-1] = [np.max(attention_per_line[-1], axis=0)] + if explanation == '': + explanation = ("No change done! Attention is exactly as produced by " + "the Transformer!") logging.info(explanation) return attention_per_line, explanation @@ -206,47 +217,57 @@ def prepare_encoder_tokens(this_seq_len, step_size, add_eos: bool): return encoder_tokens -def get_visu_params_from_options(rescale_0_1, rescale_non_lin, rescale_z): - """ - Defines options for prefix names, colormaps, vmin, vmax, explanation text, - etc. - """ - vmin_main, vmax_main, cmap_main = (0, 1, 'turbo') - vmin_pos, vmax_pos, cmap_pos = (0, 1, 'CMRmap') +def get_rescale_name(rescale_0_1, rescale_non_lin, rescale_z): if rescale_0_1: rescale_name = 'rescale_0_1' elif rescale_non_lin: rescale_name = 'rescale_non_lin' - # cmap_main = 'coolwarm' elif rescale_z: rescale_name = 'rescale_z' - # Range: We could limit it to help view better. Ex: ±3 std. - vmin_main = -3 - vmax_main = 3 else: rescale_name = 'None' + return rescale_name + +def get_explanation_projections(rescale_name): thresh = THRESH_IMPORTANT[rescale_name] explanation = ( - 'Importance: Number of times that this point was very important ' - '(>{:.2f}).\n' + 'Mean_attention: The average attention at that point when tracking in general\n' + 'Importance: Number of times that this point was very important (i.e. >{:.2f}).\n' "Looked far: Mean index of the important points (>{:.2f}) to decide " - "the next direction. 0 = current point. 100%% = very far behind.\n" + "the next direction. 0 = current point. 100\% = very far behind.\n" "Max_pos: Index of the point of maximal attention.\n" "Nb_looked: Number of points of important attention." .format(thresh, thresh)) + return explanation - options_main = {'interpolation': 'None', - 'cmap': cmap_main, - 'vmin': vmin_main, - 'vmax': vmax_main} - options_position = {'interpolation': 'None', - 'cmap': cmap_pos, - 'vmin': vmin_pos, - 'vmax': vmax_pos} +def get_min_max_from_options(rescale_name): + """ + Defines options for prefix names, colormaps, vmin, vmax, explanation text, + etc. - return options_main, options_position, explanation, rescale_name, thresh + Returns + ------- + options_imshow: dict + Useful for our plt showing. + + options_cmap: dict + Useful for the preparation of colormaps. + """ + vmin_main, vmax_main = (0, 1) + if rescale_name == 'rescale_z': + vmin_main = -3 + vmax_main = 3 + + min_max_rescaled_att = {'vmin': vmin_main, + 'vmax': vmax_main} + + # Currently, always 0-1 + min_max1_position = {'vmin': 0, + 'vmax': 1} + + return min_max_rescaled_att, min_max1_position def prepare_projections_from_options(a, rescale_0_1, rescale_non_lin, @@ -305,8 +326,25 @@ def get_config_filename(): def get_out_dir_and_create(args): # Define out_dir as experiment_path/visu_weights if not defined. # Create it if it does not exist. + + # If ran through the jupyter notebook, ran twice. So we need to kepp + # all the "if does not exist" if args.out_dir is None: args.out_dir = os.path.join(args.experiment_path, 'visu_weights') if not os.path.isdir(args.out_dir): os.mkdir(args.out_dir) + + def create_if_not_exist(path): + if not os.path.exists(path): + os.mkdir(path) + + if args.as_matrices: + create_if_not_exist(os.path.join(args.out_dir, 'as_matrices')) + if args.bertviz: + create_if_not_exist(os.path.join(args.out_dir, 'bertviz')) + if args.color_multi_length: + create_if_not_exist(os.path.join(args.out_dir, 'color_multi_length')) + if args.color_x_y_summary: + create_if_not_exist(os.path.join(args.out_dir, 'color_x_y_summary')) + return args diff --git a/src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb b/src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb index aaf60ee3..a3dbc53b 100644 --- a/src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb +++ b/src/dwi_ml/projects/Transformers/tester/tt_visualize_weights.ipynb @@ -17,20 +17,19 @@ }, { "cell_type": "code", - "execution_count": null, "id": "a9eb63a1", "metadata": {}, - "outputs": [], "source": [ "\n", "import os\n", "import sys\n", "\n", - "from scilpy.io.fetcher import get_home as get_scilpy_folder\n", - "\n", - "from dwi_ml.general.testing.projects import \\\n", - " (build_argparser_transformer_visu, get_config_filename, tt_visualize_weights_main) \n" - ] + "from dwi_ml.projects.Transformers.tester.tt_visu_utils import get_config_filename\n", + "from dwi_ml.projects.Transformers.tester.tt_visu_argparser import build_argparser_transformer_visu\n", + "from dwi_ml.projects.Transformers.tester.tt_visu_main import tt_visualize_weights_main" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -58,7 +57,8 @@ "print(\"RUNNING WEIGHTS VISUALIZATION FOR TRACTOGRAPHYTRANSFORMER'S MODEL\")\n", "parser = build_argparser_transformer_visu()\n", "args = parser.parse_args()\n", - "tt_visualize_weights_main(args, parser)" + "tt_visualize_weights_main(args, parser)\n", + "\n" ] } ], diff --git a/src/dwi_ml/projects/Transformers/transformer_models.py b/src/dwi_ml/projects/Transformers/transformer_models.py index 2b21b09e..bf138257 100644 --- a/src/dwi_ml/projects/Transformers/transformer_models.py +++ b/src/dwi_ml/projects/Transformers/transformer_models.py @@ -552,16 +552,10 @@ def __init__(self, **kw): dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) - # Receiving weird warning: enable_nested_tensor is True, - # but self.use_nested_tensor is False because encoder_layer.norm_first - # was True. - enable_nested = False if self.norm_first else True - # Note about norm: this is a final normalization step. Not linked to # the normalization decided with self.norm_first. self.modified_torch_transformer = ModifiedTransformerEncoder( - main_layer_encoder, self.n_layers_e, norm=None, - enable_nested_tensor=enable_nested) + main_layer_encoder, self.n_layers_e, norm=None) @property def d_model(self): @@ -771,16 +765,10 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw): activation=self.activation, batch_first=True, norm_first=self.norm_first) - # Receiving weird warning: enable_nested_tensor is True, - # but self.use_nested_tensor is False because encoder_layer.norm_first - # was True. - enable_nested = False if self.norm_first else True - # Note about norm: this is a final normalization step. Not linked to # the normalization decided with self.norm_first. encoder = ModifiedTransformerEncoder( - encoder_layer, self.n_layers_e, norm=None, - enable_nested_tensor=enable_nested) + encoder_layer, self.n_layers_e, norm=None) # Decoder decoder_layer = ModifiedTransformerDecoderLayer(