Skip to content

Commit 935cbf4

Browse files
Use tf.squeeze before casting non-0d tensors to float (#370)
This becomes a hard error after numpy 2.4 which we are now using internally.
1 parent 3ba6877 commit 935cbf4

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

gematria/model/python/loss_utils_test.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,14 @@ def test_unscaled_loss(self):
6666
mse = loss.mean_squared_error
6767
mae = loss.mean_absolute_error
6868
percentiles = loss.absolute_error_percentiles
69-
self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6)
70-
self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6)
7169
self.assertNear(
72-
float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6
70+
float(tf.squeeze(mse)), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6
71+
)
72+
self.assertNear(float(tf.squeeze(mae)), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6)
73+
self.assertNear(
74+
float(tf.squeeze(loss.loss_tensor)),
75+
(2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5,
76+
1e-6,
7377
)
7478
self.assertAllEqual(percentiles, ((0,), (0,), (0.5,), (3,), (4,)))
7579

@@ -86,9 +90,12 @@ def test_percentage_loss(self):
8690
mape = loss.mean_absolute_percentage_error
8791
percentiles = loss.absolute_percentage_error_percentiles
8892
self.assertAlmostEqual(
89-
float(mspe), ((3 / 4) ** 2 + 0 + 0 + 4**2 + (0.5 / 2) ** 2) / 5
93+
float(tf.squeeze(mspe)),
94+
((3 / 4) ** 2 + 0 + 0 + 4**2 + (0.5 / 2) ** 2) / 5,
95+
)
96+
self.assertAlmostEqual(
97+
float(tf.squeeze(mape)), (3 / 4 + 0 + 0 + 4 + 0.5 / 2) / 5
9098
)
91-
self.assertAlmostEqual(float(mape), (3 / 4 + 0 + 0 + 4 + 0.5 / 2) / 5)
9299
self.assertAllEqual(percentiles, ((0,), (0.5 / 2,), (0.5 / 2,), (4,)))
93100

94101
def test_normalized_loss_when_expected_value_greater_than_one(self):
@@ -121,11 +128,11 @@ def test_normalized_loss_when_expected_value_greater_than_one(self):
121128
loss_type=options.LossType.MEAN_SQUARED_ERROR
122129
)
123130
self.assertAlmostEqual(
124-
float(mean_absolute_error.loss_tensor),
131+
float(tf.squeeze(mean_absolute_error.loss_tensor)),
125132
(0.3 + 1.5 + 0.0 + 0.5 + 2.0) / 5,
126133
)
127134
self.assertAlmostEqual(
128-
float(mean_squared_error.loss_tensor),
135+
float(tf.squeeze(mean_squared_error.loss_tensor)),
129136
(0.3**2 + 1.5**2 + 0.0 + 0.5**2 + 2.0**2) / 5,
130137
delta=1e-6,
131138
)
@@ -272,10 +279,14 @@ def test_unknown_shape(self):
272279
self.assertEqual(loss.loss_tensor.shape, (1,))
273280
self.assertEqual(percentiles.shape, (len(percentile_ranks), 1))
274281

275-
self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6)
276-
self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6)
277282
self.assertNear(
278-
float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6
283+
float(tf.squeeze(mse)), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6
284+
)
285+
self.assertNear(float(tf.squeeze(mae)), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6)
286+
self.assertNear(
287+
float(tf.squeeze(loss.loss_tensor)),
288+
(2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5,
289+
1e-6,
279290
)
280291
self.assertAllEqual(percentiles, ((0,), (0.5,), (3,), (4,)))
281292

@@ -349,10 +360,14 @@ def test_single_task_unknown_shape(self):
349360
self.assertEqual(loss.loss_tensor.shape, (num_tasks,))
350361
self.assertEqual(percentiles.shape, (len(percentile_ranks), num_tasks))
351362

352-
self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6)
353-
self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6)
354363
self.assertNear(
355-
float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6
364+
float(tf.squeeze(mse)), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6
365+
)
366+
self.assertNear(float(tf.squeeze(mae)), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6)
367+
self.assertNear(
368+
float(tf.squeeze(loss.loss_tensor)),
369+
(2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5,
370+
1e-6,
356371
)
357372
self.assertAllEqual(percentiles, ((0,), (0.5,), (3,), (4,)))
358373

gematria/model/python/model_base_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def test_randomized_expected_outputs_delta(self):
359359
for i, prefix_throughputs in enumerate(
360360
inverse_throughputs.prefix_inverse_throughput_cycles
361361
):
362-
expected_output_delta = float(expected_output_deltas[i])
362+
expected_output_delta = float(tf.squeeze(expected_output_deltas[i]))
363363
possible_deltas = set()
364364
for prefix_throughput in prefix_throughputs:
365365
for previous_prefix_throughput in previous_prefix_throughputs:
@@ -676,10 +676,10 @@ def test_training_with_full_variable_list(self):
676676
)
677677
biases = model._variable_groups[TestModelWithVarGroups.BIAS]
678678
for bias in biases:
679-
self.assertNotAlmostEqual(float(bias), -0.5)
679+
self.assertNotAlmostEqual(float(tf.squeeze(bias)), -0.5)
680680
weights = model._variable_groups[TestModelWithVarGroups.WEIGHTS]
681681
for weight in weights:
682-
self.assertNotAlmostEqual(float(weight), 0.5)
682+
self.assertNotAlmostEqual(float(tf.squeeze(weight)), 0.5)
683683

684684
def test_training_bias_only(self):
685685
task_list = ['foo', 'bar']
@@ -698,10 +698,10 @@ def test_training_bias_only(self):
698698
)
699699
biases = model._variable_groups[TestModelWithVarGroups.BIAS]
700700
for bias in biases:
701-
self.assertNotAlmostEqual(float(bias), -0.5)
701+
self.assertNotAlmostEqual(float(tf.squeeze(bias)), -0.5)
702702
weights = model._variable_groups[TestModelWithVarGroups.WEIGHTS]
703703
for weight in weights:
704-
self.assertAlmostEqual(float(weight), 0.5)
704+
self.assertAlmostEqual(float(tf.squeeze(weight)), 0.5)
705705

706706
def test_grad_clipping(self):
707707
task_list = ['foo', 'bar']
@@ -737,10 +737,10 @@ def test_training_weight_only(self):
737737
)
738738
biases = model._variable_groups[TestModelWithVarGroups.BIAS]
739739
for bias in biases:
740-
self.assertAlmostEqual(float(bias), -0.5)
740+
self.assertAlmostEqual(float(tf.squeeze(bias)), -0.5)
741741
weights = model._variable_groups[TestModelWithVarGroups.WEIGHTS]
742742
for weight in weights:
743-
self.assertNotAlmostEqual(float(weight), 0.5)
743+
self.assertNotAlmostEqual(float(tf.squeeze(weight)), 0.5)
744744

745745

746746
if __name__ == '__main__':

0 commit comments

Comments
 (0)