Skip to content

Commit fbf482e

Browse files
authored
Fix ActorCritic family (#133)
* fix reinforce val loss calc as mean * make ac loss coef calc consistent * update val_loss_coef to 0.01 by default * add tmp ac config to run * remove tmp config * add default coef values to prevent inheritance breakage * add val loss coef to PPO * update ac family search range * add entropy * reinforce back to sum per textbook * monitor log 4 digits loss * Changing sum to mean in reinforce policy loss calc * Update fn to new pyTorch syntax * Fixing computation graph error * remove unnecessary mean from reinforce loss * add assert_trained method * fix actorcritic class * expand sil param, check compute carry gradient, tuning needed * standardize args, restore policy_loss_coef to ac calc * compact rewrite of reinforce loss works * compact rewrite of actor critic loss works * fix PPO * use same rewrite for sil log_probs, still need fixing/tuning * add screen to linux plot * fix platform * split cmd * use xvfb wrapper * debug log * save as context manager * add xvfb, prevent linux plot crash now * fix silly SIL loss var typo. working, need turning * add sha to all spec written * update orca install, use npm. update npm build * fix plotting for linux
1 parent ca948b3 commit fbf482e

File tree

20 files changed

+269
-237
lines changed

20 files changed

+269
-237
lines changed

.circleci/config.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
name: Install apt-get packages for lab
2525
command: |
2626
sudo apt-get update
27-
sudo apt-get install -y python-numpy python-dev cmake zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig build-essential libstdc++6
27+
sudo apt-get install -y python-numpy python-dev cmake libhdf5-dev libopenblas-dev zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig build-essential libstdc++6
2828
environment:
2929
LD_PRELOAD: /usr/lib/libtcmalloc_minimal.so.4
3030

@@ -34,10 +34,9 @@ jobs:
3434
if which yarn >/dev/null; then
3535
echo "Yarn is already installed"
3636
else
37-
sudo npm install -g yarn
37+
sudo npm install --unsafe-perm=true --allow-root -g yarn [email protected] orca
3838
fi
3939
yarn install
40-
yarn global add [email protected] orca
4140
- save_cache:
4241
paths:
4342
- node_modules
@@ -100,7 +99,7 @@ jobs:
10099
name: Install apt-get packages for lab
101100
command: |
102101
sudo apt-get update
103-
sudo apt-get install -y python-numpy python-dev cmake zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig build-essential libstdc++6
102+
sudo apt-get install -y python-numpy python-dev cmake libhdf5-dev libopenblas-dev zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig build-essential libstdc++6
104103
- run: echo '. ~/miniconda3/etc/profile.d/conda.sh' >> $BASH_ENV
105104
- run:
106105
name: Run Python tests

bin/setup_macOS

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,20 @@ for item in "${hb_list[@]}"; do
2323
brew info "${item}" | grep --quiet "Not installed" && brew install "${item}"
2424
done
2525

26-
echo "--- Installing Atom and Hydrogen for interactive computing ---"
27-
if which atom >/dev/null; then
28-
echo "Atom is already installed"
29-
else
30-
brew cask install atom
31-
fi
32-
if apm ls | grep " Hydrogen" >/dev/null; then
33-
echo "Hydrogen is already installed"
34-
else
35-
apm install hydrogen
36-
fi
37-
3826
echo "--- Installing NodeJS Lab interface ---"
3927
if which node >/dev/null; then
4028
echo "NodeJS is already installed"
4129
else
4230
brew install node
4331
brew install yarn
32+
npm install --unsafe-perm=true --allow-root -g [email protected] orca
4433
fi
4534

4635
echo "--- Installing npm modules for Lab interface ---"
4736
if [ -d ./node_modules ]; then
4837
echo "Npm modules are already installed"
4938
else
5039
yarn install
51-
yarn global add [email protected] orca
5240
fi
5341

5442
echo "--- Installing Python for Lab backend ---"

bin/setup_ubuntu

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,24 @@ trap "exit" INT
77

88
echo "--- Installing system dependencies ---"
99
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
10-
sudo add-apt-repository -y ppa:webupd8team/atom
1110
sudo apt-get update
12-
sudo apt-get install -y cmake gcc-4.9 g++-4.9 git
13-
sudo apt-get install -y libhdf5-dev libopenblas-dev
14-
sudo apt-get install -y cmake zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig build-essential libstdc++6
15-
16-
echo "--- Installing Atom and Hydrogen for interactive computing ---"
17-
if which atom >/dev/null; then
18-
echo "Atom is already installed"
19-
else
20-
sudo apt-get install -y atom
21-
fi
22-
if apm ls | grep " Hydrogen" >/dev/null; then
23-
echo "Hydrogen is already installed"
24-
else
25-
apm install hydrogen
26-
fi
11+
sudo apt-get install -y git cmake gcc g++
12+
sudo apt-get install -y zlib1g-dev libjpeg-dev xvfb libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig build-essential libstdc++6 libhdf5-dev libopenblas-dev
2713

2814
echo "--- Installing NodeJS Lab interface ---"
2915
if which node >/dev/null; then
3016
echo "Nodejs is already installed"
3117
else
3218
curl -sL https://deb.nodesource.com/setup_8.x | sudo -E bash -
3319
sudo apt-get install -y nodejs
34-
sudo npm install -g yarn
20+
sudo npm install --unsafe-perm=true --allow-root -g yarn [email protected] orca
3521
fi
3622

3723
echo "--- Installing npm modules for Lab interface ---"
3824
if [ -d ./node_modules ]; then
3925
echo "Npm modules are already installed"
4026
else
4127
yarn install
42-
yarn global add [email protected] orca
4328
fi
4429

4530
echo "--- Installing Python for Lab backend ---"

slm_lab/agent/algorithm/actor_critic.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ActorCritic(Reinforce):
6161
"add_entropy": false,
6262
"entropy_coef": 0.01,
6363
"policy_loss_coef": 1.0,
64-
"val_loss_coef": 1.0,
64+
"val_loss_coef": 0.01,
6565
"continuous_action_clip": 2.0,
6666
"training_frequency": 1,
6767
"training_epoch": 8
@@ -87,6 +87,8 @@ def init_algorithm_params(self):
8787
explore_var_start=np.nan,
8888
explore_var_end=np.nan,
8989
explore_anneal_epi=np.nan,
90+
policy_loss_coef=1.0,
91+
val_loss_coef=1.0,
9092
))
9193
util.set_attr(self, self.algorithm_spec, [
9294
'action_pdtype',
@@ -263,9 +265,9 @@ def train_shared(self):
263265
batch = self.sample()
264266
with torch.no_grad():
265267
advs, v_targets = self.calc_advs_v_targets(batch)
266-
policy_loss = self.calc_policy_loss(advs) # from actor
268+
policy_loss = self.calc_policy_loss(batch, advs) # from actor
267269
val_loss = self.calc_val_loss(batch, v_targets) # from critic
268-
loss = self.policy_loss_coef * policy_loss + self.val_loss_coef * val_loss
270+
loss = policy_loss + val_loss
269271
self.net.training_step(loss=loss)
270272
# reset
271273
self.to_train = 0
@@ -282,9 +284,7 @@ def train_separate(self):
282284
'''
283285
if self.to_train == 1:
284286
batch = self.sample()
285-
with torch.no_grad():
286-
advs, v_targets = self.calc_advs_v_targets(batch)
287-
policy_loss = self.train_actor(advs)
287+
policy_loss = self.train_actor(batch)
288288
val_loss = self.train_critic(batch)
289289
loss = val_loss + abs(policy_loss)
290290
# reset
@@ -295,9 +295,11 @@ def train_separate(self):
295295
self.last_loss = loss.item()
296296
return self.last_loss
297297

298-
def train_actor(self, advs):
298+
def train_actor(self, batch):
299299
'''Trains the actor when the actor and critic are separate networks'''
300-
policy_loss = self.calc_policy_loss(advs)
300+
with torch.no_grad():
301+
advs, _v_targets = self.calc_advs_v_targets(batch)
302+
policy_loss = self.calc_policy_loss(batch, advs)
301303
self.net.training_step(loss=policy_loss)
302304
return policy_loss
303305

@@ -314,15 +316,14 @@ def train_critic(self, batch):
314316
val_loss = total_val_loss / self.training_epoch
315317
return val_loss
316318

317-
def calc_policy_loss(self, advs):
319+
def calc_policy_loss(self, batch, advs):
318320
'''Calculate the actor's policy loss'''
319321
assert len(self.body.log_probs) == len(advs), f'{len(self.body.log_probs)} vs {len(advs)}'
320-
log_probs = torch.tensor(self.body.log_probs, requires_grad=True)
321-
entropies = torch.tensor(self.body.entropies, requires_grad=True)
322+
log_probs = torch.stack(self.body.log_probs)
323+
policy_loss = - self.policy_loss_coef * log_probs * advs
322324
if self.add_entropy:
323-
policy_loss = (- log_probs * advs) - self.entropy_coef * entropies
324-
else:
325-
policy_loss = - log_probs * advs
325+
entropies = torch.stack(self.body.entropies)
326+
policy_loss += (-self.entropy_coef * entropies)
326327
policy_loss = torch.mean(policy_loss)
327328
if torch.cuda.is_available() and self.net.gpu:
328329
policy_loss = policy_loss.cuda()
@@ -334,7 +335,7 @@ def calc_val_loss(self, batch, v_targets):
334335
v_targets = v_targets.unsqueeze(dim=-1)
335336
v_preds = self.calc_v(batch['states'], evaluate=False).unsqueeze_(dim=-1)
336337
assert v_preds.shape == v_targets.shape
337-
val_loss = self.net.loss_fn(v_preds, v_targets)
338+
val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets)
338339
if torch.cuda.is_available() and self.net.gpu:
339340
val_loss = val_loss.cuda()
340341
logger.debug(f'Critic value loss: {val_loss:.2f}')

slm_lab/agent/algorithm/ppo.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def init_algorithm_params(self):
6767
explore_var_start=np.nan,
6868
explore_var_end=np.nan,
6969
explore_anneal_epi=np.nan,
70+
val_loss_coef=1.0,
7071
))
7172
util.set_attr(self, self.algorithm_spec, [
7273
'action_pdtype',
@@ -80,6 +81,7 @@ def init_algorithm_params(self):
8081
'lam',
8182
'clip_eps',
8283
'entropy_coef',
84+
'val_loss_coef',
8385
'training_frequency', # horizon
8486
'training_epoch',
8587
])
@@ -90,6 +92,8 @@ def init_algorithm_params(self):
9092
self.action_policy_update = getattr(policy_util, self.action_policy_update)
9193
for body in self.agent.nanflat_body_a:
9294
body.explore_var = self.explore_var_start
95+
# PPO uses GAE
96+
self.calc_advs_v_targets = self.calc_gae_advs_v_targets
9397

