Skip to content

Commit 3c26945

Browse files
HsuehErh.ChangHsuehErh.Chang
HsuehErh.Chang
authored and
HsuehErh.Chang
committed
=DQfD
1 parent cf6306d commit 3c26945

24 files changed

+388
-23
lines changed

DQNfromDemo/DQfD.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import sys
2+
from os import path
3+
4+
local = path.abspath(__file__)
5+
root = path.dirname(path.dirname(local))
6+
if root not in sys.path:
7+
sys.path.append(root)
8+
9+
from DQNwithNoisyNet import DQN_NoisyNet
10+
import torch
11+
12+
13+
class DeepQL(DQN_NoisyNet.DeepQL):
14+
def __init__(self, *args,lambda1=1.0,lambda2=1.0,lambda3=1e-5, **kwargs,):
15+
super().__init__(*args, **kwargs,L2=lambda3)
16+
self.ed = 1.0 # bonus for demonstration
17+
self.ea = 0.001
18+
self.margin = 0.8
19+
self.lambda1 = lambda1 # n-step return
20+
self.lambda2 = lambda2 # supervised loss
21+
self.lambda3 = lambda3 # L2
22+
self.replay.e = 0
23+
24+
def storeTransition(self, s, a, r, s_, done, isdemo):
25+
s = torch.Tensor(s)
26+
s_ = torch.Tensor(s_)
27+
error = self.calcError((s, a, r, s_, done))
28+
e = self.ed if isdemo else self.ea
29+
self.store((s, a, r, s_, done, isdemo), error + e)
30+
31+
def JE(self, samples):
32+
loss = torch.tensor(0.0)
33+
for s, a, *_, isdemo in samples:
34+
if not isdemo:
35+
continue
36+
QE = self.net(s, torch.Tensor(a))[0]
37+
Q = self.net(s, torch.Tensor(self.findMaxA(s)))[0]
38+
Q = QE if Q + self.margin < QE else Q
39+
loss += self.lambda2 * (Q - QE)
40+
return loss / self.mbsize
41+
42+
def update(self):
43+
self.opt.zero_grad()
44+
samples, idxs, IS = self.sample()
45+
if self.noisy:
46+
self.net.sample() # for choosing action
47+
maxA = [self.findMaxA(s[3]) for s in samples]
48+
maxA = torch.Tensor(maxA)
49+
s, a, *_, isdemo = zip(*samples)
50+
s = torch.stack(s)
51+
a = torch.Tensor(a)
52+
if self.noisy:
53+
self.net.sample() # for prediction
54+
self.net2.sample() # for estimating Q
55+
predict = self.net(s, a)[:, 0]
56+
look_ahead = [r if done else r + self.gamma * self.net2(s_, maxA[i]) for i, (s, a, r, s_, done, isdemo) in
57+
enumerate(samples)]
58+
target = torch.Tensor(look_ahead)
59+
60+
errors, ls = self.loss(predict, target, IS)
61+
if self.noisy:
62+
self.net.sample()
63+
ls += self.JE(samples)
64+
ls.backward()
65+
for i in range(self.mbsize):
66+
e = self.ed if isdemo[i] else self.ea
67+
self.replay.update(idxs[i], errors[i] + e)
68+
69+
self.opt.step()
70+
if self.c >= self.C:
71+
self.c = 0
72+
self.net2.load_state_dict(self.net.state_dict())
73+
self.net2.eval()
74+
else:
75+
self.c += 1
76+
77+
78+
class DeepQLv2(DQN_NoisyNet.DeepQLv2):
79+
def __init__(self, *args,lambda1=1.0,lambda2=1.0,lambda3=1e-5, **kwargs,):
80+
super().__init__(*args, **kwargs,L2=lambda3)
81+
self.ed = 1.0 # bonus for demonstration
82+
self.ea = 0.001
83+
self.margin = 0.8
84+
self.lambda1 = lambda1 # n-step return
85+
self.lambda2 = lambda2 # supervised loss
86+
self.lambda3 = lambda3 # L2
87+
self.replay.e = 0
88+
89+
def storeTransition(self, s, a, r, s_, done, isdemo):
90+
s = torch.Tensor(s)
91+
s_ = torch.Tensor(s_)
92+
error = self.calcError((s, a, r, s_, done))
93+
e = self.ed if isdemo else self.ea
94+
self.store((s, a, r, s_, done, isdemo), error + e)
95+
96+
def JE(self, samples):
97+
loss = torch.tensor(0.0)
98+
for s, a, *_, isdemo in samples:
99+
if not isdemo:
100+
continue
101+
QE = self.net(s)[a[0]]
102+
Q = max(self.net(s))
103+
Q = QE if Q + self.margin < QE else Q
104+
loss += self.lambda2 * (Q - QE)
105+
return loss / self.mbsize
106+
107+
def update(self):
108+
self.opt.zero_grad()
109+
110+
samples, idxs, IS = self.sample()
111+
if self.noisy:
112+
self.net.sample() # for choosing action
113+
maxA = [self.findMaxA(s[3]) for s in samples]
114+
s, a, *_, isdemo = zip(*samples)
115+
s = torch.stack(s)
116+
if self.noisy:
117+
self.net.sample() # for prediction
118+
self.net2.sample() # for estimating Q
119+
predict = [self.net(s[i])[a[i][0]] for i in range(self.mbsize)]
120+
look_ahead = [r if done else r + self.gamma * self.net2(s_)[maxA[i][0]] for i, (s, a, r, s_, done, isdemo) in
121+
enumerate(samples)]
122+
target = torch.Tensor(look_ahead)
123+
124+
errors, ls = self.loss(predict, target, IS)
125+
if self.noisy:
126+
self.net.sample()
127+
ls += self.JE(samples)
128+
ls.backward()
129+
for i in range(self.mbsize):
130+
e = self.ed if isdemo[i] else self.ea
131+
self.replay.update(idxs[i], errors[i] + e)
132+
133+
self.opt.step()
134+
if self.c >= self.C:
135+
self.c = 0
136+
self.net2.load_state_dict(self.net.state_dict())
137+
self.net2.eval()
138+
else:
139+
self.c += 1

