From 24bbd37b450db6f22182039c46d6d32539a10218 Mon Sep 17 00:00:00 2001 From: Erik Arakelyan Date: Tue, 21 Jun 2022 12:28:56 +0400 Subject: [PATCH 1/2] Added Aim integration --- README.md | 9 ++++-- doc/source/quick_start.rst | 7 ++++- torchdrug/core/__init__.py | 4 +-- torchdrug/core/engine.py | 2 ++ torchdrug/core/logger.py | 61 +++++++++++++++++++++++++++++++++++++- 5 files changed, 77 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 568bdafd..87b285c7 100644 --- a/README.md +++ b/README.md @@ -148,12 +148,17 @@ solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0, 1 solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0, 1, 2, 3, 0, 1, 2, 3]) ``` -Experiments can be easily tracked and managed through [Weights & Biases platform]. +Experiments can be easily tracked and managed through [Aim] or [Weights & Biases] platform. +```python +solver = core.Engine(task, train_set, valid_set, test_set, optimizer, logger="aim") +``` + ```python solver = core.Engine(task, train_set, valid_set, test_set, optimizer, logger="wandb") ``` -[Weights & Biases platform]: https://wandb.ai/ +[Aim]: https://aimstack.io/ +[Weights & Biases]: https://wandb.ai/ Contributing ------------ diff --git a/doc/source/quick_start.rst b/doc/source/quick_start.rst index 2717d1fa..c074bfb6 100644 --- a/doc/source/quick_start.rst +++ b/doc/source/quick_start.rst @@ -177,9 +177,14 @@ the script using ``python -m torch.distributed.launch --nproc_per_node=4``. solver = core.Engine(task, train_set, valid_set, test_set, optimizer, batch_size=256, gpus=[0, 1, 2, 3]) -We may log the training and evaluation metrics to Weights & Biases platform for +We may log the training and evaluation metrics to Aim or Weights & Biases platform for better experiment tracking in the browser. +.. code:: python + + solver = core.Engine(task, train_set, valid_set, test_set, optimizer, + batch_size=1024, logger="aim") + .. code:: python solver = core.Engine(task, train_set, valid_set, test_set, optimizer, diff --git a/torchdrug/core/__init__.py b/torchdrug/core/__init__.py index 2e0abba2..83432dbf 100644 --- a/torchdrug/core/__init__.py +++ b/torchdrug/core/__init__.py @@ -1,9 +1,9 @@ from .core import _MetaContainer, Registry, Configurable, make_configurable from .engine import Engine from .meter import Meter -from .logger import LoggerBase, LoggingLogger, WandbLogger +from .logger import LoggerBase, LoggingLogger, WandbLogger, AimLogger __all__ = [ "_MetaContainer", "Registry", "Configurable", - "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", + "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", "AimLogger", ] \ No newline at end of file diff --git a/torchdrug/core/engine.py b/torchdrug/core/engine.py index 4f6f1d3d..35172cd5 100644 --- a/torchdrug/core/engine.py +++ b/torchdrug/core/engine.py @@ -112,6 +112,8 @@ def __init__(self, task, train_set, valid_set, test_set, optimizer, scheduler=No logger = core.LoggingLogger() elif logger == "wandb": logger = core.WandbLogger(project=task.__class__.__name__) + elif logger == "aim": + logger = core.AimLogger(experiment_name=task.__class__.__name__) else: raise ValueError("Unknown logger `%s`" % logger) self.meter = core.Meter(log_interval=log_interval, silent=self.rank > 0, logger=logger) diff --git a/torchdrug/core/logger.py b/torchdrug/core/logger.py index d6599283..b353165a 100644 --- a/torchdrug/core/logger.py +++ b/torchdrug/core/logger.py @@ -124,4 +124,63 @@ def log(self, record, step_id, category="train/batch"): def log_config(self, confg_dict): super(WandbLogger, self).log_config(confg_dict) - self.run.config.update(confg_dict) \ No newline at end of file + self.run.config.update(confg_dict) + + +@R.register("core.AimLogger") +class AimLogger(LoggingLogger): + """ + Log outputs with `Aim`_ and track the experiment progress. + + Note this class also output logs with the builtin logging module. + + .. _Aim Documentation: + https://aimstack.readthedocs.io/en/latest/ + + .. _aim.Run: + https://aimstack.readthedocs.io/en/latest/refs/sdk.html#module-aim.sdk.run + + Parameters: + experiment_name (str, optional): name of the project + repo (str, optional): path to the repository. Default is `.`. + run_hash (str, optional): Hash of the run you want to continue from. If none is provided, a new run will be created. + """ + + def __init__(self, experiment_name=None, repo='.', run_has=None): + super(AimLogger, self).__init__() + try: + import aim + except ModuleNotFoundError: + raise ModuleNotFoundError("Aim is not found. Please install it with `pip install aim`") + + self.aim_run = aim.Run(repo = repo, experiment=experiment_name, run_hash=run_has) + + # self.run.define_metric("train/batch/*", step_metric="batch", summary="none") + # for split in ["train", "valid", "test"]: + # self.run.define_metric("%s/epoch/*" % split, step_metric="epoch") + + def log(self, record, step_id, category="train/batch"): + + super(AimLogger, self).log(record, step_id, category) + + print("_________Started Aim Logging_________") + print(category) + context_type,context_specific = category.split("/") + + print(record) + print(self.aim_run.repo) + + if category == "train/epoch": + for metric_name in sorted(record.keys()): + self.aim_run.track(record[metric_name], step=step_id, name=metric_name, context={f"{context_type}_average": context_specific}) + elif category == "valid/epoch": + for metric_name in sorted(record.keys()): + self.aim_run[metric_name] = record[metric_name].item() + else: + for metric_name in sorted(record.keys()): + self.aim_run.track(record[metric_name], step=step_id, name=metric_name, context={context_type: context_specific}) + + print("_________Ended Aim Logging_________") + + def log_config(self, confg_dict): + self.aim_run["config"] = confg_dict From 0a0b20c1ff792d88114ac08796e9ced2fca75768 Mon Sep 17 00:00:00 2001 From: Hovhannes Tamoyan Date: Wed, 17 Aug 2022 12:11:37 +0200 Subject: [PATCH 2/2] Applied pep8 formating --- torchdrug/core/__init__.py | 2 +- torchdrug/core/engine.py | 39 +++++++++++++++++++++++++------------- torchdrug/core/logger.py | 29 +++++++++++++++++----------- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/torchdrug/core/__init__.py b/torchdrug/core/__init__.py index 83432dbf..fa67da3e 100644 --- a/torchdrug/core/__init__.py +++ b/torchdrug/core/__init__.py @@ -6,4 +6,4 @@ __all__ = [ "_MetaContainer", "Registry", "Configurable", "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", "AimLogger", -] \ No newline at end of file +] diff --git a/torchdrug/core/engine.py b/torchdrug/core/engine.py index 35172cd5..d9f99929 100644 --- a/torchdrug/core/engine.py +++ b/torchdrug/core/engine.py @@ -94,7 +94,8 @@ def __init__(self, task, train_set, valid_set, test_set, optimizer, scheduler=No train_set, valid_set, test_set = result new_params = list(task.parameters()) if len(new_params) != len(old_params): - optimizer.add_param_group({"params": new_params[len(old_params):]}) + optimizer.add_param_group( + {"params": new_params[len(old_params):]}) if self.world_size > 1: task = nn.SyncBatchNorm.convert_sync_batchnorm(task) if self.device.type == "cuda": @@ -113,10 +114,12 @@ def __init__(self, task, train_set, valid_set, test_set, optimizer, scheduler=No elif logger == "wandb": logger = core.WandbLogger(project=task.__class__.__name__) elif logger == "aim": - logger = core.AimLogger(experiment_name=task.__class__.__name__) + logger = core.AimLogger( + experiment_name=task.__class__.__name__) else: raise ValueError("Unknown logger `%s`" % logger) - self.meter = core.Meter(log_interval=log_interval, silent=self.rank > 0, logger=logger) + self.meter = core.Meter(log_interval=log_interval, + silent=self.rank > 0, logger=logger) self.meter.log_config(self.config_dict()) def train(self, num_epoch=1, batch_per_epoch=None): @@ -130,8 +133,10 @@ def train(self, num_epoch=1, batch_per_epoch=None): num_epoch (int, optional): number of epochs batch_per_epoch (int, optional): number of batches per epoch """ - sampler = torch_data.DistributedSampler(self.train_set, self.world_size, self.rank) - dataloader = data.DataLoader(self.train_set, self.batch_size, sampler=sampler, num_workers=self.num_worker) + sampler = torch_data.DistributedSampler( + self.train_set, self.world_size, self.rank) + dataloader = data.DataLoader( + self.train_set, self.batch_size, sampler=sampler, num_workers=self.num_worker) batch_per_epoch = batch_per_epoch or len(dataloader) model = self.model if self.world_size > 1: @@ -139,7 +144,8 @@ def train(self, num_epoch=1, batch_per_epoch=None): model = nn.parallel.DistributedDataParallel(model, device_ids=[self.device], find_unused_parameters=True) else: - model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + model = nn.parallel.DistributedDataParallel( + model, find_unused_parameters=True) model.train() for epoch in self.meter(num_epoch): @@ -148,7 +154,8 @@ def train(self, num_epoch=1, batch_per_epoch=None): metrics = [] start_id = 0 # the last gradient update may contain less than gradient_interval batches - gradient_interval = min(batch_per_epoch - start_id, self.gradient_interval) + gradient_interval = min( + batch_per_epoch - start_id, self.gradient_interval) for batch_id, batch in enumerate(islice(dataloader, batch_per_epoch)): if self.device.type == "cuda": @@ -156,7 +163,8 @@ def train(self, num_epoch=1, batch_per_epoch=None): loss, metric = model(batch) if not loss.requires_grad: - raise RuntimeError("Loss doesn't require grad. Did you define any loss in the task?") + raise RuntimeError( + "Loss doesn't require grad. Did you define any loss in the task?") loss = loss / gradient_interval loss.backward() metrics.append(metric) @@ -173,7 +181,8 @@ def train(self, num_epoch=1, batch_per_epoch=None): metrics = [] start_id = batch_id + 1 - gradient_interval = min(batch_per_epoch - start_id, self.gradient_interval) + gradient_interval = min( + batch_per_epoch - start_id, self.gradient_interval) if self.scheduler: self.scheduler.step() @@ -194,8 +203,10 @@ def evaluate(self, split, log=True): logger.warning(pretty.separator) logger.warning("Evaluate on %s" % split) test_set = getattr(self, "%s_set" % split) - sampler = torch_data.DistributedSampler(test_set, self.world_size, self.rank) - dataloader = data.DataLoader(test_set, self.batch_size, sampler=sampler, num_workers=self.num_worker) + sampler = torch_data.DistributedSampler( + test_set, self.world_size, self.rank) + dataloader = data.DataLoader( + test_set, self.batch_size, sampler=sampler, num_workers=self.num_worker) model = self.model model.eval() @@ -269,7 +280,8 @@ def load_config_dict(cls, config): Construct an instance from the configuration dict. """ if getattr(cls, "_registry_key", cls.__name__) != config["class"]: - raise ValueError("Expect config class to be `%s`, but found `%s`" % (cls.__name__, config["class"])) + raise ValueError("Expect config class to be `%s`, but found `%s`" % ( + cls.__name__, config["class"])) optimizer_config = config.pop("optimizer") new_config = {} @@ -279,7 +291,8 @@ def load_config_dict(cls, config): if k != "class": new_config[k] = v optimizer_config["params"] = new_config["task"].parameters() - new_config["optimizer"] = core.Configurable.load_config_dict(optimizer_config) + new_config["optimizer"] = core.Configurable.load_config_dict( + optimizer_config) return cls(**new_config) diff --git a/torchdrug/core/logger.py b/torchdrug/core/logger.py index b353165a..7d08c7df 100644 --- a/torchdrug/core/logger.py +++ b/torchdrug/core/logger.py @@ -100,18 +100,21 @@ def __init__(self, project=None, name=None, dir=None, **kwargs): try: import wandb except ModuleNotFoundError: - raise ModuleNotFoundError("Wandb is not found. Please install it with `pip install wandb`") + raise ModuleNotFoundError( + "Wandb is not found. Please install it with `pip install wandb`") if wandb.run is not None: warnings.warn( - "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" + "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()` or `WandbLogger.finish()` before instantiating `WandbLogger`." ) self.run = wandb.run else: - self.run = wandb.init(project=project, name=name, dir=dir, **kwargs) + self.run = wandb.init( + project=project, name=name, dir=dir, **kwargs) - self.run.define_metric("train/batch/*", step_metric="batch", summary="none") + self.run.define_metric( + "train/batch/*", step_metric="batch", summary="none") for split in ["train", "valid", "test"]: self.run.define_metric("%s/epoch/*" % split, step_metric="epoch") @@ -151,9 +154,11 @@ def __init__(self, experiment_name=None, repo='.', run_has=None): try: import aim except ModuleNotFoundError: - raise ModuleNotFoundError("Aim is not found. Please install it with `pip install aim`") + raise ModuleNotFoundError( + "Aim is not found. Please install it with `pip install aim`") - self.aim_run = aim.Run(repo = repo, experiment=experiment_name, run_hash=run_has) + self.aim_run = aim.Run( + repo=repo, experiment=experiment_name, run_hash=run_has) # self.run.define_metric("train/batch/*", step_metric="batch", summary="none") # for split in ["train", "valid", "test"]: @@ -165,21 +170,23 @@ def log(self, record, step_id, category="train/batch"): print("_________Started Aim Logging_________") print(category) - context_type,context_specific = category.split("/") + context_type, context_specific = category.split("/") print(record) print(self.aim_run.repo) if category == "train/epoch": for metric_name in sorted(record.keys()): - self.aim_run.track(record[metric_name], step=step_id, name=metric_name, context={f"{context_type}_average": context_specific}) - elif category == "valid/epoch": + self.aim_run.track(record[metric_name], step=step_id, name=metric_name, context={ + f"{context_type}_average": context_specific}) + elif category == "valid/epoch": for metric_name in sorted(record.keys()): self.aim_run[metric_name] = record[metric_name].item() else: for metric_name in sorted(record.keys()): - self.aim_run.track(record[metric_name], step=step_id, name=metric_name, context={context_type: context_specific}) - + self.aim_run.track(record[metric_name], step=step_id, name=metric_name, context={ + context_type: context_specific}) + print("_________Ended Aim Logging_________") def log_config(self, confg_dict):