Skip to content

Commit 5e776c8

Browse files
authored
fix(trainer): skip controller-side CUDA sync in single-controller mode (#1377)
* fix(trainer): skip controller-side CUDA sync in single-controller mode * fix(trainer): also skip controller-side barrier in single-controller mode
1 parent 0cfcd04 commit 5e776c8

1 file changed

Lines changed: 13 additions & 9 deletions

File tree

areal/trainer/rl_trainer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def _save_hf(self, epoch: int, epoch_step: int, global_step: int):
11131113
name="critic",
11141114
)
11151115
# Async mode: synchronization handled by AsyncCheckpointManager
1116-
if not self.saver.is_async:
1116+
if not self.saver.is_async and not is_single_controller():
11171117
dist.barrier(group=self.actor.cpu_group)
11181118
current_platform.synchronize()
11191119

@@ -1139,8 +1139,9 @@ def _save_recover_checkpoint(self, epoch: int, epoch_step: int, global_step: int
11391139
processor=self.processor,
11401140
)
11411141

1142-
dist.barrier(group=self.actor.cpu_group)
1143-
current_platform.synchronize()
1142+
if not is_single_controller():
1143+
dist.barrier(group=self.actor.cpu_group)
1144+
current_platform.synchronize()
11441145

11451146
def _evaluate_fn(
11461147
self,
@@ -1161,8 +1162,9 @@ def _evaluate_fn(
11611162
cnt += 1
11621163
self.eval_rollout.wait(cnt, timeout=None)
11631164

1164-
dist.barrier(group=self.actor.cpu_group)
1165-
current_platform.synchronize()
1165+
if not is_single_controller():
1166+
dist.barrier(group=self.actor.cpu_group)
1167+
current_platform.synchronize()
11661168

11671169
def _evaluate(
11681170
self,
@@ -1188,8 +1190,9 @@ def _evaluate(
11881190
epoch_step,
11891191
global_step,
11901192
)
1191-
dist.barrier(group=self.actor.cpu_group)
1192-
current_platform.synchronize()
1193+
if not is_single_controller():
1194+
dist.barrier(group=self.actor.cpu_group)
1195+
current_platform.synchronize()
11931196

11941197
def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int):
11951198
# Upload statistics to the logger (e.g., wandb)
@@ -1199,8 +1202,9 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int
11991202
stats.update(self.eval_rollout.export_stats())
12001203
self.stats_logger.commit(epoch, epoch_step, global_step, stats)
12011204

1202-
dist.barrier(group=self.actor.cpu_group)
1203-
current_platform.synchronize()
1205+
if not is_single_controller():
1206+
dist.barrier(group=self.actor.cpu_group)
1207+
current_platform.synchronize()
12041208

12051209
def _validate_cfg(self):
12061210
"""validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models."""

0 commit comments

Comments
 (0)