Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion allrank/click_models/click_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def click_on_slates(slates: Union[Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]],
click_model: ClickModel, include_empty: bool) -> Tuple[List[Union[np.ndarray, torch.Tensor]], List[List[int]]]:
"""
This metod runs a click model on a list of slates and returns new slates with `y` taken from clicks
This method runs a click model on a list of slates and returns new slates with `y` taken from clicks

:param slates: a Tuple of X, y:
X being a list of slates represented by document vectors
Expand Down
2 changes: 2 additions & 0 deletions allrank/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class Config:
expected_metrics = attrib(type=Dict[str, Dict[str, float]], default={})
detect_anomaly = attrib(type=bool, default=False)
click_model = attrib(type=Optional[NameArgsConfig], default=None)
wandb_project_id = attrib(type=str, default=None)

@classmethod
def from_json(cls, config_path):
Expand All @@ -97,6 +98,7 @@ def from_dict(cls, config):
config["metrics"] = cls._parse_metrics(config["metrics"])
config["lr_scheduler"] = NameArgsConfig(**config["lr_scheduler"])
config["loss"] = NameArgsConfig(**config["loss"])
config["wandb_project_id"] = config["wandb_project_id"]
if "click_model" in config.keys():
config["click_model"] = NameArgsConfig(**config["click_model"])
return cls(**config)
Expand Down
25 changes: 23 additions & 2 deletions allrank/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
from allrank.utils.file_utils import create_output_dirs, PathsContainer, copy_local_to_gs
from allrank.utils.ltr_logging import init_logger
from allrank.utils.python_utils import dummy_context_mgr
from argparse import ArgumentParser, Namespace
from argparse import ArgumentParser, Namespace, BooleanOptionalAction
from attr import asdict
from functools import partial
from pprint import pformat
from torch import optim

import wandb

def parse_args() -> Namespace:
parser = ArgumentParser("allRank")
parser.add_argument("--job-dir", help="Base output path for all experiments", required=True)
parser.add_argument("--run-id", help="Name of this run to be recorded (must be unique within output dir)",
required=True)
parser.add_argument("--config-file-name", required=True, type=str, help="Name of json file with config")
parser.add_argument("--wandb", help="If true, log to wandb", action=BooleanOptionalAction)

return parser.parse_args()

Expand Down Expand Up @@ -86,6 +87,12 @@ def run():
else:
scheduler = None

if args.wandb:
wandb.init(project=config.wandb_project_id, config=asdict(config))
for metric, ks in config.metrics.items():
for k in ks:
wandb.define_metric(f"{metric}_{k}", summary="max")

with torch.autograd.detect_anomaly() if config.detect_anomaly else dummy_context_mgr(): # type: ignore
# run training
result = fit(
Expand All @@ -99,13 +106,27 @@ def run():
device=dev,
output_dir=paths.output_dir,
tensorboard_output_path=paths.tensorboard_output_path,
wandb_logging=args.wandb,
**asdict(config.training)
)

dump_experiment_result(args, config, paths.output_dir, result)

if urlparse(args.job_dir).scheme == "gs":
copy_local_to_gs(paths.local_base_output_path, args.job_dir)

if args.wandb:
abs_output_dir = os.path.abspath(paths.output_dir)

artifact = wandb.Artifact(name="experiment-results", type="results", description="Experiment results")
artifact.add_reference(f"file://{abs_output_dir}")
wandb.log_artifact(artifact)

artifact = wandb.Artifact(name="model", type="model", description="Trained model")
artifact.add_reference(f"file://{abs_output_dir}/model.pkl")
wandb.log_artifact(artifact)

wandb.finish()

assert_expected_metrics(result, config.expected_metrics)

Expand Down
14 changes: 13 additions & 1 deletion allrank/training/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from allrank.training.early_stop import EarlyStop
from allrank.utils.ltr_logging import get_logger
from allrank.utils.tensorboard_utils import TensorboardSummaryWriter
import wandb

logger = get_logger()

Expand Down Expand Up @@ -76,7 +77,7 @@ def get_current_lr(optimizer):


def fit(epochs, model, loss_func, optimizer, scheduler, train_dl, valid_dl, config,
gradient_clipping_norm, early_stopping_patience, device, output_dir, tensorboard_output_path):
gradient_clipping_norm, early_stopping_patience, device, output_dir, tensorboard_output_path, wandb_logging=False):
tensorboard_summary_writer = TensorboardSummaryWriter(tensorboard_output_path)

num_params = get_num_params(model)
Expand Down Expand Up @@ -119,6 +120,17 @@ def fit(epochs, model, loss_func, optimizer, scheduler, train_dl, valid_dl, conf
tensorboard_summary_writer.save_to_tensorboard(tensorboard_metrics_dict, epoch)

logger.info(epoch_summary(epoch, train_loss, val_loss, train_metrics, val_metrics))

if wandb_logging:
wandb.log({
"train_loss": train_loss,
"val_loss": val_loss,
**{
f"{metric}_{k}": val_metrics.get(f"{metric}_{k}")
for metric, ks in config.metrics.items()
for k in ks
}
})

current_val_metric_value = val_metrics.get(config.val_metric)
if scheduler:
Expand Down
3 changes: 2 additions & 1 deletion scripts/local_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@
"n": 4
}
},
"wandb_project_id": "allRank",
"expected_metrics" : {
"val": {
"ndcg_5": 0.76
}
}
}
}
3 changes: 2 additions & 1 deletion scripts/local_config_click_model.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"n": 4
}
},
"wandb_project_id": "allRank",
"expected_metrics" : {
"val": {
"ndcg_5": 0.76
Expand All @@ -73,4 +74,4 @@
"q_percentile": 0.5
}
}
}
}