Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,52 @@ def __post_init__(self):
)


@dataclass
class WandBSystemMetricsConfig:
"""Worker-side W&B system metrics collection.

The controller W&B client already records system metrics for the controller
process. Enable this config to attach GPU worker processes to the same W&B
run so W&B can sample system metrics from the GPU nodes too.
"""

enabled: bool = field(
default=False,
metadata={
"help": "Start non-primary W&B clients in worker processes to collect "
"worker system metrics. Requires wandb.mode='shared'.",
},
)
roles: list[str] | None = field(
default=("actor", "rollout", "critic", "ref", "teacher"),
metadata={
"help": "Worker roles that should start W&B system metrics clients. "
"Set to null to enable every configured worker role.",
},
)
gpu_device_ids: list[int] | None = field(
default=None,
metadata={
"help": "Optional GPU device ids passed to W&B's system metrics "
"collector. Leave unset to let W&B use the worker's visible devices.",
},
)

def __post_init__(self):
if self.roles is not None:
if not self.roles:
raise ValueError(
"stats_logger.wandb.system_metrics.roles must be null or a non-empty list."
)
self.roles = list(self.roles)
if self.gpu_device_ids is not None:
self.gpu_device_ids = list(self.gpu_device_ids)
if any(i < 0 for i in self.gpu_device_ids):
raise ValueError(
"stats_logger.wandb.system_metrics.gpu_device_ids must contain non-negative integers."
)


@dataclass
class WandBConfig:
"""Configuration for Weights & Biases experiment tracking."""
Expand All @@ -2316,6 +2362,12 @@ class WandBConfig:
tags: list[str] | None = None
config: dict | None = None
id_suffix: str | None = "train"
system_metrics: WandBSystemMetricsConfig = field(
default_factory=WandBSystemMetricsConfig,
metadata={
"help": "Worker-side W&B system metrics configuration.",
},
)

def __post_init__(self):
"""Validate WandB configuration."""
Expand All @@ -2324,6 +2376,11 @@ def __post_init__(self):
raise ValueError(
f"Invalid wandb mode: '{self.mode}'. Must be one of: {', '.join(valid_modes)}."
)
if self.system_metrics.enabled and self.mode != "shared":
raise ValueError(
"stats_logger.wandb.system_metrics.enabled requires "
"stats_logger.wandb.mode='shared'."
)


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions areal/infra/rpc/ray_rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
)
from areal.utils.dynamic_import import import_from_string
from areal.utils.network import find_free_ports
from areal.utils.wandb_system_metrics import (
finish_worker_wandb_system_metrics,
init_worker_wandb_system_metrics,
)


@ray.remote
Expand Down Expand Up @@ -79,6 +83,7 @@ def alloc_ports(self, count: int):

def configure(self, config: BaseExperimentConfig, role: str, rank: int) -> None:
name_resolve.reconfigure(config.cluster.name_resolve)
init_worker_wandb_system_metrics(config, role=role, rank=rank)
# Set seed for any TrainEngine instances
for engine in self._engines.values():
if isinstance(engine, TrainEngine):
Expand Down Expand Up @@ -219,4 +224,5 @@ def destroy(self) -> None:
)
self._engines.clear()
self._default_engine_name = None
finish_worker_wandb_system_metrics()
ray.actor.exit_actor()
4 changes: 4 additions & 0 deletions areal/infra/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from areal.infra.rpc.guard.data_blueprint import data_bp
from areal.infra.rpc.guard.engine_blueprint import engine_bp, register_engine_hooks
from areal.utils import logging, perf_tracer
from areal.utils.wandb_system_metrics import (
register_worker_wandb_system_metrics_hooks,
)

logger = logging.getLogger("SyncRPCServer")

Expand Down Expand Up @@ -54,6 +57,7 @@ def main():
app.register_blueprint(data_bp)
app.register_blueprint(engine_bp)
register_engine_hooks(state)
register_worker_wandb_system_metrics_hooks(state)

