forked from alirezakazemipour/ACKTR-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
90 lines (73 loc) · 3.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from common import *
from brain import Brain
from tqdm import tqdm
import multiprocessing as mp
import numpy as np
import gym
if __name__ == '__main__':
params = get_params()
set_random_seeds(params["seed"])
init_wandb(params["online_wandb"])
test_env = gym.make(params["env_name"])
params.update({"n_actions": test_env.action_space.n})
test_env.close()
del test_env
params.update({"rollout_length": params["batch_size"] // params["n_workers"]})
brain = Brain(**params)
if not params["do_test"]:
logger = Logger(brain, **params)
if not params["train_from_scratch"]:
init_iteration, episode = logger.load_weights()
else:
init_iteration = 0
episode = 0
parents = []
for i in range(params["n_workers"]):
parent_conn, child_conn = mp.Pipe()
parents.append(parent_conn)
w = Worker(i, conn=child_conn, **params)
w.start()
rollout_base_shape = params["n_workers"], params["rollout_length"]
total_states = np.zeros(rollout_base_shape + params["state_shape"], dtype=np.uint8)
total_actions = np.zeros(rollout_base_shape, dtype=np.int32)
total_rewards = np.zeros(rollout_base_shape)
total_dones = np.zeros(rollout_base_shape, dtype=bool)
total_values = np.zeros(rollout_base_shape, dtype=np.float32)
next_states = np.zeros((rollout_base_shape[0],) + params["state_shape"], dtype=np.uint8)
infos = {}
logger.on()
episode_reward = 0
episode_length = 0
for iteration in tqdm(range(init_iteration + 1, params["total_iterations"] + 1)):
for t in range(params["rollout_length"]):
for worker_id, parent in enumerate(parents):
s = parent.recv()
total_states[worker_id, t] = s
total_actions[:, t], total_values[:, t] = brain.get_actions_and_values(total_states[:, t], batch=True)
for parent, a, v in zip(parents, total_actions[:, t], total_values[:, t]):
parent.send((int(a), v))
for worker_id, parent in enumerate(parents):
s_, r, d, infos[worker_id] = parent.recv()
total_rewards[worker_id, t] = r
total_dones[worker_id, t] = d
next_states[worker_id] = s_
episode_reward += total_rewards[0, t]
episode_length += 1
if total_dones[0, t] and infos[0]["lives"] == 0:
episode += 1
logger.log_episode(episode, episode_reward, episode_length)
episode_reward = 0
episode_length = 0
_, next_values= brain.get_actions_and_values(next_states, batch=True)
training_logs = brain.train(np.concatenate(total_states),
np.concatenate(total_actions),
total_rewards,
total_dones,
total_values,
next_values)
logger.log_iteration(iteration, training_logs)
else:
logger = Logger(brain, **params)
logger.load_weights()
play = Evaluator(brain, 1, **params)
play.evaluate()