diff --git a/bsuite/baselines/tf/boot_dqn/agent.py b/bsuite/baselines/tf/boot_dqn/agent.py index e423e984..095b9c7c 100644 --- a/bsuite/baselines/tf/boot_dqn/agent.py +++ b/bsuite/baselines/tf/boot_dqn/agent.py @@ -120,7 +120,6 @@ def _step(self, transitions: Sequence[tf.Tensor]): loss = tf.reduce_mean(tf.stack(losses)) gradients = tape.gradient(loss, variables) - self._total_steps.assign_add(1) self._optimizer.apply(gradients, variables) # Periodically update the target network. @@ -132,6 +131,7 @@ def _step(self, transitions: Sequence[tf.Tensor]): def select_action(self, timestep: dm_env.TimeStep) -> base.Action: """Select values via Thompson sampling, then use epsilon-greedy policy.""" + self._total_steps.assign_add(1) if self._rng.rand() < self._epsilon_fn(self._total_steps.numpy()): return self._rng.randint(self._num_actions)