diff --git a/neural_lam/custom_loggers.py b/neural_lam/custom_loggers.py index 635f515ed..43e90497d 100644 --- a/neural_lam/custom_loggers.py +++ b/neural_lam/custom_loggers.py @@ -1,5 +1,6 @@ # Standard library import sys +from typing import Optional, Dict, Any # Third-party import mlflow @@ -15,7 +16,9 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger): of version `2.0.3` at least. """ - def __init__(self, experiment_name, tracking_uri, run_name): + def __init__( + self, experiment_name: str, tracking_uri: str, run_name: str + ) -> None: super().__init__( experiment_name=experiment_name, tracking_uri=tracking_uri ) @@ -25,7 +28,7 @@ def __init__(self, experiment_name, tracking_uri, run_name): mlflow.log_param("run_id", self.run_id) @property - def save_dir(self): + def save_dir(self) -> str: """ Returns the directory where the MLFlow artifacts are saved. Used to define the path to save output when using the logger. @@ -37,7 +40,9 @@ def save_dir(self): """ return "mlruns" - def log_image(self, key, images, step=None): + def log_image( + self, key: str, images: list, step: Optional[int] = None + ) -> None: """ Log a matplotlib figure as an image to MLFlow