Skip to content

Commit 8daadaf

Browse files
authored
Add tests for select ops and update numpy+dask Select class (#233)
* add xarray select tests * align numpy select functionality with its doc string this fixes the implementation of numpy.Select so that the results match what is stated in the doc string. This changes the implementation to use indexing with native slices. * add numpy select tests * align dask select functionality with its doc string this fixes the implementation of dask.Select so that the results match what is stated in the doc string. This changes the implementation to use indexing with native slices. * add dask select tests
1 parent 4077e7f commit 8daadaf

File tree

5 files changed

+239
-28
lines changed

5 files changed

+239
-28
lines changed

packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,20 @@ def __init__(
6565
self.tuple_index = tuple_index
6666

6767
def _index(self, data, array_index):
68-
shape = data.shape
69-
for i, index in enumerate(reversed(array_index)):
70-
if index is None:
71-
pass
72-
selected_data = da.take(data, indices=index, axis=-(i + 1))
73-
if len(selected_data.shape) < len(shape):
74-
selected_data = da.expand_dims(selected_data, axis=-(i + 1))
75-
data = selected_data
76-
return data
68+
# below comprehension:
69+
# - ensures indexer is tuple (requirement)
70+
# - converts instances of None into slice(None)
71+
indexer = tuple(slice(None) if index is None else index for index in array_index)
72+
return data[indexer]
7773

7874
def apply_func(self, data):
7975
array_index = self.array_index
8076

8177
if isinstance(data, tuple):
82-
data = list(data)
8378
if self.tuple_index is None:
84-
return tuple(map(lambda x: self._index(x, array_index), data))
79+
return tuple(self._index(x, array_index) for x in data)
8580

86-
data[self.tuple_index] = self._index(data[self.tuple_index], array_index)
87-
data = tuple(data)
88-
return data
81+
return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data))
8982

9083
return self._index(data, array_index)
9184

packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,20 @@ def __init__(
6363
self.tuple_index = tuple_index
6464

6565
def _index(self, data, array_index):
66-
shape = data.shape
67-
for i, index in enumerate(reversed(array_index)):
68-
if index is None:
69-
pass
70-
selected_data = np.take(data, indices=index, axis=-(i + 1))
71-
if len(selected_data.shape) < len(shape):
72-
selected_data = np.expand_dims(selected_data, axis=-(i + 1))
73-
data = selected_data
74-
return data
66+
# below comprehension:
67+
# - ensures indexer is tuple (requirement)
68+
# - converts instances of None into slice(None)
69+
indexer = tuple(slice(None) if index is None else index for index in array_index)
70+
return data[indexer]
7571

7672
def apply_func(self, data):
7773
array_index = self.array_index
7874

7975
if isinstance(data, tuple):
80-
data = list(data)
8176
if self.tuple_index is None:
82-
return tuple(map(lambda x: self._index(x, array_index), data))
77+
return tuple(self._index(x, array_index) for x in data)
8378

84-
data[self.tuple_index] = self._index(data[self.tuple_index], array_index)
85-
data = tuple(data)
86-
return data
79+
return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data))
8780

