From 0f1469af609354ac45b0cfeb60dc2ebb2314245b Mon Sep 17 00:00:00 2001 From: Omar Hussein Date: Wed, 28 Aug 2024 11:27:21 -0400 Subject: [PATCH] Refactor ELO rating update logic in GhostTrainer - Extract ELO rating update logic from _process_trajectory into a new method update_elo_ratings - Improves code organization and reusability - No functional changes to the ELO update algorithm --- ml-agents/mlagents/trainers/ghost/trainer.py | 49 ++++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py index f49a643574..52e5e45643 100644 --- a/ml-agents/mlagents/trainers/ghost/trainer.py +++ b/ml-agents/mlagents/trainers/ghost/trainer.py @@ -195,26 +195,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: i.e. in asymmetric games. We assume the last reward determines the winner. :param trajectory: Trajectory. """ - if ( - trajectory.done_reached - and trajectory.all_group_dones_reached - and not trajectory.interrupted - ): - # Assumption is that final reward is >0/0/<0 for win/draw/loss - final_reward = ( - trajectory.steps[-1].reward + trajectory.steps[-1].group_reward - ) - result = 0.5 - if final_reward > 0: - result = 1.0 - elif final_reward < 0: - result = 0.0 - - change = self.controller.compute_elo_rating_changes( - self.current_elo, result - ) - self.change_current_elo(change) - self._stats_reporter.add_stat("Self-play/ELO", self.current_elo) + self.update_elo_ratings(trajectory) def advance(self) -> None: """ @@ -478,3 +459,31 @@ def subscribe_trajectory_queue( parsed_behavior_id.brain_name ] = internal_trajectory_queue self.trainer.subscribe_trajectory_queue(internal_trajectory_queue) + + def update_elo_ratings(self, trajectory: Trajectory) -> None: + """ + Updates the ELO ratings based on the outcome of an episode. + This method encapsulates the ELO update logic that was previously + part of the _process_trajectory method. + + :param trajectory: Trajectory containing the episode outcome. + """ + if ( + trajectory.done_reached + and trajectory.all_group_dones_reached + and not trajectory.interrupted + ): + final_reward = ( + trajectory.steps[-1].reward + trajectory.steps[-1].group_reward + ) + result = 0.5 + if final_reward > 0: + result = 1.0 + elif final_reward < 0: + result = 0.0 + + change = self.controller.compute_elo_rating_changes( + self.current_elo, result + ) + self.change_current_elo(change) + self._stats_reporter.add_stat("Self-play/ELO", self.current_elo)