Skip to content

Commit 1a126a1

Browse files
committed
Refactoring
1 parent 666faf1 commit 1a126a1

10 files changed

+723
-617
lines changed

PolicyGradient/A3C Atari.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
"traceback": [
104104
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
105105
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
106-
"\u001b[0;32m<ipython-input-9-4c24ff5f438a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0ma3c\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmake_copy_params_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mPolicyEval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobject\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_net\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_every\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msummary_writer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
106+
"\u001b[0;32m<ipython-input-9-4c24ff5f438a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0ma3c\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmake_copy_params_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mPolicyMonitor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobject\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_net\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_every\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msummary_writer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
107107
"\u001b[0;32m/Users/dennybritz/github/rl/PolicyGradient/a3c/worker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_processor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mStateProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matari\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mhelpers\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0matari_helpers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mestimators\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mValueEstimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mPolicyEstimator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mTransition\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcollections\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamedtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Transition\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"state\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"action\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"reward\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"next_state\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"done\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
108108
"\u001b[0;31mImportError\u001b[0m: No module named 'estimators'"
109109
]
@@ -112,7 +112,7 @@
112112
"source": [
113113
"from a3c.worker import make_copy_params_op\n",
114114
"\n",
115-
"class PolicyEval(object):\n",
115+
"class PolicyMonitor(object):\n",
116116
" def __init__(env, policy_net, summary_writer):\n",
117117
" self.env = env\n",
118118
" self.global_policy_net = policy_net\n",

PolicyGradient/a3c/estimator_test.py

+99-99
Original file line numberDiff line numberDiff line change
@@ -19,111 +19,111 @@
1919

2020

2121
def make_env():
22-
return gym.envs.make("Breakout-v0")
22+
return gym.envs.make("Breakout-v0")
2323

2424
VALID_ACTIONS = [0, 1, 2, 3]
2525

2626
class PolicyEstimatorTest(tf.test.TestCase):
27-
def testPredict(self):
28-
env = make_env()
29-
sp = StateProcessor()
30-
estimator = PolicyEstimator(len(VALID_ACTIONS))
31-
32-
with self.test_session() as sess:
33-
sess.run(tf.initialize_all_variables())
34-
35-
# Generate a state
36-
state = sp.process(env.reset())
37-
processed_state = atari_helpers.atari_make_initial_state(state)
38-
processed_states = np.array([processed_state])
39-
40-
# Run feeds
41-
feed_dict = {
42-
estimator.states: processed_states,
43-
estimator.targets: [1.0],
44-
estimator.actions: [1]
45-
}
46-
loss = sess.run(estimator.loss, feed_dict)
47-
pred = sess.run(estimator.predictions, feed_dict)
48-
49-
# Assertions
50-
self.assertTrue(loss > 0.0)
51-
self.assertEqual(pred["probs"].shape, (1, len(VALID_ACTIONS)))
52-
self.assertEqual(pred["logits"].shape, (1, len(VALID_ACTIONS)))
53-
54-
def testGradient(self):
55-
env = make_env()
56-
sp = StateProcessor()
57-
estimator = PolicyEstimator(len(VALID_ACTIONS))
58-
59-
with self.test_session() as sess:
60-
sess.run(tf.initialize_all_variables())
61-
62-
# Generate a state
63-
state = sp.process(env.reset())
64-
processed_state = atari_helpers.atari_make_initial_state(state)
65-
processed_states = np.array([processed_state])
66-
67-
# Run feeds
68-
feed_dict = {
69-
estimator.states: processed_states,
70-
estimator.targets: [1.0],
71-
estimator.actions: [1]
72-
}
73-
loss = sess.run(estimator.train_op, feed_dict)
74-
75-
# Assertions
76-
self.assertTrue(loss > 0.0)
27+
def testPredict(self):
28+
env = make_env()
29+
sp = StateProcessor()
30+
estimator = PolicyEstimator(len(VALID_ACTIONS))
31+
32+
with self.test_session() as sess:
33+
sess.run(tf.initialize_all_variables())
34+
35+
# Generate a state
36+
state = sp.process(env.reset())
37+
processed_state = atari_helpers.atari_make_initial_state(state)
38+
processed_states = np.array([processed_state])
39+
40+
# Run feeds
41+
feed_dict = {
42+
estimator.states: processed_states,
43+
estimator.targets: [1.0],
44+
estimator.actions: [1]
45+
}
46+
loss = sess.run(estimator.loss, feed_dict)
47+
pred = sess.run(estimator.predictions, feed_dict)
48+
49+
# Assertions
50+
self.assertTrue(loss > 0.0)
51+
self.assertEqual(pred["probs"].shape, (1, len(VALID_ACTIONS)))
52+
self.assertEqual(pred["logits"].shape, (1, len(VALID_ACTIONS)))
53+
54+
def testGradient(self):
55+
env = make_env()
56+
sp = StateProcessor()
57+
estimator = PolicyEstimator(len(VALID_ACTIONS))
58+
59+
with self.test_session() as sess:
60+
sess.run(tf.initialize_all_variables())
61+
62+
# Generate a state
63+
state = sp.process(env.reset())
64+
processed_state = atari_helpers.atari_make_initial_state(state)
65+
processed_states = np.array([processed_state])
66+
67+
# Run feeds
68+
feed_dict = {
69+
estimator.states: processed_states,
70+
estimator.targets: [1.0],
71+
estimator.actions: [1]
72+
}
73+
loss = sess.run(estimator.train_op, feed_dict)
74+
75+
# Assertions
76+
self.assertTrue(loss > 0.0)
7777

7878

7979
class ValueEstimatorTest(tf.test.TestCase):
80-
def testPredict(self):
81-
env = make_env()
82-
sp = StateProcessor()
83-
estimator = ValueEstimator()
84-
85-
with self.test_session() as sess:
86-
sess.run(tf.initialize_all_variables())
87-
88-
# Generate a state
89-
state = sp.process(env.reset())
90-
processed_state = atari_helpers.atari_make_initial_state(state)
91-
processed_states = np.array([processed_state])
92-
93-
# Run feeds
94-
feed_dict = {
95-
estimator.states: processed_states,
96-
estimator.targets: [1.0],
97-
}
98-
loss = sess.run(estimator.loss, feed_dict)
99-
pred = sess.run(estimator.predictions, feed_dict)
100-
101-
# Assertions
102-
self.assertTrue(loss > 0.0)
103-
self.assertEqual(pred["logits"].shape, (1,))
104-
105-
def testGradient(self):
106-
env = make_env()
107-
sp = StateProcessor()
108-
estimator = ValueEstimator()
109-
110-
with self.test_session() as sess:
111-
sess.run(tf.initialize_all_variables())
112-
113-
# Generate a state
114-
state = sp.process(env.reset())
115-
processed_state = atari_helpers.atari_make_initial_state(state)
116-
processed_states = np.array([processed_state])
117-
118-
# Run feeds
119-
feed_dict = {
120-
estimator.states: processed_states,
121-
estimator.targets: [1.0],
122-
}
123-
loss = sess.run(estimator.train_op, feed_dict)
124-
125-
# Assertions
126-
self.assertTrue(loss > 0.0)
80+
def testPredict(self):
81+
env = make_env()
82+
sp = StateProcessor()
83+
estimator = ValueEstimator()
84+
85+
with self.test_session() as sess:
86+
sess.run(tf.initialize_all_variables())
87+
88+
# Generate a state
89+
state = sp.process(env.reset())
90+
processed_state = atari_helpers.atari_make_initial_state(state)
91+
processed_states = np.array([processed_state])
92+
93+
# Run feeds
94+
feed_dict = {
95+
estimator.states: processed_states,
96+
estimator.targets: [1.0],
97+
}
98+
loss = sess.run(estimator.loss, feed_dict)
99+
pred = sess.run(estimator.predictions, feed_dict)
100+
101+
# Assertions
102+
self.assertTrue(loss > 0.0)
103+
self.assertEqual(pred["logits"].shape, (1,))
104+
105+
def testGradient(self):
106+
env = make_env()
107+
sp = StateProcessor()
108+
estimator = ValueEstimator()
109+
110+
with self.test_session() as sess:
111+
sess.run(tf.initialize_all_variables())
112+
113+
# Generate a state
114+
state = sp.process(env.reset())
115+
processed_state = atari_helpers.atari_make_initial_state(state)
116+
processed_states = np.array([processed_state])
117+
118+
# Run feeds
119+
feed_dict = {
120+
estimator.states: processed_states,
121+
estimator.targets: [1.0],
122+
}
123+
loss = sess.run(estimator.train_op, feed_dict)
124+
125+
# Assertions
126+
self.assertTrue(loss > 0.0)
127127

128128
if __name__ == '__main__':
129-
unittest.main()
129+
unittest.main()

0 commit comments

Comments
 (0)