19
19
20
20
21
21
def make_env ():
22
- return gym .envs .make ("Breakout-v0" )
22
+ return gym .envs .make ("Breakout-v0" )
23
23
24
24
VALID_ACTIONS = [0 , 1 , 2 , 3 ]
25
25
26
26
class PolicyEstimatorTest (tf .test .TestCase ):
27
- def testPredict (self ):
28
- env = make_env ()
29
- sp = StateProcessor ()
30
- estimator = PolicyEstimator (len (VALID_ACTIONS ))
31
-
32
- with self .test_session () as sess :
33
- sess .run (tf .initialize_all_variables ())
34
-
35
- # Generate a state
36
- state = sp .process (env .reset ())
37
- processed_state = atari_helpers .atari_make_initial_state (state )
38
- processed_states = np .array ([processed_state ])
39
-
40
- # Run feeds
41
- feed_dict = {
42
- estimator .states : processed_states ,
43
- estimator .targets : [1.0 ],
44
- estimator .actions : [1 ]
45
- }
46
- loss = sess .run (estimator .loss , feed_dict )
47
- pred = sess .run (estimator .predictions , feed_dict )
48
-
49
- # Assertions
50
- self .assertTrue (loss > 0.0 )
51
- self .assertEqual (pred ["probs" ].shape , (1 , len (VALID_ACTIONS )))
52
- self .assertEqual (pred ["logits" ].shape , (1 , len (VALID_ACTIONS )))
53
-
54
- def testGradient (self ):
55
- env = make_env ()
56
- sp = StateProcessor ()
57
- estimator = PolicyEstimator (len (VALID_ACTIONS ))
58
-
59
- with self .test_session () as sess :
60
- sess .run (tf .initialize_all_variables ())
61
-
62
- # Generate a state
63
- state = sp .process (env .reset ())
64
- processed_state = atari_helpers .atari_make_initial_state (state )
65
- processed_states = np .array ([processed_state ])
66
-
67
- # Run feeds
68
- feed_dict = {
69
- estimator .states : processed_states ,
70
- estimator .targets : [1.0 ],
71
- estimator .actions : [1 ]
72
- }
73
- loss = sess .run (estimator .train_op , feed_dict )
74
-
75
- # Assertions
76
- self .assertTrue (loss > 0.0 )
27
+ def testPredict (self ):
28
+ env = make_env ()
29
+ sp = StateProcessor ()
30
+ estimator = PolicyEstimator (len (VALID_ACTIONS ))
31
+
32
+ with self .test_session () as sess :
33
+ sess .run (tf .initialize_all_variables ())
34
+
35
+ # Generate a state
36
+ state = sp .process (env .reset ())
37
+ processed_state = atari_helpers .atari_make_initial_state (state )
38
+ processed_states = np .array ([processed_state ])
39
+
40
+ # Run feeds
41
+ feed_dict = {
42
+ estimator .states : processed_states ,
43
+ estimator .targets : [1.0 ],
44
+ estimator .actions : [1 ]
45
+ }
46
+ loss = sess .run (estimator .loss , feed_dict )
47
+ pred = sess .run (estimator .predictions , feed_dict )
48
+
49
+ # Assertions
50
+ self .assertTrue (loss > 0.0 )
51
+ self .assertEqual (pred ["probs" ].shape , (1 , len (VALID_ACTIONS )))
52
+ self .assertEqual (pred ["logits" ].shape , (1 , len (VALID_ACTIONS )))
53
+
54
+ def testGradient (self ):
55
+ env = make_env ()
56
+ sp = StateProcessor ()
57
+ estimator = PolicyEstimator (len (VALID_ACTIONS ))
58
+
59
+ with self .test_session () as sess :
60
+ sess .run (tf .initialize_all_variables ())
61
+
62
+ # Generate a state
63
+ state = sp .process (env .reset ())
64
+ processed_state = atari_helpers .atari_make_initial_state (state )
65
+ processed_states = np .array ([processed_state ])
66
+
67
+ # Run feeds
68
+ feed_dict = {
69
+ estimator .states : processed_states ,
70
+ estimator .targets : [1.0 ],
71
+ estimator .actions : [1 ]
72
+ }
73
+ loss = sess .run (estimator .train_op , feed_dict )
74
+
75
+ # Assertions
76
+ self .assertTrue (loss > 0.0 )
77
77
78
78
79
79
class ValueEstimatorTest (tf .test .TestCase ):
80
- def testPredict (self ):
81
- env = make_env ()
82
- sp = StateProcessor ()
83
- estimator = ValueEstimator ()
84
-
85
- with self .test_session () as sess :
86
- sess .run (tf .initialize_all_variables ())
87
-
88
- # Generate a state
89
- state = sp .process (env .reset ())
90
- processed_state = atari_helpers .atari_make_initial_state (state )
91
- processed_states = np .array ([processed_state ])
92
-
93
- # Run feeds
94
- feed_dict = {
95
- estimator .states : processed_states ,
96
- estimator .targets : [1.0 ],
97
- }
98
- loss = sess .run (estimator .loss , feed_dict )
99
- pred = sess .run (estimator .predictions , feed_dict )
100
-
101
- # Assertions
102
- self .assertTrue (loss > 0.0 )
103
- self .assertEqual (pred ["logits" ].shape , (1 ,))
104
-
105
- def testGradient (self ):
106
- env = make_env ()
107
- sp = StateProcessor ()
108
- estimator = ValueEstimator ()
109
-
110
- with self .test_session () as sess :
111
- sess .run (tf .initialize_all_variables ())
112
-
113
- # Generate a state
114
- state = sp .process (env .reset ())
115
- processed_state = atari_helpers .atari_make_initial_state (state )
116
- processed_states = np .array ([processed_state ])
117
-
118
- # Run feeds
119
- feed_dict = {
120
- estimator .states : processed_states ,
121
- estimator .targets : [1.0 ],
122
- }
123
- loss = sess .run (estimator .train_op , feed_dict )
124
-
125
- # Assertions
126
- self .assertTrue (loss > 0.0 )
80
+ def testPredict (self ):
81
+ env = make_env ()
82
+ sp = StateProcessor ()
83
+ estimator = ValueEstimator ()
84
+
85
+ with self .test_session () as sess :
86
+ sess .run (tf .initialize_all_variables ())
87
+
88
+ # Generate a state
89
+ state = sp .process (env .reset ())
90
+ processed_state = atari_helpers .atari_make_initial_state (state )
91
+ processed_states = np .array ([processed_state ])
92
+
93
+ # Run feeds
94
+ feed_dict = {
95
+ estimator .states : processed_states ,
96
+ estimator .targets : [1.0 ],
97
+ }
98
+ loss = sess .run (estimator .loss , feed_dict )
99
+ pred = sess .run (estimator .predictions , feed_dict )
100
+
101
+ # Assertions
102
+ self .assertTrue (loss > 0.0 )
103
+ self .assertEqual (pred ["logits" ].shape , (1 ,))
104
+
105
+ def testGradient (self ):
106
+ env = make_env ()
107
+ sp = StateProcessor ()
108
+ estimator = ValueEstimator ()
109
+
110
+ with self .test_session () as sess :
111
+ sess .run (tf .initialize_all_variables ())
112
+
113
+ # Generate a state
114
+ state = sp .process (env .reset ())
115
+ processed_state = atari_helpers .atari_make_initial_state (state )
116
+ processed_states = np .array ([processed_state ])
117
+
118
+ # Run feeds
119
+ feed_dict = {
120
+ estimator .states : processed_states ,
121
+ estimator .targets : [1.0 ],
122
+ }
123
+ loss = sess .run (estimator .train_op , feed_dict )
124
+
125
+ # Assertions
126
+ self .assertTrue (loss > 0.0 )
127
127
128
128
if __name__ == '__main__' :
129
- unittest .main ()
129
+ unittest .main ()
0 commit comments