state.register_cleanup_hook(lambda: perf_tracer.save(force=True))

Expand Down
2 changes: 2 additions & 0 deletions areal/infra/scheduler/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
gethostip,
)
from areal.utils.offload import get_tms_env_vars
from areal.utils.wandb_system_metrics import prepare_wandb_run_identity

logger = logging.getLogger("LocalScheduler")

Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
exp_config: BaseExperimentConfig | None = None,
):
self.gpu_devices = gpu_devices or self._detect_gpus()
prepare_wandb_run_identity(exp_config)

# Resolve experiment/trial names (exp_config overwrites direct params)
self.experiment_name = experiment_name
Expand Down
3 changes: 3 additions & 0 deletions areal/infra/scheduler/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from areal.utils import logging
from areal.utils.offload import get_tms_env_vars
from areal.utils.wandb_system_metrics import prepare_wandb_run_identity

logger = logging.getLogger("RayScheduler")

Expand All @@ -63,6 +64,8 @@ def __init__(
exp_config: BaseExperimentConfig | None = None,
n_gpus_per_node: int = 8,
):
prepare_wandb_run_identity(exp_config)

self.exp_config = exp_config
self._n_gpus_per_node = n_gpus_per_node
self.startup_timeout = startup_timeout
Expand Down
3 changes: 3 additions & 0 deletions areal/infra/scheduler/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from areal.utils.fs import validate_shared_path
from areal.utils.network import format_hostport, split_hostport
from areal.utils.offload import get_tms_env_vars
from areal.utils.wandb_system_metrics import prepare_wandb_run_identity

logger = logging.getLogger("SlurmScheduler")

Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
etcd3_addr: str = "localhost:2379",
exp_config: BaseExperimentConfig | None = None,
):
prepare_wandb_run_identity(exp_config)

# Get n_gpus_per_node from parameter or config
self._n_gpus_per_node = n_gpus_per_node
if exp_config is not None:
Expand Down
6 changes: 3 additions & 3 deletions areal/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __init__(
train_batch_size=config.train_dataset.batch_size,
)

# Initialize W&B primary before worker configuration in shared mode.
self.stats_logger = StatsLogger(config, ft_spec)

self.actor.initialize(addr=None, ft_spec=ft_spec, role="actor")
self.ref.initialize(addr=None, ft_spec=ft_spec, role="ref")