9498
@lab_api
9599
def init_nets(self):
@@ -111,20 +115,20 @@ def calc_log_probs(self, batch, use_old_net=False):
111115
# get ActionPD, don't append to state_buffer
112116
ActionPD, _pdparam, _body = policy_util.init_action_pd(states[0].cpu().numpy(), self, self.body, append=False)
113117
# construct log_probs for each state-action
114-
pdparams = self.calc_pdparam(states)
118+
pdparams = self.calc_pdparam(states, evaluate=False)
115119
log_probs = []
116120
for idx, pdparam in enumerate(pdparams):
117121
_action, action_pd = policy_util.sample_action_pd(ActionPD, pdparam, self.body)
118122
log_prob = action_pd.log_prob(actions[idx])
119123
log_probs.append(log_prob)
120-
log_probs = torch.tensor(log_probs)
124+
log_probs = torch.stack(log_probs)
121125
if use_old_net:
122126
# swap back
123127
self.old_net = self.net
124128
self.net = self.tmp_net
125129
return log_probs
126130

127-
def calc_loss(self, batch):
131+
def calc_policy_loss(self, batch, advs):
128132
'''
129133
The PPO loss function (subscript t is omitted)
130134
L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]
@@ -133,35 +137,34 @@ def calc_loss(self, batch):
133137
1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
134138
where ratio = pi(a|s) / pi_old(a|s)
135139
136-
2. L^VF = E[ (V(s_t) - V^target)^2 ]
140+
2. L^VF = E[ mse(V(s_t), V^target) ]
137141
138142
3. S = E[ entropy ]
139143
'''
140144
# decay clip_eps by episode
141145
clip_eps = policy_util._linear_decay(self.clip_eps, 0.1 * self.clip_eps, self.clip_eps_anneal_epi, self.body.env.clock.get('epi'))
142146

