Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Integrating W&B Tables for Prediction Visualization #1154

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wandb checkpointing
  • Loading branch information
ayulockin committed Oct 27, 2021
commit 5d4ee99c6c6bb3fccaac4a2e7272e4dbd36a341d
5 changes: 4 additions & 1 deletion mmf/configs/defaults.yaml
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ training:
# Weights and Biases control, by default Weights and Biases (wandb) is disabled
wandb:
# Whether to use Weights and Biases Logger, (Default: false)
enabled: false
enabled: true
# An entity is a username or team name where you're sending runs.
# This is necessary if you want to log your metrics to a team account. By default
# it will log the run to your user account.
@@ -54,6 +54,9 @@ training:
# Experiment/ run name to be used while logging the experiment
# under the project with wandb
name: ${training.experiment_name}
# You can save your model checkpoints as W&B Artifacts for model versioning.
# Set the value to `true` to enable this feature.
log_checkpoint: true
# Specify other argument values that you want to pass to wandb.init(). Check out the documentation
# at https://docs.wandb.ai/ref/python/init to see what arguments are available.
# job_type: 'train'
13 changes: 13 additions & 0 deletions mmf/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -522,6 +522,7 @@ def save(self, update, iteration=None, update_best=False):
best_metric = (
self.trainer.early_stop_callback.early_stopping.best_monitored_value
)

model = self.trainer.model
data_parallel = registry.get("data_parallel") or registry.get("distributed")
fp16_scaler = getattr(self.trainer, "scaler", None)
@@ -574,6 +575,18 @@ def save(self, update, iteration=None, update_best=False):
with open_if_main(current_ckpt_filepath, "wb") as f:
self.save_func(ckpt, f)

# Save the current checkpoint as W&B artifacts for model versioning.
if (
self.config.training.wandb.enabled
and self.config.training.wandb.log_checkpoint
):
logger.info(
"Saving current checkpoint as W&B Artifacts for model versioning"
)
self.trainer.logistics_callback.wandb_logger.log_model_checkpoint(
current_ckpt_filepath, ckpt
)

# Remove old checkpoints if max_to_keep is set
# In XLA, only delete checkpoint files in main process
if self.max_to_keep > 0 and is_main():
22 changes: 21 additions & 1 deletion mmf/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import collections
import copy
import functools
import json
import logging
@@ -429,11 +430,12 @@ def __init__(

self._wandb_init = dict(entity=entity, config=config, project=project)

wandb_params = config.training.wandb
wandb_params = copy.copy(config.training.wandb)
with omegaconf.open_dict(wandb_params):
wandb_params.pop("enabled")
wandb_params.pop("entity")
wandb_params.pop("project")
wandb_params.pop("log_checkpoint")

init_kwargs = OmegaConf.to_container(wandb_params, resolve=True)
self._wandb_init.update(**init_kwargs)
@@ -479,3 +481,21 @@ def log_metrics(self, metrics: Dict[str, float], commit=True):
return

self._wandb.log(metrics, commit=commit)

def log_model_checkpoint(self, model_path, ckpt_dict):
"""
Log the model checkpoint to the wandb dashboard.

Args:
model_path (str): Path to the model file.
ckpt_dict (Dict[str, Any]): Checkpoint dictionary.
"""
if not self._should_log_wandb():
return

model_artifact = self._wandb.Artifact(
"run_" + self._wandb.run.id + "_model", type="model"
)

model_artifact.add_file(model_path, name="current.pt")
self._wandb.log_artifact(model_artifact, aliases=["latest"])