-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain.py
106 lines (85 loc) · 2.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import timeit
import jax
import jax.numpy as jnp
import yaml
from matplotlib import pyplot as plt
from rejax import get_algo
def main(algo_str, config, seed_id, num_seeds, time_fit):
algo_cls = get_algo(algo_str)
algo = algo_cls.create(**config)
print(algo.config)
old_eval_callback = algo.eval_callback
def eval_callback(algo, ts, rng):
lengths, returns = old_eval_callback(algo, ts, rng)
jax.debug.print(
"Step {}, Mean episode length: {}, Mean return: {}",
ts.global_step,
lengths.mean(),
returns.mean(),
)
return lengths, returns
algo = algo.replace(eval_callback=eval_callback)
# Train it
key = jax.random.PRNGKey(seed_id)
keys = jax.random.split(key, num_seeds)
vmap_train = jax.jit(jax.vmap(algo_cls.train, in_axes=(None, 0)))
ts, (_, returns) = vmap_train(algo, keys)
returns.block_until_ready()
print(f"Achieved mean return of {returns.mean(axis=-1)[:, -1]}")
t = jnp.arange(returns.shape[1]) * algo.eval_freq
colors = plt.cm.cool(jnp.linspace(0, 1, num_seeds))
for i in range(num_seeds):
plt.plot(t, returns.mean(axis=-1)[i], c=colors[i])
plt.show()
if time_fit:
print("Fitting 3 times, getting a mean time of... ", end="", flush=True)
def time_fn():
return vmap_train(algo, keys)
time = timeit.timeit(time_fn, number=3) / 3
print(
f"{time:.1f} seconds total, equalling to "
f"{time / num_seeds:.1f} seconds per seed"
)
# Move local variables to global scope for debugging (run with -i)
globals().update(locals())
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
default="configs/cartpole.yaml",
help="Path to configuration file.",
)
parser.add_argument(
"--algorithm",
type=str,
required=True,
)
parser.add_argument(
"--time-fit",
action="store_true",
help="Time how long it takes to fit the agent by fitting 3 times.",
)
parser.add_argument(
"--seed_id",
type=int,
default=0,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--num-seeds",
type=int,
default=1,
help="Number of seeds to roll out.",
)
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f.read())[args.algorithm]
main(
args.algorithm,
config,
args.seed_id,
args.num_seeds,
args.time_fit,
)