-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun.py
51 lines (45 loc) · 1.33 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from models.generator import Generator
from agents.agent import Agent
from trainer import Trainer
from game.env import Env
import torch
def main(game_name, game_length):
#Game description
reward_mode = 'base'
reward_scale = 1.0
elite_prob = 0
env = Env(game_name, game_length, {'reward_mode': reward_mode, 'reward_scale': reward_scale, 'elite_prob': elite_prob})
#Network
latent_shape = (512,)
dropout = 0
lr = .0001
gen = Generator(latent_shape, env, 'nearest', dropout, lr)
#Agent
num_processes = 1
experiment = "Experiments"
lr = .00025
model = 'base'
dropout = .3
reconstruct = None
r_weight = .05
Agent.num_steps = 5
Agent.entropy_coef = .01
Agent.value_loss_coef = .1
agent = Agent(env, num_processes, experiment, 0, lr, model, dropout, reconstruct, r_weight)
#Training
gen_updates = 1e4
gen_batch = 32
gen_batches = 1
diversity_batches = 0
rl_batch = 1e4
pretrain = 0
elite_persist = False
elite_mode = 'mean'
load_version = 0
notes = ''
agent.writer.add_hparams({'Experiment': experiment, 'RL_LR':lr, 'Minibatch':gen_batch, 'RL_Steps': rl_batch, 'Notes':notes}, {})
t = Trainer(gen, agent, experiment, load_version, elite_mode, elite_persist)
t.loss = lambda x, y: x.mean().pow(2)
t.train(gen_updates, gen_batch, gen_batches, diversity_batches, rl_batch, pretrain)
if(__name__ == "__main__"):
main('zelda', 1000)