Skip to content

Commit 69fbe00

Browse files
committed
100% coverage for numpy reshape
1 parent 93e9b47 commit 69fbe00

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

packages/pipeline/tests/operations/numpy/test_numpy_reshape.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from pyearthtools.pipeline.operations.numpy import reshape
16+
from unittest.mock import MagicMock
1617

1718
import numpy as np
1819
import pytest
@@ -115,6 +116,53 @@ def test_Flattener_1_dim():
115116
assert np.all(undo_output == random_array), "Undo Flatten 1 dimension."
116117

117118

119+
def success_then_fail(self, *args, **kwargs):
120+
yield self
121+
raise ValueError()
122+
123+
124+
def test_Flattener_exceptions():
125+
"""Tests all the exceptions that can be raised in the Flattener class."""
126+
# try instantiating flattener with invalid dim
127+
with pytest.raises(ValueError):
128+
reshape.Flattener(flatten_dims=0)
129+
130+
# test undo without apply
131+
f = reshape.Flattener(shape_attempt=(2, 1, 1))
132+
random_array = np.random.randn(4, 3, 5)
133+
with pytest.raises(RuntimeError):
134+
f.undo(random_array)
135+
136+
# _configure_shape_attempt error when apply not run
137+
with pytest.raises(RuntimeError):
138+
f._configure_shape_attempt()
139+
140+
# test undo when flatten_dims unset
141+
output = f.apply(random_array)
142+
f.flatten_dims = None # "accidentally" overwrite the dims
143+
with pytest.raises(RuntimeError):
144+
f.undo(output)
145+
146+
# setup flattener
147+
mock_array = MagicMock()
148+
mock_array.__len__.return_value = 1
149+
mock_array.shape = tuple([1])
150+
mock_array.reshape.return_value = mock_array
151+
f = reshape.Flattener()
152+
output = f.apply(mock_array)
153+
154+
# trigger ValueError in undo when reshape fails
155+
mock_array.reshape.side_effect = ValueError
156+
with pytest.raises(ValueError):
157+
f.undo(mock_array)
158+
159+
# error when input array shape not same rank as shape_attempt
160+
f = reshape.Flattener(shape_attempt=("...", 2))
161+
output = f.apply(random_array)
162+
with pytest.raises(IndexError):
163+
f.undo(output)
164+
165+
118166
def test_Flatten():
119167
f1 = reshape.Flatten(flatten_dims=2)
120168
random_array = np.random.randn(4, 3, 5)
@@ -157,10 +205,36 @@ def test_Flatten_with_shape_attempt_with_ellipses():
157205
assert f.undo_func(undo_data).shape == (2, 1, 1, 1)
158206

159207

208+
def test_Flatten_with_many_arrays():
209+
incoming_data = (np.zeros((8, 1, 3, 3)), np.zeros((8, 1, 3, 6)))
210+
f = reshape.Flatten()
211+
output = f.apply_func(incoming_data)
212+
assert isinstance(output, tuple)
213+
assert output[0].shape == (8 * 1 * 3 * 3,)
214+
assert output[1].shape == (8 * 1 * 3 * 6,)
215+
# undo
216+
output = f.undo(output)
217+
assert isinstance(output, tuple)
218+
assert output[0].shape == incoming_data[0].shape
219+
assert output[1].shape == incoming_data[1].shape
220+
221+
160222
def test_SwapAxis():
161223
s = reshape.SwapAxis(1, 3)
162224
random_array = np.random.randn(5, 7, 8, 2)
163225
output = s.apply_func(random_array)
164226
assert output.shape == (5, 2, 8, 7), "Swap axes 1 and 3"
165227
undo_output = s.undo_func(output)
166228
assert np.all(undo_output == random_array), "Undo axis swap."
229+
230+
231+
def test_Flattener_prod_shape_helper():
232+
"""Tests the Flattener._prod_shape method with numpy input."""
233+
f = reshape.Flattener()
234+
data = np.array(
235+
(
236+
(1, 2, 3),
237+
(4, 5, 6),
238+
)
239+
)
240+
assert f._prod_shape(data) == 6 # product of data shape

0 commit comments

Comments
 (0)