@@ -153,20 +153,20 @@ def test_ema(self):
153
153
self .assertAllClose (v , [[2.0 , 3.0 ], [4.0 , 5.0 ]])
154
154
self .assertAllClose (
155
155
optimizer ._model_variables_moving_average [0 ],
156
- [[2.9 , 3.9 ], [4.9 , 5.9 ]], # avg of initial v + current v
156
+ [[2.0 , 3.0 ], [4.0 , 5.0 ]], # initialized after first step
157
157
)
158
158
self .strategy .run (lambda : optimizer .apply_gradients ([(grads , v )]))
159
159
self .assertAllClose (v , [[1.0 , 2.0 ], [3.0 , 4.0 ]])
160
160
self .assertAllClose (
161
161
optimizer ._model_variables_moving_average [0 ],
162
- [[2.71 , 3.71 ], [4.71 , 5.71 ]],
162
+ [[1.9 , 2.9 ], [3.9 , 4.9 ]],
163
163
)
164
164
self .strategy .run (lambda : optimizer .apply_gradients ([(grads , v )]))
165
165
# Variables were overwritten with EMA
166
- self .assertAllClose (v , [[2.439 , 3.439 ], [4.439 , 5.439 ]])
166
+ self .assertAllClose (v , [[1.71 , 2.71 ], [3.71 , 4.71 ]])
167
167
self .assertAllClose (
168
168
optimizer ._model_variables_moving_average [0 ],
169
- [[2.439 , 3.439 ], [4.439 , 5.439 ]],
169
+ [[1.71 , 2.71 ], [3.71 , 4.71 ]],
170
170
)
171
171
172
172
def test_gradient_accumulation (self ):
0 commit comments