-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsarsaLambdaAgent.py
30 lines (26 loc) · 1.03 KB
/
sarsaLambdaAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# import gym
import itertools
from collections import defaultdict
import numpy as np
from Agent import Agent
class SarsaLambdaAgent(Agent):
def __init__(self, trace_decay, gamma, nA):
self.trace_decay = trace_decay
self.gamma = gamma
self.num_actions = nA
self.Q = defaultdict(lambda: np.zeros(nA))
self.E = defaultdict(lambda: np.zeros(nA))
def update(self, state, next_state, reward, action, next_action, Nas):
type='accumulate'
delta = reward + self.gamma*self.Q[next_state][next_action] - self.Q[state][action]
self.E[state][action] += 1
alpha = 1.0/ Nas[state][action]
for s, _ in self.Q.items():
self.Q[s][:] += alpha * delta * self.E[s][:]
if type == 'accumulate':
self.E[s][:] *= self.trace_decay * self.gamma
elif type == 'replace':
if s == state:
self.E[s][:] = 1
else:
self.E[s][:] *= self.gamma * self.trace_decay