Skip to content

Commit b140135

Browse files
authored
REF: simplify ohlc (pandas-dev#41091)
1 parent 3435ebf commit b140135

File tree

3 files changed

+29
-30
lines changed

3 files changed

+29
-30
lines changed

pandas/core/groupby/generic.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -363,20 +363,10 @@ def _cython_agg_general(
363363
result = self.grouper._cython_operation(
364364
"aggregate", obj._values, how, axis=0, min_count=min_count
365365
)
366-
367-
if how == "ohlc":
368-
# e.g. ohlc
369-
agg_names = ["open", "high", "low", "close"]
370-
assert len(agg_names) == result.shape[1]
371-
for result_column, result_name in zip(result.T, agg_names):
372-
key = base.OutputKey(label=result_name, position=idx)
373-
output[key] = result_column
374-
idx += 1
375-
else:
376-
assert result.ndim == 1
377-
key = base.OutputKey(label=name, position=idx)
378-
output[key] = result
379-
idx += 1
366+
assert result.ndim == 1
367+
key = base.OutputKey(label=name, position=idx)
368+
output[key] = result
369+
idx += 1
380370

381371
if not output:
382372
raise DataError("No numeric types to aggregate")
@@ -942,10 +932,6 @@ def count(self) -> Series:
942932
)
943933
return self._reindex_output(result, fill_value=0)
944934

945-
def _apply_to_column_groupbys(self, func):
946-
""" return a pass thru """
947-
return func(self)
948-
949935
def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None):
950936
"""Calculate pct_change of each value to previous entry in group"""
951937
# TODO: Remove this conditional when #23918 is fixed
@@ -1137,6 +1123,7 @@ def _cython_agg_general(
11371123
def _cython_agg_manager(
11381124
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
11391125
) -> Manager2D:
1126+
# Note: we never get here with how="ohlc"; that goes through SeriesGroupBy
11401127

11411128
data: Manager2D = self._get_data_to_aggregate()
11421129

@@ -1227,21 +1214,13 @@ def array_func(values: ArrayLike) -> ArrayLike:
12271214
# generally if we have numeric_only=False
12281215
# and non-applicable functions
12291216
# try to python agg
1230-
1231-
if alt is None:
1232-
# we cannot perform the operation
1233-
# in an alternate way, exclude the block
1234-
assert how == "ohlc"
1235-
raise
1236-
12371217
result = py_fallback(values)
12381218

12391219
return cast_agg_result(result, values, how)
12401220
return result
12411221

12421222
# TypeError -> we may have an exception in trying to aggregate
12431223
# continue and exclude the block
1244-
# NotImplementedError -> "ohlc" with wrong dtype
12451224
new_mgr = data.grouped_reduce(array_func, ignore_failures=True)
12461225

12471226
if not len(new_mgr):

pandas/core/groupby/groupby.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,25 @@ def ohlc(self) -> DataFrame:
17911791
DataFrame
17921792
Open, high, low and close values within each group.
17931793
"""
1794-
return self._apply_to_column_groupbys(lambda x: x._cython_agg_general("ohlc"))
1794+
if self.obj.ndim == 1:
1795+
# self._iterate_slices() yields only self._selected_obj
1796+
obj = self._selected_obj
1797+
1798+
is_numeric = is_numeric_dtype(obj.dtype)
1799+
if not is_numeric:
1800+
raise DataError("No numeric types to aggregate")
1801+
1802+
res_values = self.grouper._cython_operation(
1803+
"aggregate", obj._values, "ohlc", axis=0, min_count=-1
1804+
)
1805+
1806+
agg_names = ["open", "high", "low", "close"]
1807+
result = self.obj._constructor_expanddim(
1808+
res_values, index=self.grouper.result_index, columns=agg_names
1809+
)
1810+
return self._reindex_output(result)
1811+
1812+
return self._apply_to_column_groupbys(lambda x: x.ohlc())
17951813

17961814
@final
17971815
@doc(DataFrame.describe)

pandas/tests/resample/test_datetime_index.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,16 @@ def test_custom_grouper(index):
5858
g = s.groupby(b)
5959

6060
# check all cython functions work
61-
funcs = ["add", "mean", "prod", "ohlc", "min", "max", "var"]
61+
g.ohlc() # doesn't use _cython_agg_general
62+
funcs = ["add", "mean", "prod", "min", "max", "var"]
6263
for f in funcs:
6364
g._cython_agg_general(f)
6465

6566
b = Grouper(freq=Minute(5), closed="right", label="right")
6667
g = s.groupby(b)
6768
# check all cython functions work
68-
funcs = ["add", "mean", "prod", "ohlc", "min", "max", "var"]
69+
g.ohlc() # doesn't use _cython_agg_general
70+
funcs = ["add", "mean", "prod", "min", "max", "var"]
6971
for f in funcs:
7072
g._cython_agg_general(f)
7173

@@ -79,7 +81,7 @@ def test_custom_grouper(index):
7981
idx = DatetimeIndex(idx, freq="5T")
8082
expect = Series(arr, index=idx)
8183

82-
# GH2763 - return in put dtype if we can
84+
# GH2763 - return input dtype if we can
8385
result = g.agg(np.sum)
8486
tm.assert_series_equal(result, expect)
8587

0 commit comments

Comments
 (0)