DQNfromDemo/Test/CartPole.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from os import path
2+
import sys
3+
local=path.abspath(__file__)
4+
root=path.dirname(path.dirname(path.dirname(local)))
5+
if root not in sys.path:
6+
sys.path.append(root)
7+
8+
import gym
9+
import torch
10+
import matplotlib.pyplot as plt
11+
import math
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
from DQNwithNoisyNet.NoisyLayer import NoisyLinear
15+
from DQNfromDemo import DQfD
16+
from operator import methodcaller
17+
import json
18+
19+
20+
class Net(nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
self.fc1_s = nn.Linear(4, 40)
24+
self.fc1_a = nn.Linear(1, 40)
25+
self.fc2 = nn.Linear(40, 1)
26+
27+
def forward(self, s, a):
28+
x = self.fc1_s(s) + self.fc1_a(a)
29+
x = F.relu(x)
30+
x = self.fc2(x)
31+
return x
32+
33+
34+
class Net2(nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
self.fc1 = nn.Linear(4, 40)
38+
self.fc2 = nn.Linear(40, 2)
39+
40+
def forward(self, s):
41+
x = self.fc1(s)
42+
x = F.relu(x)
43+
x = self.fc2(x)
44+
return x
45+
46+
47+
class NoisyNet(nn.Module):
48+
def __init__(self):
49+
super().__init__()
50+
self.fc1_s = NoisyLinear(4, 40)
51+
self.fc1_a = NoisyLinear(1, 40)
52+
self.fc2 = NoisyLinear(40, 1)
53+
54+
def forward(self, s, a):
55+
x = self.fc1_s(s) + self.fc1_a(a)
56+
x = F.relu(x)
57+
x = self.fc2(x)
58+
return x
59+
60+
def sample(self):
61+
for layer in self.children():
62+
if hasattr(layer, "sample"):
63+
layer.sample()
64+
65+
66+
class NoisyNet2(nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
self.fc1 = NoisyLinear(4, 40)
70+
self.fc2 = NoisyLinear(40, 2)
71+
72+
def forward(self, s):
73+
x = self.fc1(s)
74+
x = F.relu(x)
75+
x = self.fc2(x)
76+
return x
77+
78+
def sample(self):
79+
for layer in self.children():
80+
if hasattr(layer, "sample"):
81+
layer.sample()
82+
83+
84+
if __name__ == "__main__":
85+
env = gym.make('CartPole-v1')
86+
s = env.reset()
87+
A = [[0], [1]]
88+
dqn = DQfD.DeepQL(Net, noisy=False, lr=0.005, gamma=1, actionFinder=lambda x: A,N=5000)
89+
process = []
90+
randomness = []
91+
epoch = 100
92+
eps_start = 0.05
93+
eps_end = 0.95
94+
N = 1 - eps_start
95+
lam = -math.log((1 - eps_end) / N) / epoch
96+
total = 0
97+
count = 0 # successful count
98+
with open("CartPoleDemo.txt","r") as file:
99+
data=json.load(file)
100+
for k,v in data.items():
101+
for s,a,r,s_,done in v:
102+
dqn.storeTransition(s,a,r,s_,done,True)
103+
for i in range(1000):
104+
if i % 100 == 0:
105+
print("pretraining:",i)
106+
dqn.update()
107+
108+
109+
for i in range(epoch):
110+
print(i)
111+
dqn.eps = 1 - N * math.exp(-lam * i)
112+
count = count + 1 if total >= 500 else 0
113+
if count >= 2:
114+
dqn.eps = 1
115+
break
116+
total = 0
117+
while True:
118+
a = dqn.act(s)
119+
s_, r, done, _ = env.step(a[0])
120+
total += r
121+
r = -1 if done and total < 500 else 0.002
122+
dqn.storeTransition(s, a, r, s_, done,False)
123+
dqn.update()
124+
s = s_
125+
if done:
126+
s = env.reset()
127+
print('total:', total)
128+
process.append(total)
129+
break
130+
131+
total = 0
132+
s = env.reset()
133+
dqn.eps = 1
134+
while True:
135+
a = dqn.act(s)[0]
136+
s, r, done, _ = env.step(a)
137+
total += 1
138+
env.render()
139+
if done:
140+
s = env.reset()
141+
print(total)
142+
total = 0
143+
144+
env.close()

DQNfromDemo/Test/CartPoleDemo.txt

+1
Large diffs are not rendered by default.
File renamed without changes.

DQNfromDemo/__init__.pyc

121 Bytes
Binary file not shown.
4.94 KB
Binary file not shown.
160 Bytes
Binary file not shown.

DQN_NoisyNet/DQN_NoisyNet.py DQNwithNoisyNet/DQN_NoisyNet.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33
from torch import optim
44
import torch
55
import math
6-
from .prioritized_memory import Memory, WeightedMSE
76

8-
#(s,a) => Q(s,a)
7+
if __package__:
8+
from .prioritized_memory import Memory, WeightedMSE
9+
else:
10+
from prioritized_memory import Memory, WeightedMSE
11+
12+
13+
# (s,a) => Q(s,a)
914
class DeepQL:
1015
def __init__(self, Net, noisy=True, eps=0.9, lr=5e-3, gamma=0.9, mbsize=20, C=100, N=500, L2=0, actionFinder=None):
1116
self.exp = []
@@ -17,7 +22,7 @@ def __init__(self, Net, noisy=True, eps=0.9, lr=5e-3, gamma=0.9, mbsize=20, C=10
1722
self.net2 = Net()
1823
self.net2.load_state_dict(self.net.state_dict())
1924
self.net2.eval()
20-
self.C = C #for target replacement
25+
self.C = C # for target replacement
2126
self.c = 0
2227
self.replay = Memory(capacity=N)
2328
self.loss = WeightedMSE()
@@ -108,7 +113,7 @@ def update(self):
108113
self.c += 1
109114

110115

111-
#s => Q[s,a1], Q[s,a2]...
116+
# s => Q[s,a1], Q[s,a2]...
112117
class DeepQLv2:
113118
def __init__(self, Net, noisy=True, eps=0.9, lr=5e-3, gamma=0.9, mbsize=20, C=100, N=500, L2=0, actionFinder=None):
114119
self.exp = []
@@ -126,8 +131,8 @@ def __init__(self, Net, noisy=True, eps=0.9, lr=5e-3, gamma=0.9, mbsize=20, C=10
126131
self.eps = eps
127132
self.noisy = noisy
128133
self.actionFinder = actionFinder
129-
self.A = []
130-
# (state:tensor => Action :List[List])
134+
*_,last=self.net.children()
135+
self.A = list(range(last.out_features))
131136

132137
def act(self, state):
133138
# state:list[float] A:list[list]
@@ -152,8 +157,6 @@ def findMaxA(self, state):
152157
net = self.net
153158
net.eval()
154159
Q = net(state)
155-
if not self.A:
156-
self.A = list(range(len(Q))) #[0,1,2,3...]
157160
net.train()
158161
return [int(Q.argmax())]
159162

DQN_NoisyNet/NoisyLayer.py DQNwithNoisyNet/NoisyLayer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def reset_parameters(self,sig0):
3939
self.bias_sig.data.zero_()
4040
self.bias_sig.data = self.bias_sig.data.zero_() + sig0 / self.weight_mu.shape[1]
4141

42-
def sample(self, zero=1):
42+
def sample(self):
4343
size_in = self.in_features
4444
size_out = self.out_features
45-
noise_in = f(self.dist.sample((1, size_in))) * zero
46-
noise_out = f(self.dist.sample((1, size_out))) * zero
45+
noise_in = f(self.dist.sample((1, size_in)))
46+
noise_out = f(self.dist.sample((1, size_out)))
4747
self.weight = self.weight_mu + self.weight_sig * torch.mm(noise_out.t(), noise_in)
4848
self.bias = (self.bias_mu + self.bias_sig * noise_out).squeeze()
4949

File renamed without changes.

DQN_NoisyNet/Test/CartPole.py DQNwithNoisyNet/Test/CartPole.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
import numpy as np
1+
from os import path
2+
import sys
3+
local=path.abspath(__file__)
4+
root=path.dirname(path.dirname(path.dirname(local)))
5+
if root not in sys.path:
6+
sys.path.append(root)
7+
28
import gym
39
import torch
410
import matplotlib.pyplot as plt
511
import math
612
import torch.nn as nn
713
import torch.nn.functional as F
8-
from NoisyLayer import NoisyLinear
9-
import DQN_NoisyNet
14+
from DQNwithNoisyNet.NoisyLayer import NoisyLinear
15+
from DQNwithNoisyNet import DQN_NoisyNet
1016
from operator import methodcaller
1117

1218

@@ -79,7 +85,6 @@ def sample(self):
7985
s = env.reset()
8086
A = [[0], [1]]
8187
dqn = DQN_NoisyNet.DeepQLv2(NoisyNet2, noisy=True, lr=0.002, gamma=1, actionFinder=lambda x: A)
82-
8388
process = []
8489
randomness = []
8590
epoch = 200
@@ -121,10 +126,10 @@ def sample(self):
121126
plt.show()
122127
env.close()
123128

124-
# torch.save(dqn.net.state_dict(),"./model.txt")
129+
#torch.save(dqn.net.state_dict(),"./CartPoleExpert.txt")
125130
# dqn.eps=1
126131
total = 0
127-
# dqn.net.load_state_dict(torch.load("./model.txt"))
132+
#dqn.net.load_state_dict(torch.load("./CartPoleExpert.txt"))
128133
s = env.reset()
129134
s = torch.Tensor(s)
130135
while True:
3.4 KB
Binary file not shown.

0 commit comments

Comments
 (0)