From 17c917e77d2d32dd9cf47b127add1f434e5cd58c Mon Sep 17 00:00:00 2001 From: Marcel Gregoriadis Date: Tue, 17 Dec 2024 00:43:33 +0100 Subject: [PATCH] added support for wandb logging --- allrank/click_models/click_utils.py | 2 +- allrank/config.py | 2 ++ allrank/main.py | 25 +++++++++++++++++++++++-- allrank/training/train_utils.py | 14 +++++++++++++- scripts/local_config.json | 3 ++- scripts/local_config_click_model.json | 3 ++- 6 files changed, 43 insertions(+), 6 deletions(-) diff --git a/allrank/click_models/click_utils.py b/allrank/click_models/click_utils.py index 99853ec..b0635cf 100644 --- a/allrank/click_models/click_utils.py +++ b/allrank/click_models/click_utils.py @@ -10,7 +10,7 @@ def click_on_slates(slates: Union[Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]], click_model: ClickModel, include_empty: bool) -> Tuple[List[Union[np.ndarray, torch.Tensor]], List[List[int]]]: """ - This metod runs a click model on a list of slates and returns new slates with `y` taken from clicks + This method runs a click model on a list of slates and returns new slates with `y` taken from clicks :param slates: a Tuple of X, y: X being a list of slates represented by document vectors diff --git a/allrank/config.py b/allrank/config.py index e7a1d85..9360025 100644 --- a/allrank/config.py +++ b/allrank/config.py @@ -76,6 +76,7 @@ class Config: expected_metrics = attrib(type=Dict[str, Dict[str, float]], default={}) detect_anomaly = attrib(type=bool, default=False) click_model = attrib(type=Optional[NameArgsConfig], default=None) + wandb_project_id = attrib(type=str, default=None) @classmethod def from_json(cls, config_path): @@ -97,6 +98,7 @@ def from_dict(cls, config): config["metrics"] = cls._parse_metrics(config["metrics"]) config["lr_scheduler"] = NameArgsConfig(**config["lr_scheduler"]) config["loss"] = NameArgsConfig(**config["loss"]) + config["wandb_project_id"] = config["wandb_project_id"] if "click_model" in config.keys(): config["click_model"] = NameArgsConfig(**config["click_model"]) return cls(**config) diff --git a/allrank/main.py b/allrank/main.py index 58792fb..39195a4 100644 --- a/allrank/main.py +++ b/allrank/main.py @@ -14,12 +14,12 @@ from allrank.utils.file_utils import create_output_dirs, PathsContainer, copy_local_to_gs from allrank.utils.ltr_logging import init_logger from allrank.utils.python_utils import dummy_context_mgr -from argparse import ArgumentParser, Namespace +from argparse import ArgumentParser, Namespace, BooleanOptionalAction from attr import asdict from functools import partial from pprint import pformat from torch import optim - +import wandb def parse_args() -> Namespace: parser = ArgumentParser("allRank") @@ -27,6 +27,7 @@ def parse_args() -> Namespace: parser.add_argument("--run-id", help="Name of this run to be recorded (must be unique within output dir)", required=True) parser.add_argument("--config-file-name", required=True, type=str, help="Name of json file with config") + parser.add_argument("--wandb", help="If true, log to wandb", action=BooleanOptionalAction) return parser.parse_args() @@ -86,6 +87,12 @@ def run(): else: scheduler = None + if args.wandb: + wandb.init(project=config.wandb_project_id, config=asdict(config)) + for metric, ks in config.metrics.items(): + for k in ks: + wandb.define_metric(f"{metric}_{k}", summary="max") + with torch.autograd.detect_anomaly() if config.detect_anomaly else dummy_context_mgr(): # type: ignore # run training result = fit( @@ -99,6 +106,7 @@ def run(): device=dev, output_dir=paths.output_dir, tensorboard_output_path=paths.tensorboard_output_path, + wandb_logging=args.wandb, **asdict(config.training) ) @@ -106,6 +114,19 @@ def run(): if urlparse(args.job_dir).scheme == "gs": copy_local_to_gs(paths.local_base_output_path, args.job_dir) + + if args.wandb: + abs_output_dir = os.path.abspath(paths.output_dir) + + artifact = wandb.Artifact(name="experiment-results", type="results", description="Experiment results") + artifact.add_reference(f"file://{abs_output_dir}") + wandb.log_artifact(artifact) + + artifact = wandb.Artifact(name="model", type="model", description="Trained model") + artifact.add_reference(f"file://{abs_output_dir}/model.pkl") + wandb.log_artifact(artifact) + + wandb.finish() assert_expected_metrics(result, config.expected_metrics) diff --git a/allrank/training/train_utils.py b/allrank/training/train_utils.py index 3bede64..b3ea53c 100644 --- a/allrank/training/train_utils.py +++ b/allrank/training/train_utils.py @@ -11,6 +11,7 @@ from allrank.training.early_stop import EarlyStop from allrank.utils.ltr_logging import get_logger from allrank.utils.tensorboard_utils import TensorboardSummaryWriter +import wandb logger = get_logger() @@ -76,7 +77,7 @@ def get_current_lr(optimizer): def fit(epochs, model, loss_func, optimizer, scheduler, train_dl, valid_dl, config, - gradient_clipping_norm, early_stopping_patience, device, output_dir, tensorboard_output_path): + gradient_clipping_norm, early_stopping_patience, device, output_dir, tensorboard_output_path, wandb_logging=False): tensorboard_summary_writer = TensorboardSummaryWriter(tensorboard_output_path) num_params = get_num_params(model) @@ -119,6 +120,17 @@ def fit(epochs, model, loss_func, optimizer, scheduler, train_dl, valid_dl, conf tensorboard_summary_writer.save_to_tensorboard(tensorboard_metrics_dict, epoch) logger.info(epoch_summary(epoch, train_loss, val_loss, train_metrics, val_metrics)) + + if wandb_logging: + wandb.log({ + "train_loss": train_loss, + "val_loss": val_loss, + **{ + f"{metric}_{k}": val_metrics.get(f"{metric}_{k}") + for metric, ks in config.metrics.items() + for k in ks + } + }) current_val_metric_value = val_metrics.get(config.val_metric) if scheduler: diff --git a/scripts/local_config.json b/scripts/local_config.json index 2dfc96c..85cb826 100644 --- a/scripts/local_config.json +++ b/scripts/local_config.json @@ -55,9 +55,10 @@ "n": 4 } }, + "wandb_project_id": "allRank", "expected_metrics" : { "val": { "ndcg_5": 0.76 } } -} \ No newline at end of file +} diff --git a/scripts/local_config_click_model.json b/scripts/local_config_click_model.json index 3251301..1a74a1d 100644 --- a/scripts/local_config_click_model.json +++ b/scripts/local_config_click_model.json @@ -55,6 +55,7 @@ "n": 4 } }, + "wandb_project_id": "allRank", "expected_metrics" : { "val": { "ndcg_5": 0.76 @@ -73,4 +74,4 @@ "q_percentile": 0.5 } } -} \ No newline at end of file +}