diff --git a/ai/rl/dqn/rainbow/_agent.py b/ai/rl/dqn/rainbow/_agent.py index fa7a43b..61f5f86 100644 --- a/ai/rl/dqn/rainbow/_agent.py +++ b/ai/rl/dqn/rainbow/_agent.py @@ -217,7 +217,7 @@ def _get_distributional_loss( self._config.n_atoms, self._config.v_max, self._config.v_min, - self._config.discount_factor, + self.discount_factor, ) def _get_td_loss( @@ -245,7 +245,7 @@ def _get_td_loss( target_values = self._target_network(next_states).max(dim=1).values return self._td_loss( current_q_values, - rewards + ~terminals * self._config.discount_factor * target_values, + rewards + ~terminals * self.discount_factor* target_values, ) @property diff --git a/ai/rl/dqn/rainbow/trainers/seed/_actor.py b/ai/rl/dqn/rainbow/trainers/seed/_actor.py index 29434fc..81f9c6d 100644 --- a/ai/rl/dqn/rainbow/trainers/seed/_actor.py +++ b/ai/rl/dqn/rainbow/trainers/seed/_actor.py @@ -133,6 +133,7 @@ def run(self): if first and logging_client is not None: value = float(_get_values(mask.unsqueeze(0), model_output.unsqueeze(0), use_distributional, z)[0]) logging_client.log("Actor/Start value", value) + first = False next_state, reward, terminal, truncated, _ = env.step(action) terminal = terminal or truncated