143-
with torch.no_grad():
144-
adv_targets, v_targets = self.calc_gae_advs_v_targets(batch)
145-
146147
# L^CLIP
147148
log_probs = self.calc_log_probs(batch, use_old_net=False)
148149
old_log_probs = self.calc_log_probs(batch, use_old_net=True)
149150
assert log_probs.shape == old_log_probs.shape
150-
assert adv_targets.shape == log_probs.shape
151+
assert advs.shape == log_probs.shape
151152
ratios = torch.exp(log_probs - old_log_probs)
152-
sur_1 = ratios * adv_targets
153-
sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * adv_targets
153+
sur_1 = ratios * advs
154+
sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
154155
# flip sign because need to maximize
155156
clip_loss = -torch.mean(torch.min(sur_1, sur_2))
156157

157-
# L^VF
158-
val_loss = self.calc_val_loss(batch, v_targets) # from critic
158+
# L^VF (inherit from ActorCritic)
159159

160160
# S entropy bonus
161-
ent_mean = torch.mean(torch.tensor(self.body.entropies))
162-
ent_penalty = -self.entropy_coef * ent_mean
163-
loss = clip_loss + val_loss + ent_penalty
164-
return loss
161+
ent_penalty = 0
162+
for e in self.body.entropies:
163+
ent_penalty += (-self.entropy_coef * e)
164+
ent_penalty /= len(self.body.entropies)
165+
166+
policy_loss = clip_loss + ent_penalty
167+
return policy_loss
165168

