Skip to content

Commit b3762a2

Browse files
carsondenisoncarsondenisonHPpre-commit-ci[bot]Borda
authored
Fixed epsilon decay in dqn example (#117)
* fixed epsilon decay in dqn example * update requirements Co-authored-by: carsondenisonHP <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 214d1d1 commit b3762a2

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

lightning_examples/reinforce-learning-DQN/.meta.yml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
title: How to train a Deep Q Network
22
author: PL team
33
created: 2021-01-31
4-
updated: 2021-06-17
4+
updated: 2021-12-03
55
license: CC BY-SA
6-
build: 2
6+
build: 1
77
tags:
88
- RL
99
description: |
@@ -13,6 +13,9 @@ description: |
1313
2. Handle unsupervised learning by using an IterableDataset where the dataset itself is constantly updated during training
1414
3. Each training step carries has the agent taking an action in the environment and storing the experience in the IterableDataset
1515
requirements:
16+
- torchvision<=0.10
17+
- torchaudio<=0.10
18+
- torchtext<=0.10
1619
- gym
1720
accelerator:
1821
- CPU

lightning_examples/reinforce-learning-DQN/dqn.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# %%
22
import os
33
from collections import OrderedDict, deque, namedtuple
4-
from typing import List, Tuple
4+
from typing import Iterator, List, Tuple
55

66
import gym
77
import numpy as np
@@ -99,7 +99,7 @@ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
9999
self.buffer = buffer
100100
self.sample_size = sample_size
101101

102-
def __iter__(self) -> Tuple:
102+
def __iter__(self) -> Iterator[Tuple]:
103103
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
104104
for i in range(len(dones)):
105105
yield states[i], actions[i], rewards[i], dones[i], new_states[i]
@@ -247,7 +247,7 @@ def populate(self, steps: int = 1000) -> None:
247247
Args:
248248
steps: number of random steps to populate the buffer with
249249
"""
250-
for i in range(steps):
250+
for _ in range(steps):
251251
self.agent.play_step(self.net, epsilon=1.0)
252252

253253
def forward(self, x: Tensor) -> Tensor:
@@ -273,7 +273,7 @@ def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
273273
"""
274274
states, actions, rewards, dones, next_states = batch
275275

276-
state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
276+
state_action_values = self.net(states).gather(1, actions.long().unsqueeze(-1)).squeeze(-1)
277277

278278
with torch.no_grad():
279279
next_state_values = self.target_net(next_states).max(1)[0]
@@ -284,6 +284,11 @@ def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
284284

285285
return nn.MSELoss()(state_action_values, expected_state_action_values)
286286

287+
def get_epsilon(self, start: int, end: int, frames: int) -> float:
288+
if self.global_step > frames:
289+
return end
290+
return start - (self.global_step / frames) * (start - end)
291+
287292
def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
288293
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss
289294
based on the minibatch recieved.
@@ -296,14 +301,13 @@ def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
296301
Training loss and log metrics
297302
"""
298303
device = self.get_device(batch)
299-
epsilon = max(
300-
self.hparams.eps_end,
301-
self.hparams.eps_start - self.global_step + 1 / self.hparams.eps_last_frame,
302-
)
304+
epsilon = self.get_epsilon(self.hparams.eps_start, self.hparams.eps_end, self.hparams.eps_last_frame)
305+
self.log("epsilon", epsilon)
303306

304307
# step through environment with agent
305308
reward, done = self.agent.play_step(self.net, epsilon, device)
306309
self.episode_reward += reward
310+
self.log("episode reward", self.episode_reward)
307311

308312
# calculates training loss
309313
loss = self.dqn_mse_loss(batch)

0 commit comments

Comments
 (0)