8881
return self._index(data, array_index)
8982

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright Commonwealth of Australia, Bureau of Meteorology 2024.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import dask.array as da
17+
18+
19+
from pyearthtools.pipeline.operations.dask import select
20+
21+
22+
@pytest.fixture(scope="module")
23+
def sample():
24+
"""Test dask array."""
25+
return da.array(range(24)).reshape((2, 3, 4))
26+
27+
28+
def test_Select(sample):
29+
"""Tests the Select dask operation."""
30+
31+
s = select.Select([0])
32+
33+
output = s.apply_func(sample)
34+
35+
assert output.shape == (3, 4)
36+
assert (output == sample[0, :, :]).all().compute()
37+
38+
# multi-dimensional indexing
39+
s = select.Select([0, None, 3])
40+
41+
output = s.apply_func(sample)
42+
43+
assert output.shape == (3,)
44+
assert (output == sample[0, :, 3]).all().compute()
45+
46+
# pass tuple of arrays
47+
output = s.apply_func((sample, sample))
48+
for arr in output:
49+
assert arr.shape == (3,)
50+
assert (arr == sample[0, :, 3]).all().compute()
51+
52+
# pass tuple of arrays with tuple index
53+
s = select.Select(array_index=(0,), tuple_index=1)
54+
output = s.apply_func((sample, sample))
55+
assert output[0].shape == sample.shape
56+
assert (output[0] == sample).all().compute()
57+
assert output[1].shape == (3, 4)
58+
assert (output[1] == sample[0]).all().compute()
59+
60+
61+
def test_Slice(sample):
62+
"""Tests the Slice dask operation."""
63+
64+
s = select.Slice((1,), (2,), (1, 4))
65+
output = s.apply_func(sample)
66+
assert output.shape == (1, 2, 3)
67+
assert (output == sample[:1, :2, 1:4]).all().compute()
68+
69+
# test reverse_slice
70+
s = select.Slice((1,), (2,), reverse_slice=True)
71+
output = s.apply_func(sample)
72+
assert output.shape == (2, 1, 2)
73+
assert (output == sample[:, :1, :2]).all().compute()
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright Commonwealth of Australia, Bureau of Meteorology 2024.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import numpy as np
17+
18+
19+
from pyearthtools.pipeline.operations.numpy import select
20+
21+
22+
@pytest.fixture(scope="module")
23+
def sample():
24+
"""Test numpy array."""
25+
return np.array(range(24)).reshape((2, 3, 4))
26+
27+
28+
def test_Select(sample):
29+
"""Tests the Select numpy operation."""
30+
31+
s = select.Select([0])
32+
33+
output = s.apply_func(sample)
34+
35+
assert output.shape == (3, 4)
36+
assert np.array_equal(output, sample[0, :, :])
37+
38+
# multi-dimensional indexing
39+
s = select.Select([0, None, 3])
40+
41+
output = s.apply_func(sample)
42+
43+
assert output.shape == (3,)
44+
assert np.array_equal(output, sample[0, :, 3])
45+
46+
# pass tuple of arrays
47+
output = s.apply_func((sample, sample))
48+
for arr in output:
49+
assert arr.shape == (3,)
50+
assert np.array_equal(arr, sample[0, :, 3])
51+
52+
# pass tuple of arrays with tuple index
53+
s = select.Select(array_index=(0,), tuple_index=1)
54+
output = s.apply_func((sample, sample))
55+
assert output[0].shape == sample.shape
56+
assert np.array_equal(output[0], sample)
57+
assert output[1].shape == (3, 4)
58+
assert np.array_equal(output[1], sample[0])
59+
60+
61+
def test_Slice(sample):
62+
"""Tests the Slice numpy operations."""
63+
64+
s = select.Slice((1,), (2,), (1, 4))
65+
output = s.apply_func(sample)
66+
assert output.shape == (1, 2, 3)
67+
assert np.array_equal(output, sample[:1, :2, 1:4])
68+
69+
# test reverse_slice
70+
s = select.Slice((1,), (2,), reverse_slice=True)
71+
output = s.apply_func(sample)
72+
assert output.shape == (2, 1, 2)
73+
assert np.array_equal(output, sample[:, :1, :2])
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright Commonwealth of Australia, Bureau of Meteorology 2024.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import xarray as xr
16+
import pytest
17+
import numpy as np
18+
19+
20+
from pyearthtools.pipeline.operations.xarray import select
21+
22+
23+
@pytest.fixture(scope="module")
24+
def sample():
25+
"""Test xarray dataset."""
26+
coords = {"dim0": range(3), "dim1": range(3)}
27+
return xr.Dataset(
28+
{
29+
"var1": xr.DataArray(np.array(range(9)).reshape((3, 3)), coords),
30+
"var2": xr.DataArray(np.array(range(9, 18)).reshape((3, 3)), coords),
31+
},
32+
)
33+
34+
35+
def test_SelectDataset(sample):
36+
"""Tests the SelectDataset xarray operation."""
37+
38+
s = select.SelectDataset(("var1",))
39+
40+
output = s.apply_func(sample)
41+
42+
assert "var1" in output
43+
assert "var2" not in output
44+
assert output["var1"].equals(sample["var1"])
45+
46+
47+
def test_DropDataset(sample):
48+
"""Tests the DropDataset xarray operation."""
49+
50+
s = select.DropDataset(("var1",))
51+
52+
output = s.apply_func(sample)
53+
assert "var1" not in output
54+
assert "var2" in output
55+
assert output["var2"].equals(sample["var2"])
56+
57+
58+
def test_SliceDataset(sample):
59+
"""Tests the SliceDataset xarray operation."""
60+
61+
args = {"dim0": (0, 2, 2), "dim1": (0, 1)}
62+
63+
def test_slicer(slicer, sample):
64+
65+
output = s.apply_func(sample)
66+
67+
assert np.array_equal(output.coords["dim0"].values, [0, 2])
68+
assert np.array_equal(output.coords["dim1"].values, [0, 1])
69+
70+
# test passing dict to SliceDataset
71+
s = select.SliceDataset(args)
72+
test_slicer(s, sample)
73+
74+
# test passing kwargs to SliceDataset
75+
s = select.SliceDataset(**args)
76+
test_slicer(s, sample)
77+
78+
# test passing dataarray to slicer
79+
test_slicer(s, sample["var1"])

0 commit comments

Comments
 (0)