166169
def train_shared(self):
167170
'''
@@ -171,8 +174,13 @@ def train_shared(self):
171174
batch = self.sample()
172175
total_loss = torch.tensor(0.0)
173176
for _ in range(self.training_epoch):
174-
loss = self.calc_loss(batch)
175-
self.net.training_step(loss=loss)
177+
with torch.no_grad():
178+
advs, v_targets = self.calc_advs_v_targets(batch)
179+
policy_loss = self.calc_policy_loss(batch, advs) # from actor
180+
val_loss = self.calc_val_loss(batch, v_targets) # from critic
181+
loss = policy_loss + val_loss
182+
# retain for entropies etc.
183+
self.net.training_step(loss=loss, retain_graph=True)
176184
total_loss += loss.cpu()
177185
loss = total_loss / self.training_epoch
178186
net_util.copy(self.net, self.old_net)
@@ -190,15 +198,9 @@ def train_separate(self):
190198
'''
191199
if self.to_train == 1:
192200
batch = self.sample()
193-
total_loss = torch.tensor(0.0)
194-
for _ in range(self.training_epoch):
195-
loss = self.calc_loss(batch)
196-
# to reuse loss for critic
197-
self.net.training_step(loss=loss, retain_graph=True)
198-
# critic.optim.step using the same loss
199-
self.critic.training_step(loss=loss)
200-
total_loss += loss.cpu()
201-
loss = total_loss / self.training_epoch
201+
policy_loss = self.train_actor(batch)
202+
val_loss = self.train_critic(batch)
203+
loss = val_loss + abs(policy_loss)
202204
net_util.copy(self.net, self.old_net)
203205
net_util.copy(self.critic, self.old_critic)
204206
# reset
@@ -208,3 +210,15 @@ def train_separate(self):
208210
logger.debug(f'Loss: {loss:.2f}')
209211
self.last_loss = loss.item()
210212
return self.last_loss
213+
214+
def train_actor(self, batch):
215+
'''Trains the actor when the actor and critic are separate networks'''
216+
total_policy_loss = torch.tensor(0.0)
217+
for _ in range(self.training_epoch):
218+
with torch.no_grad():
219+
advs, _v_targets = self.calc_advs_v_targets(batch)
220+
policy_loss = self.calc_policy_loss(batch, advs)
221+
# retain for entropies etc.
222+
self.net.training_step(loss=policy_loss, retain_graph=True)
223+
val_loss = total_policy_loss / self.training_epoch
224+
return policy_loss

slm_lab/agent/algorithm/reinforce.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,14 @@ def calc_policy_loss(self, batch):
159159
adv_std += 1e-08
160160
advs = (advs - advs.mean()) / adv_std
161161
assert len(self.body.log_probs) == len(advs), f'{len(self.body.log_probs)} vs {len(advs)}'
162-
policy_loss = torch.tensor(0.0)
162+
log_probs = torch.stack(self.body.log_probs)
163+
policy_loss = - log_probs * advs
164+
if self.add_entropy:
165+
entropies = torch.stack(self.body.entropies)
166+
policy_loss += (-self.entropy_coef * entropies)
167+
policy_loss = torch.sum(policy_loss)
163168
if torch.cuda.is_available() and self.net.gpu:
164-
advs = advs.cuda()
165169
policy_loss = policy_loss.cuda()
166-
for logp, adv, ent in zip(self.body.log_probs, advs, self.body.entropies):
167-
if self.add_entropy:
168-
policy_loss += (-logp * adv - self.entropy_coef * ent).cpu()
169-
else:
170-
policy_loss += (-logp * adv).cpu()
171170
return policy_loss
172171

173172
@lab_api

0 commit comments

Comments
 (0)