@@ -66,10 +66,14 @@ def test_unscaled_loss(self):
6666 mse = loss .mean_squared_error
6767 mae = loss .mean_absolute_error
6868 percentiles = loss .absolute_error_percentiles
69- self .assertNear (float (mse ), (3 ** 2 + 0 + 0 + 4 ** 2 + 0.5 ** 2 ) / 5 , 1e-6 )
70- self .assertNear (float (mae ), (3 + 0 + 0 + 4 + 0.5 ) / 5 , 1e-6 )
7169 self .assertNear (
72- float (loss .loss_tensor ), (2.5 + 0 + 0 + 3.5 + (0.5 ** 2 ) / 2 ) / 5 , 1e-6
70+ float (tf .squeeze (mse )), (3 ** 2 + 0 + 0 + 4 ** 2 + 0.5 ** 2 ) / 5 , 1e-6
71+ )
72+ self .assertNear (float (tf .squeeze (mae )), (3 + 0 + 0 + 4 + 0.5 ) / 5 , 1e-6 )
73+ self .assertNear (
74+ float (tf .squeeze (loss .loss_tensor )),
75+ (2.5 + 0 + 0 + 3.5 + (0.5 ** 2 ) / 2 ) / 5 ,
76+ 1e-6 ,
7377 )
7478 self .assertAllEqual (percentiles , ((0 ,), (0 ,), (0.5 ,), (3 ,), (4 ,)))
7579
@@ -86,9 +90,12 @@ def test_percentage_loss(self):
8690 mape = loss .mean_absolute_percentage_error
8791 percentiles = loss .absolute_percentage_error_percentiles
8892 self .assertAlmostEqual (
89- float (mspe ), ((3 / 4 ) ** 2 + 0 + 0 + 4 ** 2 + (0.5 / 2 ) ** 2 ) / 5
93+ float (tf .squeeze (mspe )),
94+ ((3 / 4 ) ** 2 + 0 + 0 + 4 ** 2 + (0.5 / 2 ) ** 2 ) / 5 ,
95+ )
96+ self .assertAlmostEqual (
97+ float (tf .squeeze (mape )), (3 / 4 + 0 + 0 + 4 + 0.5 / 2 ) / 5
9098 )
91- self .assertAlmostEqual (float (mape ), (3 / 4 + 0 + 0 + 4 + 0.5 / 2 ) / 5 )
9299 self .assertAllEqual (percentiles , ((0 ,), (0.5 / 2 ,), (0.5 / 2 ,), (4 ,)))
93100
94101 def test_normalized_loss_when_expected_value_greater_than_one (self ):
@@ -121,11 +128,11 @@ def test_normalized_loss_when_expected_value_greater_than_one(self):
121128 loss_type = options .LossType .MEAN_SQUARED_ERROR
122129 )
123130 self .assertAlmostEqual (
124- float (mean_absolute_error .loss_tensor ),
131+ float (tf . squeeze ( mean_absolute_error .loss_tensor ) ),
125132 (0.3 + 1.5 + 0.0 + 0.5 + 2.0 ) / 5 ,
126133 )
127134 self .assertAlmostEqual (
128- float (mean_squared_error .loss_tensor ),
135+ float (tf . squeeze ( mean_squared_error .loss_tensor ) ),
129136 (0.3 ** 2 + 1.5 ** 2 + 0.0 + 0.5 ** 2 + 2.0 ** 2 ) / 5 ,
130137 delta = 1e-6 ,
131138 )
@@ -272,10 +279,14 @@ def test_unknown_shape(self):
272279 self .assertEqual (loss .loss_tensor .shape , (1 ,))
273280 self .assertEqual (percentiles .shape , (len (percentile_ranks ), 1 ))
274281
275- self .assertNear (float (mse ), (3 ** 2 + 0 + 0 + 4 ** 2 + 0.5 ** 2 ) / 5 , 1e-6 )
276- self .assertNear (float (mae ), (3 + 0 + 0 + 4 + 0.5 ) / 5 , 1e-6 )
277282 self .assertNear (
278- float (loss .loss_tensor ), (2.5 + 0 + 0 + 3.5 + (0.5 ** 2 ) / 2 ) / 5 , 1e-6
283+ float (tf .squeeze (mse )), (3 ** 2 + 0 + 0 + 4 ** 2 + 0.5 ** 2 ) / 5 , 1e-6
284+ )
285+ self .assertNear (float (tf .squeeze (mae )), (3 + 0 + 0 + 4 + 0.5 ) / 5 , 1e-6 )
286+ self .assertNear (
287+ float (tf .squeeze (loss .loss_tensor )),
288+ (2.5 + 0 + 0 + 3.5 + (0.5 ** 2 ) / 2 ) / 5 ,
289+ 1e-6 ,
279290 )
280291 self .assertAllEqual (percentiles , ((0 ,), (0.5 ,), (3 ,), (4 ,)))
281292
@@ -349,10 +360,14 @@ def test_single_task_unknown_shape(self):
349360 self .assertEqual (loss .loss_tensor .shape , (num_tasks ,))
350361 self .assertEqual (percentiles .shape , (len (percentile_ranks ), num_tasks ))
351362
352- self .assertNear (float (mse ), (3 ** 2 + 0 + 0 + 4 ** 2 + 0.5 ** 2 ) / 5 , 1e-6 )
353- self .assertNear (float (mae ), (3 + 0 + 0 + 4 + 0.5 ) / 5 , 1e-6 )
354363 self .assertNear (
355- float (loss .loss_tensor ), (2.5 + 0 + 0 + 3.5 + (0.5 ** 2 ) / 2 ) / 5 , 1e-6
364+ float (tf .squeeze (mse )), (3 ** 2 + 0 + 0 + 4 ** 2 + 0.5 ** 2 ) / 5 , 1e-6
365+ )
366+ self .assertNear (float (tf .squeeze (mae )), (3 + 0 + 0 + 4 + 0.5 ) / 5 , 1e-6 )
367+ self .assertNear (
368+ float (tf .squeeze (loss .loss_tensor )),
369+ (2.5 + 0 + 0 + 3.5 + (0.5 ** 2 ) / 2 ) / 5 ,
370+ 1e-6 ,
356371 )
357372 self .assertAllEqual (percentiles , ((0 ,), (0.5 ,), (3 ,), (4 ,)))
358373
0 commit comments