Skip to content

Commit 3247db0

Browse files
authored
6612-support-mlflow-data-tracking (#6616)
Fixes #6612 . ### Description Add the dataset tracking support to MLFlow handler. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: binliu <[email protected]>
1 parent b8c3e86 commit 3247db0

File tree

2 files changed

+168
-3
lines changed

2 files changed

+168
-3
lines changed

monai/handlers/mlflow_handler.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313

1414
import os
1515
import time
16-
from collections.abc import Callable, Sequence
16+
import warnings
17+
from collections.abc import Callable, Mapping, Sequence
1718
from pathlib import Path
1819
from typing import TYPE_CHECKING, Any
1920

2021
import torch
22+
from torch.utils.data import Dataset
2123

2224
from monai.config import IgniteInfo
23-
from monai.utils import ensure_tuple, min_version, optional_import
25+
from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import
2426

2527
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
2628
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
2729
mlflow.entities, _ = optional_import(
2830
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
2931
)
32+
pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.")
33+
tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm")
3034

3135
if TYPE_CHECKING:
3236
from ignite.engine import Engine
@@ -72,6 +76,14 @@ class MLFlowHandler:
7276
Must accept parameter "engine", use default logger if None.
7377
iteration_logger: customized callable logger for iteration level logging with MLFlow.
7478
Must accept parameter "engine", use default logger if None.
79+
dataset_logger: customized callable logger to log the dataset information with MLFlow.
80+
Must accept parameter "dataset_dict", use default logger if None.
81+
dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch
82+
dataset, that needs to be recorded. This arg is only useful when MLFlow version >= 2.4.0.
83+
For more details about how to log data with MLFlow, please go to the website:
84+
https://mlflow.org/docs/latest/python_api/mlflow.data.html.
85+
dataset_keys: a key or a collection of keys to indicate contents in the dataset that
86+
need to be stored by MLFlow.
7587
output_transform: a callable that is used to transform the
7688
``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}.
7789
By default this value logging happens when every iteration completed.
@@ -111,6 +123,9 @@ def __init__(
111123
epoch_log: bool | Callable[[Engine, int], bool] = True,
112124
epoch_logger: Callable[[Engine], Any] | None = None,
113125
iteration_logger: Callable[[Engine], Any] | None = None,
126+
dataset_logger: Callable[[Mapping[str, Dataset]], Any] | None = None,
127+
dataset_dict: Mapping[str, Dataset] | None = None,
128+
dataset_keys: str = CommonKeys.IMAGE,
114129
output_transform: Callable = lambda x: x[0],
115130
global_epoch_transform: Callable = lambda x: x,
116131
state_attributes: Sequence[str] | None = None,
@@ -126,6 +141,7 @@ def __init__(
126141
self.epoch_log = epoch_log
127142
self.epoch_logger = epoch_logger
128143
self.iteration_logger = iteration_logger
144+
self.dataset_logger = dataset_logger
129145
self.output_transform = output_transform
130146
self.global_epoch_transform = global_epoch_transform
131147
self.state_attributes = state_attributes
@@ -140,6 +156,8 @@ def __init__(
140156
self.close_on_complete = close_on_complete
141157
self.experiment = None
142158
self.cur_run = None
159+
self.dataset_dict = dataset_dict
160+
self.dataset_keys = ensure_tuple(dataset_keys)
143161

144162
def _delete_exist_param_in_dict(self, param_dict: dict) -> None:
145163
"""
@@ -210,6 +228,11 @@ def start(self, engine: Engine) -> None:
210228
self._delete_exist_param_in_dict(attrs)
211229
self._log_params(attrs)
212230

231+
if self.dataset_logger:
232+
self.dataset_logger(self.dataset_dict)
233+
else:
234+
self._default_dataset_log(self.dataset_dict)
235+
213236
def _set_experiment(self):
214237
experiment = self.experiment
215238
if not experiment:
@@ -222,6 +245,36 @@ def _set_experiment(self):
222245
raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
223246
self.experiment = experiment
224247

248+
@staticmethod
249+
def _get_pandas_dataset_info(pandas_dataset):
250+
dataset_name = pandas_dataset.name
251+
return {
252+
f"{dataset_name}_digest": pandas_dataset.digest,
253+
f"{dataset_name}_samples": pandas_dataset.profile["num_rows"],
254+
}
255+
256+
def _log_dataset(self, sample_dict: dict[str, Any], context: str = "train") -> None:
257+
if not self.cur_run:
258+
raise ValueError("Current Run is not Active to log the dataset")
259+
260+
# Need to update the self.cur_run to sync the dataset log, otherwise the `inputs` info will be out-of-date.
261+
self.cur_run = self.client.get_run(self.cur_run.info.run_id)
262+
logged_set = [x for x in self.cur_run.inputs.dataset_inputs if x.dataset.name.startswith(context)]
263+
# In case there are datasets with the same name.
264+
dataset_count = str(len(logged_set))
265+
dataset_name = f"{context}_dataset_{dataset_count}"
266+
sample_df = pandas.DataFrame(sample_dict)
267+
dataset = mlflow.data.from_pandas(sample_df, name=dataset_name)
268+
exist_dataset_list = list(
269+
filter(lambda x: x.dataset.digest == dataset.digest, self.cur_run.inputs.dataset_inputs)
270+
)
271+
272+
if not len(exist_dataset_list):
273+
datasets = [mlflow.entities.DatasetInput(dataset._to_mlflow_entity())]
274+
self.client.log_inputs(run_id=self.cur_run.info.run_id, datasets=datasets)
275+
dataset_info = MLFlowHandler._get_pandas_dataset_info(dataset)
276+
self._log_params(dataset_info)
277+
225278
def _log_params(self, params: dict[str, Any]) -> None:
226279
if not self.cur_run:
227280
raise ValueError("Current Run is not Active to log params")
@@ -352,3 +405,61 @@ def _default_iteration_log(self, engine: Engine) -> None:
352405
for i, param_group in enumerate(cur_optimizer.param_groups)
353406
}
354407
self._log_metrics(params, step=engine.state.iteration)
408+
409+
def _default_dataset_log(self, dataset_dict: Mapping[str, Dataset] | None) -> None:
410+
"""
411+
Execute dataset log operation based on the input dataset_dict. The dataset_dict should have a format
412+
like:
413+
{
414+
"dataset_name0": dataset0,
415+
"dataset_name1": dataset1,
416+
......
417+
}
418+
The keys stand for names of datasets, which will be logged as prefixes of dataset names in MLFlow.
419+
The values are PyTorch datasets from which sample names are abstracted to build a Pandas DataFrame.
420+
If the input dataset_dict is None, this function will directly return and do nothing.
421+
422+
To use this function, every sample in the input datasets must contain keys specified by the `dataset_keys`
423+
parameter.
424+
This function will log a PandasDataset to MLFlow inputs, generated from the Pandas DataFrame.
425+
For more details about PandasDataset, please refer to this link:
426+
https://mlflow.org/docs/latest/python_api/mlflow.data.html#mlflow.data.pandas_dataset.PandasDataset
427+
428+
Please note that it may take a while to record the dataset if it has too many samples.
429+
430+
Args:
431+
dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch
432+
dataset, that needs to be recorded.
433+
434+
"""
435+
436+
if dataset_dict is None:
437+
return
438+
elif len(dataset_dict) == 0:
439+
warnings.warn("There is no dataset to log!")
440+
441+
# Log datasets to MLFlow one by one.
442+
for dataset_type, dataset in dataset_dict.items():
443+
if dataset is None:
444+
raise AttributeError(f"The {dataset_type} dataset of is None. Cannot record it by MLFlow.")
445+
446+
sample_dict: dict[str, list[str]] = {}
447+
dataset_samples = getattr(dataset, "data", [])
448+
for sample in tqdm(dataset_samples, f"Recording the {dataset_type} dataset"):
449+
for key in self.dataset_keys:
450+
if key not in sample_dict:
451+
sample_dict[key] = []
452+
453+
if key in sample:
454+
value_to_log = sample[key]
455+
else:
456+
raise KeyError(f"Unexpect key '{key}' in the sample.")
457+
458+
if not isinstance(value_to_log, str):
459+
warnings.warn(
460+
f"Expected type string, got type {type(value_to_log)} of the {key} name."
461+
"May log an empty dataset in MLFlow"
462+
)
463+
else:
464+
sample_dict[key].append(value_to_log)
465+
self._log_dataset(sample_dict, dataset_type)

tests/test_handler_mlflow.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
from ignite.engine import Engine, Events
2424
from parameterized import parameterized
2525

26+
from monai.apps import download_and_extract
27+
from monai.bundle import ConfigWorkflow, download
2628
from monai.handlers import MLFlowHandler
27-
from monai.utils import path_to_uri
29+
from monai.utils import optional_import, path_to_uri
30+
from tests.utils import skip_if_downloading_fails, skip_if_quick
31+
32+
_, has_dataset_tracking = optional_import("mlflow", "2.4.0")
2833

2934

3035
def get_event_filter(e):
@@ -230,6 +235,55 @@ def test_multi_thread(self):
230235
self.tmpdir_list.append(res)
231236
self.assertTrue(len(glob.glob(res)) > 0)
232237

238+
@skip_if_quick
239+
@unittest.skipUnless(has_dataset_tracking, reason="Requires mlflow version >= 2.4.0.")
240+
def test_dataset_tracking(self):
241+
test_bundle_name = "endoscopic_tool_segmentation"
242+
with tempfile.TemporaryDirectory() as tempdir:
243+
resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/endoscopic_tool_dataset.zip"
244+
md5 = "f82da47259c0a617202fb54624798a55"
245+
compressed_file = os.path.join(tempdir, "endoscopic_tool_segmentation.zip")
246+
data_dir = os.path.join(tempdir, "endoscopic_tool_dataset")
247+
with skip_if_downloading_fails():
248+
if not os.path.exists(data_dir):
249+
download_and_extract(resource, compressed_file, tempdir, md5)
250+
251+
download(test_bundle_name, bundle_dir=tempdir)
252+
253+
bundle_root = os.path.join(tempdir, test_bundle_name)
254+
config_file = os.path.join(bundle_root, "configs/inference.json")
255+
meta_file = os.path.join(bundle_root, "configs/metadata.json")
256+
logging_file = os.path.join(bundle_root, "configs/logging.conf")
257+
workflow = ConfigWorkflow(
258+
workflow="infer",
259+
config_file=config_file,
260+
meta_file=meta_file,
261+
logging_file=logging_file,
262+
init_id="initialize",
263+
run_id="run",
264+
final_id="finalize",
265+
)
266+
267+
tracking_path = os.path.join(bundle_root, "eval")
268+
workflow.bundle_root = bundle_root
269+
workflow.dataset_dir = data_dir
270+
workflow.initialize()
271+
infer_dataset = workflow.dataset
272+
mlflow_handler = MLFlowHandler(
273+
iteration_log=False,
274+
epoch_log=False,
275+
dataset_dict={"test": infer_dataset},
276+
tracking_uri=path_to_uri(tracking_path),
277+
)
278+
mlflow_handler.attach(workflow.evaluator)
279+
workflow.run()
280+
workflow.finalize()
281+
282+
cur_run = mlflow_handler.client.get_run(mlflow_handler.cur_run.info.run_id)
283+
logged_nontrain_set = [x for x in cur_run.inputs.dataset_inputs if x.dataset.name.startswith("test")]
284+
self.assertEqual(len(logged_nontrain_set), 1)
285+
mlflow_handler.close()
286+
233287

234288
if __name__ == "__main__":
235289
unittest.main()

0 commit comments

Comments
 (0)