|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | from pyearthtools.pipeline.operations.numpy import reshape |
| 16 | +from unittest.mock import MagicMock |
16 | 17 |
|
17 | 18 | import numpy as np |
18 | 19 | import pytest |
@@ -115,6 +116,53 @@ def test_Flattener_1_dim(): |
115 | 116 | assert np.all(undo_output == random_array), "Undo Flatten 1 dimension." |
116 | 117 |
|
117 | 118 |
|
| 119 | +def success_then_fail(self, *args, **kwargs): |
| 120 | + yield self |
| 121 | + raise ValueError() |
| 122 | + |
| 123 | + |
| 124 | +def test_Flattener_exceptions(): |
| 125 | + """Tests all the exceptions that can be raised in the Flattener class.""" |
| 126 | + # try instantiating flattener with invalid dim |
| 127 | + with pytest.raises(ValueError): |
| 128 | + reshape.Flattener(flatten_dims=0) |
| 129 | + |
| 130 | + # test undo without apply |
| 131 | + f = reshape.Flattener(shape_attempt=(2, 1, 1)) |
| 132 | + random_array = np.random.randn(4, 3, 5) |
| 133 | + with pytest.raises(RuntimeError): |
| 134 | + f.undo(random_array) |
| 135 | + |
| 136 | + # _configure_shape_attempt error when apply not run |
| 137 | + with pytest.raises(RuntimeError): |
| 138 | + f._configure_shape_attempt() |
| 139 | + |
| 140 | + # test undo when flatten_dims unset |
| 141 | + output = f.apply(random_array) |
| 142 | + f.flatten_dims = None # "accidentally" overwrite the dims |
| 143 | + with pytest.raises(RuntimeError): |
| 144 | + f.undo(output) |
| 145 | + |
| 146 | + # setup flattener |
| 147 | + mock_array = MagicMock() |
| 148 | + mock_array.__len__.return_value = 1 |
| 149 | + mock_array.shape = tuple([1]) |
| 150 | + mock_array.reshape.return_value = mock_array |
| 151 | + f = reshape.Flattener() |
| 152 | + output = f.apply(mock_array) |
| 153 | + |
| 154 | + # trigger ValueError in undo when reshape fails |
| 155 | + mock_array.reshape.side_effect = ValueError |
| 156 | + with pytest.raises(ValueError): |
| 157 | + f.undo(mock_array) |
| 158 | + |
| 159 | + # error when input array shape not same rank as shape_attempt |
| 160 | + f = reshape.Flattener(shape_attempt=("...", 2)) |
| 161 | + output = f.apply(random_array) |
| 162 | + with pytest.raises(IndexError): |
| 163 | + f.undo(output) |
| 164 | + |
| 165 | + |
118 | 166 | def test_Flatten(): |
119 | 167 | f1 = reshape.Flatten(flatten_dims=2) |
120 | 168 | random_array = np.random.randn(4, 3, 5) |
@@ -157,10 +205,36 @@ def test_Flatten_with_shape_attempt_with_ellipses(): |
157 | 205 | assert f.undo_func(undo_data).shape == (2, 1, 1, 1) |
158 | 206 |
|
159 | 207 |
|
| 208 | +def test_Flatten_with_many_arrays(): |
| 209 | + incoming_data = (np.zeros((8, 1, 3, 3)), np.zeros((8, 1, 3, 6))) |
| 210 | + f = reshape.Flatten() |
| 211 | + output = f.apply_func(incoming_data) |
| 212 | + assert isinstance(output, tuple) |
| 213 | + assert output[0].shape == (8 * 1 * 3 * 3,) |
| 214 | + assert output[1].shape == (8 * 1 * 3 * 6,) |
| 215 | + # undo |
| 216 | + output = f.undo(output) |
| 217 | + assert isinstance(output, tuple) |
| 218 | + assert output[0].shape == incoming_data[0].shape |
| 219 | + assert output[1].shape == incoming_data[1].shape |
| 220 | + |
| 221 | + |
160 | 222 | def test_SwapAxis(): |
161 | 223 | s = reshape.SwapAxis(1, 3) |
162 | 224 | random_array = np.random.randn(5, 7, 8, 2) |
163 | 225 | output = s.apply_func(random_array) |
164 | 226 | assert output.shape == (5, 2, 8, 7), "Swap axes 1 and 3" |
165 | 227 | undo_output = s.undo_func(output) |
166 | 228 | assert np.all(undo_output == random_array), "Undo axis swap." |
| 229 | + |
| 230 | + |
| 231 | +def test_Flattener_prod_shape_helper(): |
| 232 | + """Tests the Flattener._prod_shape method with numpy input.""" |
| 233 | + f = reshape.Flattener() |
| 234 | + data = np.array( |
| 235 | + ( |
| 236 | + (1, 2, 3), |
| 237 | + (4, 5, 6), |
| 238 | + ) |
| 239 | + ) |
| 240 | + assert f._prod_shape(data) == 6 # product of data shape |
0 commit comments