Skip to content

Commit 1738cdc

Browse files
edoyangotennlee
andauthored
reverse logic for DropValue filters (#232)
This aligns the logic with the doc strings. Now a PipelineFilterException is raised when the number of values in the input is equal to or greater than the supplied percentage threshold. Before, the exception would be raised if the number of values is strictly less than the supplied threshold. Co-authored-by: Tennessee Leeuwenburg <[email protected]>
1 parent 0b5ae9e commit 1738cdc

File tree

6 files changed

+39
-39
lines changed

6 files changed

+39
-39
lines changed

packages/pipeline/src/pyearthtools/pipeline/operations/dask/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def filter(self, sample: da.Array) -> None:
133133
lambda x: ((da.count_nonzero(x == self._value) / math.prod(x.shape)) * 100) >= self._percentage
134134
) # noqa
135135

136-
if not function(sample):
136+
if function(sample):
137137
raise PipelineFilterException(sample, f"Data contained more than {self._percentage}% of {self._value}.")
138138

139139

packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def filter(self, sample: np.ndarray) -> None:
135135
lambda x: ((np.count_nonzero(x == self._value) / math.prod(x.shape)) * 100) >= self._percentage
136136
) # noqa
137137

138-
if not function(sample):
138+
if function(sample):
139139
raise PipelineFilterException(sample, f"Data contained more than {self._percentage}% of {self._value}.")
140140

141141

packages/pipeline/src/pyearthtools/pipeline/operations/xarray/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def filter(self, sample: T) -> None:
187187
else:
188188
raise TypeError("This filter only accepts xr.DataArray or xr.Dataset")
189189

190-
if not drop:
190+
if drop:
191191
raise PipelineFilterException(sample, f"Data contained more than {self._percentage}% of {self._value}.")
192192

193193

packages/pipeline/tests/operations/dask/test_dask_filter.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,24 +62,24 @@ def test_DropValue():
6262

6363
original = da.from_array([[0, 0], [1, 2]])
6464

65-
# drop case (num zeros < threshold)
65+
# non-drop case (num zeros < threshold)
6666
drop = filters.DropValue(0, 75)
67-
with pytest.raises(PipelineFilterException):
68-
drop.filter(original)
67+
drop.filter(original)
6968

70-
# non-drop case (num zeros >= threshold)
69+
# drop case (num zeros >= threshold)
7170
drop = filters.DropValue(0, 50)
72-
drop.filter(original)
71+
with pytest.raises(PipelineFilterException):
72+
drop.filter(original)
7373

74-
# drop case (num nans < threshold)
74+
# non-drop case (num nans < threshold)
7575
original = da.from_array([[np.nan, np.nan], [1, 2]])
7676
drop = filters.DropValue("nan", 75)
77-
with pytest.raises(PipelineFilterException):
78-
drop.filter(original)
77+
drop.filter(original)
7978

80-
# non-drop case (num nans >= threshold)
79+
# drop case (num nans >= threshold)
8180
drop = filters.DropValue("nan", 50)
82-
drop.filter(original)
81+
with pytest.raises(PipelineFilterException):
82+
drop.filter(original)
8383

8484

8585
def test_Shape():

packages/pipeline/tests/operations/numpy/test_numpy_filter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,22 @@ def test_DropValue():
6464

6565
drop = filters.DropValue(value=1, percentage=75)
6666

67-
with pytest.raises(PipelineFilterException):
68-
drop.filter(original)
67+
drop.filter(original)
6968

70-
# test no drop case
69+
# test drop case
7170
drop = filters.DropValue(value=1, percentage=50)
72-
drop.filter(original)
71+
with pytest.raises(PipelineFilterException):
72+
drop.filter(original)
7373

74-
# test with nan - drop case
74+
# test with nan - non-drop case
7575
drop = filters.DropValue(value="nan", percentage=75)
7676

77-
with pytest.raises(PipelineFilterException):
78-
drop.filter(original)
77+
drop.filter(original)
7978

80-
# no drop case
79+
# drop case
8180
drop = filters.DropValue(value="nan", percentage=50)
82-
drop.filter(original)
81+
with pytest.raises(PipelineFilterException):
82+
drop.filter(original)
8383

8484

8585
def test_Shape():

packages/pipeline/tests/operations/xarray/test_xarray_filter.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,42 +97,42 @@ def test_DropValue():
9797
{"var1": xr.DataArray(np.array([[1, 1], [3, 4]])), "var2": xr.DataArray(np.array([[np.nan, np.nan], [6, 7]]))}
9898
)
9999

100-
# check var1 of dataset drop case
100+
# check var1 of dataset non-drop case
101101
drop = filters.DropValue(1, 75)
102-
with pytest.raises(PipelineFilterException):
103-
drop.filter(original["var1"])
102+
drop.filter(original["var1"])
104103

105-
# check var1 of dataset non-drop case
104+
# check var1 of dataset drop case
106105
drop = filters.DropValue(1, 50)
107-
drop.filter(original["var1"])
106+
with pytest.raises(PipelineFilterException):
107+
drop.filter(original["var1"])
108108

109-
# check var2 of dataset drop case (using nan)
109+
# check var2 of dataset non-drop case (using nan)
110110
drop = filters.DropValue("nan", 75)
111+
drop.filter(original["var2"])
112+
113+
# check var2 of dataset drop case
114+
drop = filters.DropValue("nan", 50)
111115
with pytest.raises(PipelineFilterException):
112116
drop.filter(original["var2"])
113117

114-
# check var2 of dataset non-drop case
115-
drop = filters.DropValue("nan", 50)
116-
drop.filter(original["var2"])
118+
# check whole dataset non-drop case
119+
drop = filters.DropValue(1, 50)
120+
drop.filter(original)
117121

118122
# check whole dataset drop case
119-
drop = filters.DropValue(1, 50)
123+
drop = filters.DropValue(1, 10)
120124
with pytest.raises(PipelineFilterException):
121125
drop.filter(original)
122126

123-
# check whole dataset non-drop case
124-
drop = filters.DropValue(1, 10)
127+
# check whole dataset nan non-drop case
128+
drop = filters.DropValue("nan", 50)
125129
drop.filter(original)
126130

127131
# check whole dataset nan drop case
128-
drop = filters.DropValue("nan", 50)
132+
drop = filters.DropValue("nan", 10)
129133
with pytest.raises(PipelineFilterException):
130134
drop.filter(original)
131135

132-
# check whole dataset nan non-drop case
133-
drop = filters.DropValue("nan", 10)
134-
drop.filter(original)
135-
136136
# check invalid type
137137
with pytest.raises(TypeError):
138138
drop.filter(np.empty((1, 1)))

0 commit comments

Comments
 (0)