@@ -1113,7 +1113,7 @@ def _save_hf(self, epoch: int, epoch_step: int, global_step: int):
11131113 name = "critic" ,
11141114 )
11151115 # Async mode: synchronization handled by AsyncCheckpointManager
1116- if not self .saver .is_async :
1116+ if not self .saver .is_async and not is_single_controller () :
11171117 dist .barrier (group = self .actor .cpu_group )
11181118 current_platform .synchronize ()
11191119
@@ -1139,8 +1139,9 @@ def _save_recover_checkpoint(self, epoch: int, epoch_step: int, global_step: int
11391139 processor = self .processor ,
11401140 )
11411141
1142- dist .barrier (group = self .actor .cpu_group )
1143- current_platform .synchronize ()
1142+ if not is_single_controller ():
1143+ dist .barrier (group = self .actor .cpu_group )
1144+ current_platform .synchronize ()
11441145
11451146 def _evaluate_fn (
11461147 self ,
@@ -1161,8 +1162,9 @@ def _evaluate_fn(
11611162 cnt += 1
11621163 self .eval_rollout .wait (cnt , timeout = None )
11631164
1164- dist .barrier (group = self .actor .cpu_group )
1165- current_platform .synchronize ()
1165+ if not is_single_controller ():
1166+ dist .barrier (group = self .actor .cpu_group )
1167+ current_platform .synchronize ()
11661168
11671169 def _evaluate (
11681170 self ,
@@ -1188,8 +1190,9 @@ def _evaluate(
11881190 epoch_step ,
11891191 global_step ,
11901192 )
1191- dist .barrier (group = self .actor .cpu_group )
1192- current_platform .synchronize ()
1193+ if not is_single_controller ():
1194+ dist .barrier (group = self .actor .cpu_group )
1195+ current_platform .synchronize ()
11931196
11941197 def _export_and_commit_stats (self , epoch : int , epoch_step : int , global_step : int ):
11951198 # Upload statistics to the logger (e.g., wandb)
@@ -1199,8 +1202,9 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int
11991202 stats .update (self .eval_rollout .export_stats ())
12001203 self .stats_logger .commit (epoch , epoch_step , global_step , stats )
12011204
1202- dist .barrier (group = self .actor .cpu_group )
1203- current_platform .synchronize ()
1205+ if not is_single_controller ():
1206+ dist .barrier (group = self .actor .cpu_group )
1207+ current_platform .synchronize ()
12041208
12051209 def _validate_cfg (self ):
12061210 """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models."""
0 commit comments