@@ -33,8 +33,26 @@ class _RollingAndExpanding(object):
33
33
34
34
35
35
class Rolling (_RollingAndExpanding ):
36
- def __init__ (self , obj ):
37
- self .obj = obj
36
+ def __init__ (self , kdf_or_kser , window , min_periods = None ):
37
+ from databricks .koalas import DataFrame , Series
38
+ from databricks .koalas .groupby import SeriesGroupBy , DataFrameGroupBy
39
+ window = window - 1
40
+ min_periods = min_periods if min_periods is not None else 0
41
+
42
+ if window < 0 :
43
+ raise ValueError ("window must be >= 0" )
44
+ if (min_periods is not None ) and (min_periods < 0 ):
45
+ raise ValueError ("min_periods must be >= 0" )
46
+ self ._window_val = window
47
+ self ._min_periods = min_periods
48
+ self .kdf_or_kser = kdf_or_kser
49
+ if not isinstance (kdf_or_kser , (DataFrame , Series , DataFrameGroupBy , SeriesGroupBy )):
50
+ raise TypeError (
51
+ "kdf_or_kser must be a series or dataframe; however, got: %s" % type (kdf_or_kser ))
52
+ if isinstance (kdf_or_kser , (DataFrame , Series )):
53
+ self ._index_scols = kdf_or_kser ._internal .index_scols
54
+ self ._window = Window .orderBy (self ._index_scols ).rowsBetween (
55
+ Window .currentRow - window , Window .currentRow )
38
56
39
57
def __getattr__ (self , item : str ) -> Any :
40
58
if hasattr (_MissingPandasLikeRolling , item ):
@@ -45,6 +63,91 @@ def __getattr__(self, item: str) -> Any:
45
63
return partial (property_or_func , self )
46
64
raise AttributeError (item )
47
65
66
+ def _apply_as_series_or_frame (self , func ):
67
+ """
68
+ Decorator that can wraps a function that handles Spark column in order
69
+ to support it in both Koalas Series and DataFrame.
70
+ Note that the given `func` name should be same as the API's method name.
71
+ """
72
+ from databricks .koalas import DataFrame , Series
73
+
74
+ if isinstance (self .kdf_or_kser , Series ):
75
+ kser = self .kdf_or_kser
76
+ return kser ._with_new_scol (
77
+ func (kser ._scol )).rename (kser .name )
78
+ elif isinstance (self .kdf_or_kser , DataFrame ):
79
+ kdf = self .kdf_or_kser
80
+ applied = []
81
+ for column in kdf .columns :
82
+ applied .append (
83
+ getattr (kdf [column ].rolling (self ._window_val + 1 ,
84
+ self ._min_periods ), func .__name__ )())
85
+
86
+ sdf = kdf ._sdf .select (
87
+ kdf ._internal .index_scols + [c ._scol for c in applied ])
88
+ internal = kdf ._internal .copy (
89
+ sdf = sdf ,
90
+ data_columns = [c ._internal .data_columns [0 ] for c in applied ],
91
+ column_index = [c ._internal .column_index [0 ] for c in applied ])
92
+ return DataFrame (internal )
93
+
94
+ def count (self ):
95
+ """
96
+ The rolling count of any non-NaN observations inside the window.
97
+
98
+ .. note:: the current implementation of this API uses Spark's Window without
99
+ specifying partition specification. This leads to move all data into
100
+ single partition in single machine and could cause serious
101
+ performance degradation. Avoid this method against very large dataset.
102
+
103
+ Returns
104
+ -------
105
+ Series or DataFrame
106
+ Returned object type is determined by the caller of the rolling
107
+ calculation.
108
+
109
+ See Also
110
+ --------
111
+ Series.rolling : Calling object with Series data.
112
+ DataFrame.rolling : Calling object with DataFrames.
113
+ DataFrame.count : Count of the full DataFrame.
114
+
115
+ Examples
116
+ --------
117
+ >>> s = ks.Series([2, 3, float("nan"), 10])
118
+ >>> s.rolling(1).count()
119
+ 0 1.0
120
+ 1 1.0
121
+ 2 0.0
122
+ 3 1.0
123
+ Name: 0, dtype: float64
124
+
125
+ >>> s.rolling(3).count()
126
+ 0 1.0
127
+ 1 2.0
128
+ 2 2.0
129
+ 3 2.0
130
+ Name: 0, dtype: float64
131
+
132
+ >>> s.to_frame().rolling(1).count()
133
+ 0
134
+ 0 1.0
135
+ 1 1.0
136
+ 2 0.0
137
+ 3 1.0
138
+
139
+ >>> s.to_frame().rolling(3).count()
140
+ 0
141
+ 0 1.0
142
+ 1 2.0
143
+ 2 2.0
144
+ 3 2.0
145
+ """
146
+ def count (scol ):
147
+ return F .count (scol ).over (self ._window )
148
+
149
+ return self ._apply_as_series_or_frame (count ).astype ('float64' )
150
+
48
151
49
152
class RollingGroupby (Rolling ):
50
153
def __getattr__ (self , item : str ) -> Any :
@@ -56,6 +159,9 @@ def __getattr__(self, item: str) -> Any:
56
159
return partial (property_or_func , self )
57
160
raise AttributeError (item )
58
161
162
+ def count (self ):
163
+ raise NotImplementedError ("groupby.rolling().count() is currently not implemented yet." )
164
+
59
165
60
166
class Expanding (_RollingAndExpanding ):
61
167
def __init__ (self , kdf_or_kser , min_periods = 1 ):
0 commit comments