From 24e582524229c91a7c1002fcc07008ad249d2bc8 Mon Sep 17 00:00:00 2001 From: princekumarlahon Date: Sun, 22 Mar 2026 21:11:35 +0530 Subject: [PATCH] Add type hints to CustomMLFlowLogger methods --- neural_lam/custom_loggers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/neural_lam/custom_loggers.py b/neural_lam/custom_loggers.py index 635f515ed..533c86fa1 100644 --- a/neural_lam/custom_loggers.py +++ b/neural_lam/custom_loggers.py @@ -6,6 +6,8 @@ import mlflow.pytorch import pytorch_lightning as pl from loguru import logger +from typing import List, Optional +from matplotlib.figure import Figure class CustomMLFlowLogger(pl.loggers.MLFlowLogger): @@ -15,7 +17,7 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger): of version `2.0.3` at least. """ - def __init__(self, experiment_name, tracking_uri, run_name): + def __init__(self, experiment_name: str, tracking_uri: str, run_name: str) -> None: super().__init__( experiment_name=experiment_name, tracking_uri=tracking_uri ) @@ -25,7 +27,7 @@ def __init__(self, experiment_name, tracking_uri, run_name): mlflow.log_param("run_id", self.run_id) @property - def save_dir(self): + def save_dir(self) -> str: """ Returns the directory where the MLFlow artifacts are saved. Used to define the path to save output when using the logger. @@ -37,7 +39,7 @@ def save_dir(self): """ return "mlruns" - def log_image(self, key, images, step=None): + def log_image(self,key: str,images: List[Figure],step: Optional[int] = None,) -> None: """ Log a matplotlib figure as an image to MLFlow