Skip to content

Commit 9aba8ae

Browse files
itholicHyukjinKwon
authored andcommitted
Implement rolling().count() in Series and Frame (#990)
This PR implements rolling().count(): ```python >>> import databricks.koalas as ks >>> s = ks.Series([2, 3, float("nan"), 10]) >>> s.rolling(1).count() 0 1.0 1 1.0 2 0.0 3 1.0 Name: 0, dtype: float64 >>> s.to_frame().rolling(1).count() 0 0 1.0 1 1.0 2 0.0 3 1.0 ``` Relates to #977
1 parent 5fc4b61 commit 9aba8ae

File tree

6 files changed

+168
-9
lines changed

6 files changed

+168
-9
lines changed

databricks/koalas/generic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1427,8 +1427,9 @@ def median(self, accuracy=10000):
14271427
return DataFrame(kdf._internal.copy(sdf=sdf, index_map=[('__DUMMY__', None)])) \
14281428
._to_internal_pandas().transpose().iloc[:, 0]
14291429

1430-
def rolling(self, *args, **kwargs):
1431-
return Rolling(self)
1430+
# TODO: 'center', 'win_type', 'on', 'axis' parameter should be implemented.
1431+
def rolling(self, window, min_periods=None):
1432+
return Rolling(self, window=window, min_periods=min_periods)
14321433

14331434
# TODO: 'center' and 'axis' parameter should be implemented.
14341435
# 'axis' implementation, refer https://github.com/databricks/koalas/pull/607

databricks/koalas/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1648,8 +1648,8 @@ def nunique(self, dropna=True):
16481648
F.when(F.count(F.when(col.isNull(), 1).otherwise(None)) >= 1, 1).otherwise(0))
16491649
return self._reduce_for_stat_function(stat_function, only_numeric=False)
16501650

1651-
def rolling(self, *args, **kwargs):
1652-
return RollingGroupby(self)
1651+
def rolling(self, window, *args, **kwargs):
1652+
return RollingGroupby(self, window)
16531653

