diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py index e090dd6..5e074e5 100644 --- a/amp_rsl_rl/runners/amp_on_policy_runner.py +++ b/amp_rsl_rl/runners/amp_on_policy_runner.py @@ -234,12 +234,46 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals ) elif self.logger_type == "wandb": from rsl_rl.utils.wandb_utils import WandbSummaryWriter + import wandb + + # Update the run name with a sequence number. This function is useful to + # replicate the same behaviour of rsl-rl-lib before v2.3.0 + def update_run_name_with_sequence(prefix: str) -> None: + # Retrieve the current wandb run details (project and entity) + project = wandb.run.project + entity = wandb.run.entity + + # Use wandb's API to list all runs in your project + api = wandb.Api() + runs = api.runs(f"{entity}/{project}") + + max_num = 0 + # Iterate through runs to extract the numeric suffix after the prefix. + for run in runs: + if run.name.startswith(prefix): + # Extract the numeric part from the run name. + numeric_suffix = run.name[ + len(prefix) : + ] # e.g., from "prefix564", get "564" + try: + run_num = int(numeric_suffix) + if run_num > max_num: + max_num = run_num + except ValueError: + continue + + # Increment to get the new run number + new_num = max_num + 1 + new_run_name = f"{prefix}{new_num}" + + # Update the wandb run's name + wandb.run.name = new_run_name + print("Updated run name to:", wandb.run.name) self.writer = WandbSummaryWriter( log_dir=self.log_dir, flush_secs=10, cfg=self.cfg ) - - import wandb + update_run_name_with_sequence(prefix=self.cfg["wandb_project"]) wandb.gym.monitor() self.writer.log_config( diff --git a/amp_rsl_rl/utils/__init__.py b/amp_rsl_rl/utils/__init__.py index 404527b..0ede374 100644 --- a/amp_rsl_rl/utils/__init__.py +++ b/amp_rsl_rl/utils/__init__.py @@ -10,4 +10,10 @@ from .motion_loader import AMPLoader, download_amp_dataset_from_hf from .exporter import export_policy_as_onnx -__all__ = ["Normalizer", "RunningMeanStd", "AMPLoader", "download_amp_dataset_from_hf", "export_policy_as_onnx"] +__all__ = [ + "Normalizer", + "RunningMeanStd", + "AMPLoader", + "download_amp_dataset_from_hf", + "export_policy_as_onnx", +]