Skip to content

Commit 5ec2a0f

Browse files
authored
PPO-fix (#145)
* Investigating PPO crashing * Removing debugging prints * generalize calc_log_probs and refactor for SIL * improve reliability of ppo and sil loss calc * add log_prob nanguard at creation * improve logger * add computation logging * improve debug logging * use base_case_openai for test * fix SIL log_probs * fix singleton cont action separate AC output unit * fix PPO weight copy * replace clone with detach properly * revert detach to clone to fix PPO * typo * refactor log_probs to policy_util * add net arg to calc_pdparam function * add PPOSIL * refactor calc_pdparams in policy_util * fix typo
1 parent 925c1d2 commit 5ec2a0f

25 files changed

+721
-142
lines changed

slm_lab/agent/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,15 @@ def post_body_init(self):
122122
@lab_api
123123
def reset(self, state_a):
124124
'''Do agent reset per session, such as memory pointer'''
125+
logger.debug(f'Agent {self.a} reset')
125126
for (e, b), body in util.ndenumerate_nonan(self.body_a):
126127
body.memory.epi_reset(state_a[(e, b)])
127128

128129
@lab_api
129130
def act(self, state_a):
130131
'''Standard act method from algorithm.'''
131132
action_a = self.algorithm.act(state_a)
133+
logger.debug(f'Agent {self.a} act: {action_a}')
132134
return action_a
133135

134136
@lab_api
@@ -144,6 +146,7 @@ def update(self, action_a, reward_a, state_a, done_a):
144146
body.loss = loss_a[(e, b)]
145147
explore_var_a = self.algorithm.update()
146148
explore_var_a = util.guard_data_a(self, explore_var_a, 'explore_var')
149+
logger.debug(f'Agent {self.a} loss: {loss_a}, explore_var_a {explore_var_a}')
147150
return loss_a, explore_var_a
148151

149152
@lab_api
@@ -179,12 +182,13 @@ def get(self, a):
179182

180183
@lab_api
181184
def reset(self, state_space):
182-
logger.debug('AgentSpace.reset')
185+
logger.debug3('AgentSpace.reset')
183186
_action_v, _loss_v, _explore_var_v = self.aeb_space.init_data_v(AGENT_DATA_NAMES)
184187
for agent in self.agents:
185188
state_a = state_space.get(a=agent.a)
186189
agent.reset(state_a)
187190
_action_space, _loss_space, _explore_var_space = self.aeb_space.add(AGENT_DATA_NAMES, [_action_v, _loss_v, _explore_var_v])
191+
logger.debug3(f'action_space: {_action_space}')
188192
return _action_space
189193

190194
@lab_api
@@ -197,7 +201,7 @@ def act(self, state_space):
197201
action_a = agent.act(state_a)
198202
action_v[a, 0:len(action_a)] = action_a
199203
action_space, = self.aeb_space.add(data_names, [action_v])
200-
logger.debug(f'\naction_space: {action_space}')
204+
logger.debug3(f'\naction_space: {action_space}')
201205
return action_space
202206

203207
@lab_api
@@ -214,7 +218,7 @@ def update(self, action_space, reward_space, state_space, done_space):
214218
loss_v[a, 0:len(loss_a)] = loss_a
215219
explore_var_v[a, 0:len(explore_var_a)] = explore_var_a
216220
loss_space, explore_var_space = self.aeb_space.add(data_names, [loss_v, explore_var_v])
217-
logger.debug(f'\nloss_space: {loss_space}\nexplore_var_space: {explore_var_space}')
221+
logger.debug3(f'\nloss_space: {loss_space}\nexplore_var_space: {explore_var_space}')
218222
return loss_space, explore_var_space
219223

220224
@lab_api

slm_lab/agent/algorithm/actor_critic.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def init_nets(self):
161161
assert 'Separate' in net_type
162162
self.share_architecture = False
163163
out_dim = self.body.action_dim * [2]
164+
if len(out_dim) == 1:
165+
out_dim = out_dim[0]
164166
critic_out_dim = 1
165167

166168
self.net_spec['type'] = net_type = net_type.replace('Shared', '').replace('Separate', '')
@@ -195,37 +197,39 @@ def init_nets(self):
195197
self.post_init_nets()
196198

197199
@lab_api
198-
def calc_pdparam(self, x, evaluate=True):
200+
def calc_pdparam(self, x, evaluate=True, net=None):
199201
'''
200202
The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
201203
'''
204+
net = self.net if net is None else net
202205
if evaluate:
203-
pdparam = self.net.wrap_eval(x)
206+
pdparam = net.wrap_eval(x)
204207
else:
205-
self.net.train()
206-
pdparam = self.net(x)
208+
net.train()
209+
pdparam = net(x)
207210
if self.share_architecture:
208211
# MLPHeterogenousTails, get front (no critic)
209212
if self.body.is_discrete:
210-
return pdparam[0]
213+
pdparam = pdparam[0]
211214
else:
212215
if len(pdparam) == 2: # only (loc, scale) and (v)
213-
return pdparam[0]
216+
pdparam = pdparam[0]
214217
else:
215-
return pdparam[:-1]
216-
else:
217-
return pdparam
218+
pdparam = pdparam[:-1]
219+
logger.debug(f'pdparam: {pdparam}')
220+
return pdparam
218221

219-
def calc_v(self, x, evaluate=True):
222+
def calc_v(self, x, evaluate=True, net=None):
220223
'''
221224
Forward-pass to calculate the predicted state-value from critic.
222225
'''
226+
net = self.net if net is None else net
223227
if self.share_architecture:
224228
if evaluate:
225-
out = self.net.wrap_eval(x)
229+
out = net.wrap_eval(x)
226230
else:
227-
self.net.train()
228-
out = self.net(x)
231+
net.train()
232+
out = net(x)
229233
# MLPHeterogenousTails, get last
230234
v = out[-1].squeeze_(dim=1)
231235
else:
@@ -235,6 +239,7 @@ def calc_v(self, x, evaluate=True):
235239
self.critic.train()
236240
out = self.critic(x)
237241
v = out.squeeze_(dim=1)
242+
logger.debug(f'v: {v}')
238243
return v
239244

240245
@lab_api
@@ -264,7 +269,7 @@ def train_shared(self):
264269
self.to_train = 0
265270
self.body.log_probs = []
266271
self.body.entropies = []
267-
logger.debug(f'Total loss: {loss:.2f}')
272+
logger.debug(f'Total loss: {loss:.4f}')
268273
self.last_loss = loss.item()
269274
return self.last_loss
270275

@@ -282,7 +287,7 @@ def train_separate(self):
282287
self.to_train = 0
283288
self.body.entropies = []
284289
self.body.log_probs = []
285-
logger.debug(f'Total loss: {loss:.2f}')
290+
logger.debug(f'Total loss: {loss:.4f}')
286291
self.last_loss = loss.item()
287292
return self.last_loss
288293

@@ -309,7 +314,7 @@ def train_critic(self, batch):
309314

310315
def calc_policy_loss(self, batch, advs):
311316
'''Calculate the actor's policy loss'''
312-
assert len(self.body.log_probs) == len(advs), f'{len(self.body.log_probs)} vs {len(advs)}'
317+
assert len(self.body.log_probs) == len(advs), f'batch_size of log_probs {len(self.body.log_probs)} vs advs: {len(advs)}'
313318
log_probs = torch.stack(self.body.log_probs)
314319
policy_loss = - self.policy_loss_coef * log_probs * advs
315320
if self.add_entropy:
@@ -318,7 +323,7 @@ def calc_policy_loss(self, batch, advs):
318323
policy_loss = torch.mean(policy_loss)
319324
if torch.cuda.is_available() and self.net.gpu:
320325
policy_loss = policy_loss.cuda()
321-
logger.debug(f'Actor policy loss: {policy_loss:.2f}')
326+
logger.debug(f'Actor policy loss: {policy_loss:.4f}')
322327
return policy_loss
323328

324329
def calc_val_loss(self, batch, v_targets):
@@ -329,7 +334,7 @@ def calc_val_loss(self, batch, v_targets):
329334
val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets)
330335
if torch.cuda.is_available() and self.net.gpu:
331336
val_loss = val_loss.cuda()
332-
logger.debug(f'Critic value loss: {val_loss:.2f}')
337+
logger.debug(f'Critic value loss: {val_loss:.4f}')
333338
return val_loss
334339