16541654
def expanding(self, min_periods=1):
16551655
"""

databricks/koalas/missing/window.py

-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ class _MissingPandasLikeRolling(object):
6464
aggregate = unsupported_function_rolling("aggregate")
6565
apply = unsupported_function_rolling("apply")
6666
corr = unsupported_function_rolling("corr")
67-
count = unsupported_function_rolling("count")
6867
cov = unsupported_function_rolling("cov")
6968
kurt = unsupported_function_rolling("kurt")
7069
median = unsupported_function_rolling("median")
@@ -105,7 +104,6 @@ class _MissingPandasLikeRollingGroupby(object):
105104
aggregate = unsupported_function_rolling("aggregate")
106105
apply = unsupported_function_rolling("apply")
107106
corr = unsupported_function_rolling("corr")
108-
count = unsupported_function_rolling("count")
109107
cov = unsupported_function_rolling("cov")
110108
kurt = unsupported_function_rolling("kurt")
111109
median = unsupported_function_rolling("median")
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import databricks.koalas as ks
19+
from databricks.koalas.testing.utils import ReusedSQLTestCase, TestUtils
20+
from databricks.koalas.window import Rolling
21+
22+
23+
class RollingTests(ReusedSQLTestCase, TestUtils):
24+
25+
def test_rolling_error(self):
26+
with self.assertRaisesRegex(ValueError, "window must be >= 0"):
27+
ks.range(10).rolling(window=-1)
28+
with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
29+
ks.range(10).rolling(window=1, min_periods=-1)
30+
31+
with self.assertRaisesRegex(
32+
TypeError,
33+
"kdf_or_kser must be a series or dataframe; however, got:.*int"):
34+
Rolling(1, 2)
35+
36+
def _test_rolling_func(self, f):
37+
kser = ks.Series([1, 2, 3])
38+
pser = kser.to_pandas()
39+
self.assert_eq(repr(getattr(kser.rolling(2), f)()), repr(getattr(pser.rolling(2), f)()))
40+
41+
kdf = ks.DataFrame({'a': [1, 2, 3, 2], 'b': [4.0, 2.0, 3.0, 1.0]})
42+
pdf = kdf.to_pandas()
43+
self.assert_eq(repr(getattr(kdf.rolling(2), f)()), repr(getattr(pdf.rolling(2), f)()))
44+
45+
def test_rolling_count(self):
46+
self._test_rolling_func("count")

databricks/koalas/window.py

+108-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,26 @@ class _RollingAndExpanding(object):
3333

3434

3535
class Rolling(_RollingAndExpanding):
36-
def __init__(self, obj):
37-
self.obj = obj
36+
def __init__(self, kdf_or_kser, window, min_periods=None):
37+
from databricks.koalas import DataFrame, Series
38+
from databricks.koalas.groupby import SeriesGroupBy, DataFrameGroupBy
39+
window = window - 1
40+
min_periods = min_periods if min_periods is not None else 0
41+
42+
if window < 0:
43+
raise ValueError("window must be >= 0")
44+
if (min_periods is not None) and (min_periods < 0):
45+
raise ValueError("min_periods must be >= 0")
46+
self._window_val = window
47+
self._min_periods = min_periods
48+
self.kdf_or_kser = kdf_or_kser
49+
if not isinstance(kdf_or_kser, (DataFrame, Series, DataFrameGroupBy, SeriesGroupBy)):
50+
raise TypeError(
51+
"kdf_or_kser must be a series or dataframe; however, got: %s" % type(kdf_or_kser))
52+
if isinstance(kdf_or_kser, (DataFrame, Series)):
53+
self._index_scols = kdf_or_kser._internal.index_scols
54+
self._window = Window.orderBy(self._index_scols).rowsBetween(
55+
Window.currentRow - window, Window.currentRow)
3856

3957
def __getattr__(self, item: str) -> Any:
4058
if hasattr(_MissingPandasLikeRolling, item):
@@ -45,6 +63,91 @@ def __getattr__(self, item: str) -> Any:
4563
return partial(property_or_func, self)
4664
raise AttributeError(item)
4765

66+
def _apply_as_series_or_frame(self, func):
67+
"""
68+
Decorator that can wraps a function that handles Spark column in order
69+
to support it in both Koalas Series and DataFrame.
70+
Note that the given `func` name should be same as the API's method name.
71+
"""
72+
from databricks.koalas import DataFrame, Series
73+
74+
if isinstance(self.kdf_or_kser, Series):
75+
kser = self.kdf_or_kser
76+
return kser._with_new_scol(
77+
func(kser._scol)).rename(kser.name)
78+
elif isinstance(self.kdf_or_kser, DataFrame):
79+
kdf = self.kdf_or_kser
80+
applied = []
81+
for column in kdf.columns:
82+
applied.append(
83+
getattr(kdf[column].rolling(self._window_val + 1,
84+
self._min_periods), func.__name__)())
85+
86+
sdf = kdf._sdf.select(
87+
kdf._internal.index_scols + [c._scol for c in applied])
88+
internal = kdf._internal.copy(
89+
sdf=sdf,
90+
data_columns=[c._internal.data_columns[0] for c in applied],
91+
column_index=[c._internal.column_index[0] for c in applied])
92+
return DataFrame(internal)
93+
94+
def count(self):
95+
"""
96+
The rolling count of any non-NaN observations inside the window.
97+
98+
.. note:: the current implementation of this API uses Spark's Window without
99+
specifying partition specification. This leads to move all data into
100+
single partition in single machine and could cause serious
101+
performance degradation. Avoid this method against very large dataset.
102+
103+
Returns
104+
-------
105+
Series or DataFrame
106+
Returned object type is determined by the caller of the rolling
107+
calculation.
108+
109+
See Also
110+
--------
111+
Series.rolling : Calling object with Series data.
112+
DataFrame.rolling : Calling object with DataFrames.
113+
DataFrame.count : Count of the full DataFrame.
114+
115+
Examples
116+
--------
117+
>>> s = ks.Series([2, 3, float("nan"), 10])
118+
>>> s.rolling(1).count()
119+
0 1.0
120+
1 1.0
121+
2 0.0
122+
3 1.0
123+
Name: 0, dtype: float64
124+
125+
>>> s.rolling(3).count()
126+
0 1.0
127+
1 2.0
128+
2 2.0
129+
3 2.0
130+
Name: 0, dtype: float64
131+
132+
>>> s.to_frame().rolling(1).count()
133+
0
134+
0 1.0
135+
1 1.0
136+
2 0.0
137+
3 1.0
138+
139+
>>> s.to_frame().rolling(3).count()
140+
0
141+
0 1.0
142+
1 2.0
143+
2 2.0
144+
3 2.0
145+
"""
146+
def count(scol):
147+
return F.count(scol).over(self._window)
148+
149+
return self._apply_as_series_or_frame(count).astype('float64')
150+
48151

49152
class RollingGroupby(Rolling):
50153
def __getattr__(self, item: str) -> Any:
@@ -56,6 +159,9 @@ def __getattr__(self, item: str) -> Any:
56159
return partial(property_or_func, self)
57160
raise AttributeError(item)
58161

162+
def count(self):
163+
raise NotImplementedError("groupby.rolling().count() is currently not implemented yet.")
164+
59165

60166
class Expanding(_RollingAndExpanding):
61167
def __init__(self, kdf_or_kser, min_periods=1):

docs/source/reference/window.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
======
22
Window
33
======
4+
.. currentmodule:: databricks.koalas.window
45

56
Rolling objects are returned by ``.rolling`` calls: :func:`koalas.DataFrame.rolling`, :func:`koalas.Series.rolling`, etc.
67
Expanding objects are returned by ``.expanding`` calls: :func:`koalas.DataFrame.expanding`, :func:`koalas.Series.expanding`, etc.
78

9+
Standard moving window functions
10+
--------------------------------
11+
12+
.. autosummary::
13+
:toctree: api/
14+
15+
Rolling.count
16+
817
Standard expanding window functions
918
-----------------------------------
10-
.. currentmodule:: databricks.koalas.window
1119

1220
.. autosummary::
1321
:toctree: api/

0 commit comments

Comments
 (0)