Expand Down Expand Up @@ -174,9 +177,6 @@ def __init__(
self.saver = Saver(config.saver, ft_spec)
self.recover_handler = RecoverHandler(config.recover, ft_spec)

# Set up statistics logging (wandb, tensorboard, etc.)
self.stats_logger = StatsLogger(config, ft_spec)

# Set up checkpointing for recover
self.recover_info = self.recover_handler.load(
self.actor,
Expand Down
7 changes: 4 additions & 3 deletions areal/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def __init__(
train_batch_size=config.train_dataset.batch_size,
)

# Initialize the controller W&B client before workers are configured so
# worker-side shared-mode clients can attach to an existing primary run.
self.stats_logger = StatsLogger(config, ft_spec)

# Initialize engines first — the scheduler must know about roles
# before the data controller can colocate with them.
engine_init_kwargs = {"addr": None, "ft_spec": ft_spec}
Expand Down Expand Up @@ -354,9 +358,6 @@ def __init__(
self.saver = Saver(config.saver, ft_spec)
self.recover_handler = RecoverHandler(config.recover, ft_spec)

# Set up statistics logging (wandb, tensoboard, etc.)
self.stats_logger = StatsLogger(config, ft_spec)

# Set up checkpointing for recover
self.recover_info = self.recover_handler.load(
self.actor,
Expand Down
5 changes: 2 additions & 3 deletions areal/trainer/rw_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def __init__(
train_batch_size=config.train_dataset.batch_size,
)

self.stats_logger = StatsLogger(config, ft_spec)

self.actor.initialize(addr=None, ft_spec=ft_spec, role="actor")

self.valid_dataloader: StatefulDataLoader | None = None
Expand Down Expand Up @@ -164,9 +166,6 @@ def __init__(
self.saver = Saver(config.saver, ft_spec)
self.recover_handler = RecoverHandler(config.recover, ft_spec)

# Set up statistics logging (wandb, tensorboard, etc.)
self.stats_logger = StatsLogger(config, ft_spec)

# Set up checkpointing for recover
self.recover_info = self.recover_handler.load(
self.actor,
Expand Down
3 changes: 2 additions & 1 deletion areal/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(
train_batch_size=config.train_dataset.batch_size,
)

self.stats_logger = StatsLogger(config, ft_spec)

self.actor.initialize(addr=None, ft_spec=ft_spec, role="actor")

self.valid_dataloader: StatefulDataLoader | None = None
Expand All @@ -135,7 +137,6 @@ def __init__(
self.evaluator = Evaluator(config.evaluator, ft_spec)
self.saver = Saver(config.saver, ft_spec)
self.recover_handler = RecoverHandler(config.recover, ft_spec)
self.stats_logger = StatsLogger(config, ft_spec)
self.recover_info = self.recover_handler.load(
self.actor,
self.saver,
Expand Down
1 change: 1 addition & 0 deletions areal/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"StatsLogger": "light_green",
"StatsTracker": "light_green",
"PerfTracer": "light_green",
"WandBSystemMetrics": "light_green",
# RPC servers - white
"SyncRPCServer": "white",
"RayRPCServer": "white",
Expand Down
34 changes: 24 additions & 10 deletions areal/utils/stats_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import getpass
import os
import time
from dataclasses import asdict
Expand All @@ -15,6 +14,10 @@
from areal.api.cli_args import BaseExperimentConfig, StatsLoggerConfig
from areal.utils import logging
from areal.utils.printing import tabulate_stats
from areal.utils.wandb_system_metrics import (
resolve_wandb_run_id,
stats_logger_log_path,
)
from areal.version import version_info

logger = logging.getLogger("StatsLogger", "system")
Expand Down Expand Up @@ -48,10 +51,6 @@ def init(self):
if self.config.wandb.mode != "disabled":
wandb.login()

suffix = self.config.wandb.id_suffix
if suffix == "timestamp":
suffix = time.strftime("%Y_%m_%d_%H_%M_%S")

exp_config_dict = asdict(self.exp_config)
exp_config_dict["version_info"] = {
"commit_id": version_info.commit,
Expand All @@ -60,7 +59,15 @@ def init(self):
"version": version_info.full_version_with_dirty_description,
}

wandb.init(
wandb_settings = None
if self.config.wandb.mode == "shared":
wandb_settings = wandb.Settings(
mode="shared",
x_primary=True,
x_label="controller",
)

wandb_init_kwargs = dict(
mode=self.config.wandb.mode,
entity=self.config.wandb.entity,
project=self.config.wandb.project or self.config.experiment_name,
Expand All @@ -73,9 +80,12 @@ def init(self):
config=exp_config_dict, # save all experiment config to wandb
dir=self.get_log_path(self.config),
force=True,
id=f"{self.config.experiment_name}_{self.config.trial_name}_{suffix}",
id=resolve_wandb_run_id(self.config),
resume="allow",
)
if wandb_settings is not None:
wandb_init_kwargs["settings"] = wandb_settings
wandb.init(**wandb_init_kwargs)

swanlab_config = self.config.swanlab
if swanlab_config.mode != "disabled":
Expand Down Expand Up @@ -176,6 +186,10 @@ def get_log_path(
raise ValueError(
"fileroot, experiment_name, and trial_name must be provided."
)
path = f"{fileroot}/logs/{getpass.getuser()}/{experiment_name}/{trial_name}"
os.makedirs(path, exist_ok=True)
return path
return stats_logger_log_path(
StatsLoggerConfig(
experiment_name=experiment_name,
trial_name=trial_name,
fileroot=fileroot,
)
)
Loading
Loading