diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py index f49a643574..52e5e45643 100644 --- a/ml-agents/mlagents/trainers/ghost/trainer.py +++ b/ml-agents/mlagents/trainers/ghost/trainer.py @@ -195,26 +195,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: i.e. in asymmetric games. We assume the last reward determines the winner. :param trajectory: Trajectory. """ - if ( - trajectory.done_reached - and trajectory.all_group_dones_reached - and not trajectory.interrupted - ): - # Assumption is that final reward is >0/0/<0 for win/draw/loss - final_reward = ( - trajectory.steps[-1].reward + trajectory.steps[-1].group_reward - ) - result = 0.5 - if final_reward > 0: - result = 1.0 - elif final_reward < 0: - result = 0.0 - - change = self.controller.compute_elo_rating_changes( - self.current_elo, result - ) - self.change_current_elo(change) - self._stats_reporter.add_stat("Self-play/ELO", self.current_elo) + self.update_elo_ratings(trajectory) def advance(self) -> None: """ @@ -478,3 +459,31 @@ def subscribe_trajectory_queue( parsed_behavior_id.brain_name ] = internal_trajectory_queue self.trainer.subscribe_trajectory_queue(internal_trajectory_queue) + + def update_elo_ratings(self, trajectory: Trajectory) -> None: + """ + Updates the ELO ratings based on the outcome of an episode. + This method encapsulates the ELO update logic that was previously + part of the _process_trajectory method. + + :param trajectory: Trajectory containing the episode outcome. + """ + if ( + trajectory.done_reached + and trajectory.all_group_dones_reached + and not trajectory.interrupted + ): + final_reward = ( + trajectory.steps[-1].reward + trajectory.steps[-1].group_reward + ) + result = 0.5 + if final_reward > 0: + result = 1.0 + elif final_reward < 0: + result = 0.0 + + change = self.controller.compute_elo_rating_changes( + self.current_elo, result + ) + self.change_current_elo(change) + self._stats_reporter.add_stat("Self-play/ELO", self.current_elo)