diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 73b98c21f..33f9d67ef 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -1113,7 +1113,7 @@ def _save_hf(self, epoch: int, epoch_step: int, global_step: int): name="critic", ) # Async mode: synchronization handled by AsyncCheckpointManager - if not self.saver.is_async: + if not self.saver.is_async and not is_single_controller(): dist.barrier(group=self.actor.cpu_group) current_platform.synchronize() @@ -1139,8 +1139,9 @@ def _save_recover_checkpoint(self, epoch: int, epoch_step: int, global_step: int processor=self.processor, ) - dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + dist.barrier(group=self.actor.cpu_group) + current_platform.synchronize() def _evaluate_fn( self, @@ -1161,8 +1162,9 @@ def _evaluate_fn( cnt += 1 self.eval_rollout.wait(cnt, timeout=None) - dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + dist.barrier(group=self.actor.cpu_group) + current_platform.synchronize() def _evaluate( self, @@ -1188,8 +1190,9 @@ def _evaluate( epoch_step, global_step, ) - dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + dist.barrier(group=self.actor.cpu_group) + current_platform.synchronize() def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int): # Upload statistics to the logger (e.g., wandb) @@ -1199,8 +1202,9 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int stats.update(self.eval_rollout.export_stats()) self.stats_logger.commit(epoch, epoch_step, global_step, stats) - dist.barrier(group=self.actor.cpu_group) - current_platform.synchronize() + if not is_single_controller(): + dist.barrier(group=self.actor.cpu_group) + current_platform.synchronize() def _validate_cfg(self): """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models."""