1313
1414import os
1515import time
16- from collections .abc import Callable , Sequence
16+ import warnings
17+ from collections .abc import Callable , Mapping , Sequence
1718from pathlib import Path
1819from typing import TYPE_CHECKING , Any
1920
2021import torch
22+ from torch .utils .data import Dataset
2123
2224from monai .config import IgniteInfo
23- from monai .utils import ensure_tuple , min_version , optional_import
25+ from monai .utils import CommonKeys , ensure_tuple , min_version , optional_import
2426
2527Events , _ = optional_import ("ignite.engine" , IgniteInfo .OPT_IMPORT_VERSION , min_version , "Events" )
2628mlflow , _ = optional_import ("mlflow" , descriptor = "Please install mlflow before using MLFlowHandler." )
2729mlflow .entities , _ = optional_import (
2830 "mlflow.entities" , descriptor = "Please install mlflow.entities before using MLFlowHandler."
2931)
32+ pandas , _ = optional_import ("pandas" , descriptor = "Please install pandas for recording the dataset." )
33+ tqdm , _ = optional_import ("tqdm" , "4.47.0" , min_version , "tqdm" )
3034
3135if TYPE_CHECKING :
3236 from ignite .engine import Engine
@@ -72,6 +76,14 @@ class MLFlowHandler:
7276 Must accept parameter "engine", use default logger if None.
7377 iteration_logger: customized callable logger for iteration level logging with MLFlow.
7478 Must accept parameter "engine", use default logger if None.
79+ dataset_logger: customized callable logger to log the dataset information with MLFlow.
80+ Must accept parameter "dataset_dict", use default logger if None.
81+ dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch
82+ dataset, that needs to be recorded. This arg is only useful when MLFlow version >= 2.4.0.
83+ For more details about how to log data with MLFlow, please go to the website:
84+ https://mlflow.org/docs/latest/python_api/mlflow.data.html.
85+ dataset_keys: a key or a collection of keys to indicate contents in the dataset that
86+ need to be stored by MLFlow.
7587 output_transform: a callable that is used to transform the
7688 ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}.
7789 By default this value logging happens when every iteration completed.
@@ -111,6 +123,9 @@ def __init__(
111123 epoch_log : bool | Callable [[Engine , int ], bool ] = True ,
112124 epoch_logger : Callable [[Engine ], Any ] | None = None ,
113125 iteration_logger : Callable [[Engine ], Any ] | None = None ,
126+ dataset_logger : Callable [[Mapping [str , Dataset ]], Any ] | None = None ,
127+ dataset_dict : Mapping [str , Dataset ] | None = None ,
128+ dataset_keys : str = CommonKeys .IMAGE ,
114129 output_transform : Callable = lambda x : x [0 ],
115130 global_epoch_transform : Callable = lambda x : x ,
116131 state_attributes : Sequence [str ] | None = None ,
@@ -126,6 +141,7 @@ def __init__(
126141 self .epoch_log = epoch_log
127142 self .epoch_logger = epoch_logger
128143 self .iteration_logger = iteration_logger
144+ self .dataset_logger = dataset_logger
129145 self .output_transform = output_transform
130146 self .global_epoch_transform = global_epoch_transform
131147 self .state_attributes = state_attributes
@@ -140,6 +156,8 @@ def __init__(
140156 self .close_on_complete = close_on_complete
141157 self .experiment = None
142158 self .cur_run = None
159+ self .dataset_dict = dataset_dict
160+ self .dataset_keys = ensure_tuple (dataset_keys )
143161
144162 def _delete_exist_param_in_dict (self , param_dict : dict ) -> None :
145163 """
@@ -210,6 +228,11 @@ def start(self, engine: Engine) -> None:
210228 self ._delete_exist_param_in_dict (attrs )
211229 self ._log_params (attrs )
212230
231+ if self .dataset_logger :
232+ self .dataset_logger (self .dataset_dict )
233+ else :
234+ self ._default_dataset_log (self .dataset_dict )
235+
213236 def _set_experiment (self ):
214237 experiment = self .experiment
215238 if not experiment :
@@ -222,6 +245,36 @@ def _set_experiment(self):
222245 raise ValueError (f"Cannot set a deleted experiment '{ self .experiment_name } ' as the active experiment" )
223246 self .experiment = experiment
224247
248+ @staticmethod
249+ def _get_pandas_dataset_info (pandas_dataset ):
250+ dataset_name = pandas_dataset .name
251+ return {
252+ f"{ dataset_name } _digest" : pandas_dataset .digest ,
253+ f"{ dataset_name } _samples" : pandas_dataset .profile ["num_rows" ],
254+ }
255+
256+ def _log_dataset (self , sample_dict : dict [str , Any ], context : str = "train" ) -> None :
257+ if not self .cur_run :
258+ raise ValueError ("Current Run is not Active to log the dataset" )
259+
260+ # Need to update the self.cur_run to sync the dataset log, otherwise the `inputs` info will be out-of-date.
261+ self .cur_run = self .client .get_run (self .cur_run .info .run_id )
262+ logged_set = [x for x in self .cur_run .inputs .dataset_inputs if x .dataset .name .startswith (context )]
263+ # In case there are datasets with the same name.
264+ dataset_count = str (len (logged_set ))
265+ dataset_name = f"{ context } _dataset_{ dataset_count } "
266+ sample_df = pandas .DataFrame (sample_dict )
267+ dataset = mlflow .data .from_pandas (sample_df , name = dataset_name )
268+ exist_dataset_list = list (
269+ filter (lambda x : x .dataset .digest == dataset .digest , self .cur_run .inputs .dataset_inputs )
270+ )
271+
272+ if not len (exist_dataset_list ):
273+ datasets = [mlflow .entities .DatasetInput (dataset ._to_mlflow_entity ())]
274+ self .client .log_inputs (run_id = self .cur_run .info .run_id , datasets = datasets )
275+ dataset_info = MLFlowHandler ._get_pandas_dataset_info (dataset )
276+ self ._log_params (dataset_info )
277+
225278 def _log_params (self , params : dict [str , Any ]) -> None :
226279 if not self .cur_run :
227280 raise ValueError ("Current Run is not Active to log params" )
@@ -352,3 +405,61 @@ def _default_iteration_log(self, engine: Engine) -> None:
352405 for i , param_group in enumerate (cur_optimizer .param_groups )
353406 }
354407 self ._log_metrics (params , step = engine .state .iteration )
408+
409+ def _default_dataset_log (self , dataset_dict : Mapping [str , Dataset ] | None ) -> None :
410+ """
411+ Execute dataset log operation based on the input dataset_dict. The dataset_dict should have a format
412+ like:
413+ {
414+ "dataset_name0": dataset0,
415+ "dataset_name1": dataset1,
416+ ......
417+ }
418+ The keys stand for names of datasets, which will be logged as prefixes of dataset names in MLFlow.
419+ The values are PyTorch datasets from which sample names are abstracted to build a Pandas DataFrame.
420+ If the input dataset_dict is None, this function will directly return and do nothing.
421+
422+ To use this function, every sample in the input datasets must contain keys specified by the `dataset_keys`
423+ parameter.
424+ This function will log a PandasDataset to MLFlow inputs, generated from the Pandas DataFrame.
425+ For more details about PandasDataset, please refer to this link:
426+ https://mlflow.org/docs/latest/python_api/mlflow.data.html#mlflow.data.pandas_dataset.PandasDataset
427+
428+ Please note that it may take a while to record the dataset if it has too many samples.
429+
430+ Args:
431+ dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch
432+ dataset, that needs to be recorded.
433+
434+ """
435+
436+ if dataset_dict is None :
437+ return
438+ elif len (dataset_dict ) == 0 :
439+ warnings .warn ("There is no dataset to log!" )
440+
441+ # Log datasets to MLFlow one by one.
442+ for dataset_type , dataset in dataset_dict .items ():
443+ if dataset is None :
444+ raise AttributeError (f"The { dataset_type } dataset of is None. Cannot record it by MLFlow." )
445+
446+ sample_dict : dict [str , list [str ]] = {}
447+ dataset_samples = getattr (dataset , "data" , [])
448+ for sample in tqdm (dataset_samples , f"Recording the { dataset_type } dataset" ):
449+ for key in self .dataset_keys :
450+ if key not in sample_dict :
451+ sample_dict [key ] = []
452+
453+ if key in sample :
454+ value_to_log = sample [key ]
455+ else :
456+ raise KeyError (f"Unexpect key '{ key } ' in the sample." )
457+
458+ if not isinstance (value_to_log , str ):
459+ warnings .warn (
460+ f"Expected type string, got type { type (value_to_log )} of the { key } name."
461+ "May log an empty dataset in MLFlow"
462+ )
463+ else :
464+ sample_dict [key ].append (value_to_log )
465+ self ._log_dataset (sample_dict , dataset_type )
0 commit comments