335340
def calc_gae_advs_v_targets(self, batch):
@@ -360,6 +365,7 @@ def calc_gae_advs_v_targets(self, batch):
360365
adv_std[adv_std != adv_std] = 0
361366
adv_std += 1e-08
362367
adv_targets = (adv_targets - adv_targets.mean()) / adv_std
368+
logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
363369
return adv_targets, v_targets
364370

365371
def calc_nstep_advs_v_targets(self, batch):
@@ -375,6 +381,7 @@ def calc_nstep_advs_v_targets(self, batch):
375381
if torch.cuda.is_available() and self.net.gpu:
376382
nstep_advs = nstep_advs.cuda()
377383
adv_targets = v_targets = nstep_advs
384+
logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
378385
return adv_targets, v_targets
379386

380387
def calc_td_advs_v_targets(self, batch):
@@ -388,6 +395,7 @@ def calc_td_advs_v_targets(self, batch):
388395
td_returns = td_returns.cuda()
389396
v_targets = td_returns
390397
adv_targets = v_targets - v_preds # TD error, but called adv for API consistency
398+
logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
391399
return adv_targets, v_targets
392400

393401
@lab_api

slm_lab/agent/algorithm/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def post_init_nets(self):
5959
logger.info(f'Initialized algorithm models for lab_mode: {util.get_lab_mode()}')
6060

