1
1
# %%
2
2
import os
3
3
from collections import OrderedDict , deque , namedtuple
4
- from typing import List , Tuple
4
+ from typing import Iterator , List , Tuple
5
5
6
6
import gym
7
7
import numpy as np
@@ -99,7 +99,7 @@ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
99
99
self .buffer = buffer
100
100
self .sample_size = sample_size
101
101
102
- def __iter__ (self ) -> Tuple :
102
+ def __iter__ (self ) -> Iterator [ Tuple ] :
103
103
states , actions , rewards , dones , new_states = self .buffer .sample (self .sample_size )
104
104
for i in range (len (dones )):
105
105
yield states [i ], actions [i ], rewards [i ], dones [i ], new_states [i ]
@@ -247,7 +247,7 @@ def populate(self, steps: int = 1000) -> None:
247
247
Args:
248
248
steps: number of random steps to populate the buffer with
249
249
"""
250
- for i in range (steps ):
250
+ for _ in range (steps ):
251
251
self .agent .play_step (self .net , epsilon = 1.0 )
252
252
253
253
def forward (self , x : Tensor ) -> Tensor :
@@ -273,7 +273,7 @@ def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
273
273
"""
274
274
states , actions , rewards , dones , next_states = batch
275
275
276
- state_action_values = self .net (states ).gather (1 , actions .unsqueeze (- 1 )).squeeze (- 1 )
276
+ state_action_values = self .net (states ).gather (1 , actions .long (). unsqueeze (- 1 )).squeeze (- 1 )
277
277
278
278
with torch .no_grad ():
279
279
next_state_values = self .target_net (next_states ).max (1 )[0 ]
@@ -284,6 +284,11 @@ def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
284
284
285
285
return nn .MSELoss ()(state_action_values , expected_state_action_values )
286
286
287
+ def get_epsilon (self , start : int , end : int , frames : int ) -> float :
288
+ if self .global_step > frames :
289
+ return end
290
+ return start - (self .global_step / frames ) * (start - end )
291
+
287
292
def training_step (self , batch : Tuple [Tensor , Tensor ], nb_batch ) -> OrderedDict :
288
293
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss
289
294
based on the minibatch recieved.
@@ -296,14 +301,13 @@ def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
296
301
Training loss and log metrics
297
302
"""
298
303
device = self .get_device (batch )
299
- epsilon = max (
300
- self .hparams .eps_end ,
301
- self .hparams .eps_start - self .global_step + 1 / self .hparams .eps_last_frame ,
302
- )
304
+ epsilon = self .get_epsilon (self .hparams .eps_start , self .hparams .eps_end , self .hparams .eps_last_frame )
305
+ self .log ("epsilon" , epsilon )
303
306
304
307
# step through environment with agent
305
308
reward , done = self .agent .play_step (self .net , epsilon , device )
306
309
self .episode_reward += reward
310
+ self .log ("episode reward" , self .episode_reward )
307
311
308
312
# calculates training loss
309
313
loss = self .dqn_mse_loss (batch )
0 commit comments