6161
@lab_api
62-
def calc_pdparam(self, x, evaluate=True):
62+
def calc_pdparam(self, x, evaluate=True, net=None):
6363
'''
6464
To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs.
6565
The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.

slm_lab/agent/algorithm/dqn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def calc_q_targets(self, batch):
109109
q_targets = (max_q_targets * batch['actions']) + (q_preds * (1 - batch['actions']))
110110
if torch.cuda.is_available() and self.net.gpu:
111111
q_targets = q_targets.cuda()
112+
logger.debug(f'q_targets: {q_targets}')
112113
return q_targets
113114

114115
@lab_api
@@ -221,6 +222,7 @@ def calc_q_targets(self, batch):
221222
q_targets = (max_q_targets * batch['actions']) + (q_preds * (1 - batch['actions']))
222223
if torch.cuda.is_available() and self.net.gpu:
223224
q_targets = q_targets.cuda()
225+
logger.debug(f'q_targets: {q_targets}')
224226
return q_targets
225227

226228
def update_nets(self):
@@ -333,12 +335,13 @@ def init_nets(self):
333335
self.eval_net = self.target_net
334336

335337
@lab_api
336-
def calc_pdparam(self, x, evaluate=True):
338+
def calc_pdparam(self, x, evaluate=True, net=None):
337339
'''
338340
Calculate pdparams for multi-action by chunking the network logits output
339341
'''
340-
pdparam = super(MultitaskDQN, self).calc_pdparam(x, evaluate=evaluate)
342+
pdparam = super(MultitaskDQN, self).calc_pdparam(x, evaluate=evaluate, net=net)
341343
pdparam = torch.cat(torch.split(pdparam, self.action_dims, dim=1))
344+
logger.debug(f'pdparam: {pdparam}')
342345
return pdparam
343346

344347
@lab_api
@@ -359,6 +362,7 @@ def act(self, state_a):
359362
action_pd = action_pd_a[idx]
360363
body.entropies.append(action_pd.entropy())
361364
body.log_probs.append(action_pd.log_prob(action_a[idx].float()))
365+
assert not torch.isnan(body.log_probs[-1])
362366
return action_a.cpu().numpy()
363367

364368
@lab_api
@@ -410,6 +414,7 @@ def calc_q_targets(self, batch):
410414
q_targets = torch.cat(multi_q_targets, dim=1)
411415
if torch.cuda.is_available() and self.net.gpu:
412416
q_targets = q_targets.cuda()
417+
logger.debug(f'q_targets: {q_targets}')
413418
return q_targets
414419

415420

@@ -432,12 +437,12 @@ def init_nets(self):
432437
self.eval_net = self.target_net
433438

434439
@lab_api
435-
def calc_pdparam(self, x, evaluate=True):
440+
def calc_pdparam(self, x, evaluate=True, net=None):
436441
'''
437442
Calculate pdparams for multi-action by chunking the network logits output
438443
'''
439444
x = torch.cat(torch.split(x, self.state_dims, dim=1)).unsqueeze_(dim=1)
440-
pdparam = SARSA.calc_pdparam(self, x, evaluate=evaluate)
445+
pdparam = SARSA.calc_pdparam(self, x, evaluate=evaluate, net=net)
441446
return pdparam
442447

443448
@lab_api
@@ -479,6 +484,7 @@ def calc_q_targets(self, batch):
479484
multi_q_targets.append(q_targets)
480485
# return as list for compatibility with net output in training_step
481486
q_targets = multi_q_targets
487+
logger.debug(f'q_targets: {q_targets}')
482488
return q_targets
483489

484490
@lab_api

slm_lab/agent/algorithm/math_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,3 @@ def calc_gaes(rewards, v_preds, next_v_preds, gamma, lam):
8686
assert not np.isnan(gaes).any(), f'GAE has nan: {gaes}'
8787
gaes = torch.from_numpy(gaes).float()
8888
return gaes
89-
90-
91-
# Q-learning calc

slm_lab/agent/algorithm/policy_util.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from slm_lab.lib import logger, util
1717
from torch import distributions
1818
import numpy as np
19+
import pydash as ps
1920
import torch
2021

2122
logger = logger.get_logger(__name__)
@@ -155,15 +156,17 @@ def sample_action_pd(ActionPD, pdparam, body):
155156
action_pd = ActionPD(logits=pdparam)
156157
else: # continuous outputs a list, loc and scale
157158
assert len(pdparam) == 2, pdparam
158-
# scale (stdev) must be >=0
159-
clamp_pdparam = torch.stack([pdparam[0], torch.clamp(pdparam[1], 1e-8)])
160-
action_pd = ActionPD(*clamp_pdparam)
159+
# scale (stdev) must be >0, use softplus
160+
if pdparam[1] < 5:
161+
pdparam[1] = torch.log(1 + torch.exp(pdparam[1])) + 1e-8
162+
action_pd = ActionPD(*pdparam)
161163
action = action_pd.sample()
162164
return action, action_pd
163165

164166

165167
# interface action sampling methods
166168

169+
167170
def default(state, algorithm, body):
168171
'''Plain policy by direct sampling using outputs of net as logits and constructing ActionPD as appropriate'''
169172
ActionPD, pdparam, body = init_action_pd(state, algorithm, body)
@@ -341,3 +344,50 @@ def rate_decay(algorithm, body):
341344
def periodic_decay(algorithm, body):
342345
'''Apply _periodic_decay to explore_var'''
343346
return fn_decay_explore_var(algorithm, body, _periodic_decay)
347+
348+
349+
# misc calc methods
350+
351+
352+
def guard_multi_pdparams(pdparams, body):
353+
'''Guard pdparams for multi action'''
354+
action_dim = body.action_dim
355+
is_multi_action = ps.is_iterable(action_dim)
356+
if is_multi_action:
357+
assert ps.is_list(pdparams)
358+
pdparams = [t.clone() for t in pdparams] # clone for grad safety
359+
assert len(pdparams) == len(action_dim), pdparams
360+
# transpose into (batch_size, [action_dims])
361+
pdparams = [list(torch.split(t, action_dim, dim=0)) for t in torch.cat(pdparams, dim=1)]
362+
return pdparams
363+
364+
365+
def calc_log_probs(algorithm, net, body, batch):
366+
'''
367+
Method to calculate log_probs fresh from batch data
368+
Body already stores log_prob from self.net. This is used for PPO where log_probs needs to be recalculated.
369+
'''
370+
states, actions = batch['states'], batch['actions']
371+
action_dim = body.action_dim
372+
is_multi_action = ps.is_iterable(action_dim)
373+
# construct log_probs for each state-action
374+
pdparams = algorithm.calc_pdparam(states, net=net)
375+
pdparams = guard_multi_pdparams(pdparams, body)
376+
assert len(pdparams) == len(states), f'batch_size of pdparams: {len(pdparams)} vs states: {len(states)}'
377+
378+
pdtypes = ACTION_PDS[body.action_type]
379+
ActionPD = getattr(distributions, body.action_pdtype)
380+
381+
log_probs = []
382+
for idx, pdparam in enumerate(pdparams):
383+
if not is_multi_action: # already cloned for multi_action above
384+
pdparam = pdparam.clone() # clone for grad safety
385+
_action, action_pd = sample_action_pd(ActionPD, pdparam, body)
386+
log_probs.append(action_pd.log_prob(actions[idx]))
387+
log_probs = torch.stack(log_probs)
388+
if is_multi_action:
389+
log_probs = log_probs.mean(dim=1)
390+
log_probs = torch.tensor(log_probs, requires_grad=True)
391+
assert not torch.isnan(log_probs).any(), f'log_probs: {log_probs}, \npdparams: {pdparams} \nactions: {actions}'
392+
logger.debug(f'log_probs: {log_probs}')
393+
return log_probs

0 commit comments

Comments
 (0)