1111 Window ,
1212 WindowKeyResult ,
1313 WindowOnLateCallback ,
14+ WindowOnUpdateCallback ,
1415 get_window_ranges ,
1516)
1617
@@ -46,6 +47,7 @@ def __init__(
4647 dataframe : "StreamingDataFrame" ,
4748 step_ms : Optional [int ] = None ,
4849 on_late : Optional [WindowOnLateCallback ] = None ,
50+ on_update : Optional [WindowOnUpdateCallback ] = None ,
4951 ):
5052 super ().__init__ (
5153 name = name ,
@@ -56,6 +58,7 @@ def __init__(
5658 self ._grace_ms = grace_ms
5759 self ._step_ms = step_ms
5860 self ._on_late = on_late
61+ self ._on_update = on_update
5962
6063 self ._closing_strategy = ClosingStrategy .KEY
6164
@@ -132,6 +135,7 @@ def process_window(
132135 state = transaction .as_state (prefix = key )
133136 duration_ms = self ._duration_ms
134137 grace_ms = self ._grace_ms
138+ on_update = self ._on_update
135139
136140 collect = self .collect
137141 aggregate = self .aggregate
@@ -152,6 +156,7 @@ def process_window(
152156 max_expired_window_end = latest_timestamp - grace_ms
153157 max_expired_window_start = max_expired_window_end - duration_ms
154158 updated_windows : list [WindowKeyResult ] = []
159+ triggered_windows : list [WindowKeyResult ] = []
155160 for start , end in ranges :
156161 if start <= max_expired_window_start :
157162 late_by_ms = max_expired_window_end - timestamp_ms
@@ -169,18 +174,44 @@ def process_window(
169174 # since actual values are stored separately and combined into an array
170175 # during window expiration.
171176 aggregated = None
177+
172178 if aggregate :
173179 current_value = state .get_window (start , end )
174180 if current_value is None :
175181 current_value = self ._initialize_value ()
176182
177183 aggregated = self ._aggregate_value (current_value , value , timestamp_ms )
178- updated_windows .append (
179- (
180- key ,
181- self ._results (aggregated , [], start , end ),
182- )
183- )
184+
185+ if on_update and on_update (current_value , aggregated ):
186+ # Get collected values for the result
187+ collected = []
188+ if collect :
189+ collected = state .get_from_collection (start , end )
190+ # Add the current value that's being collected
191+ collected .append (self ._collect_value (value ))
192+
193+ result = self ._results (aggregated , collected , start , end )
194+ triggered_windows .append ((key , result ))
195+ transaction .delete_window (start , end , prefix = key )
196+ # Note: We don't delete from collection here - normal expiration
197+ # will handle cleanup for both tumbling and hopping windows
198+ continue
199+
200+ result = self ._results (aggregated , [], start , end )
201+ updated_windows .append ((key , result ))
202+ elif collect and on_update :
203+ # For collect-only windows, get the old and new collected values
204+ old_collected = state .get_from_collection (start , end )
205+ new_collected = [* old_collected , self ._collect_value (value )]
206+
207+ if on_update (old_collected , new_collected ):
208+ result = self ._results (None , new_collected , start , end )
209+ triggered_windows .append ((key , result ))
210+ transaction .delete_window (start , end , prefix = key )
211+ # Note: We don't delete from collection here - normal expiration
212+ # will handle cleanup for both tumbling and hopping windows
213+ continue
214+
184215 state .update_window (start , end , value = aggregated , timestamp_ms = timestamp_ms )
185216
186217 if collect :
@@ -198,7 +229,10 @@ def process_window(
198229 key , state , max_expired_window_start , collect
199230 )
200231
201- return updated_windows , expired_windows
232+ # Combine triggered windows with time-expired windows
233+ all_expired_windows = triggered_windows + list (expired_windows )
234+
235+ return updated_windows , iter (all_expired_windows )
202236
203237 def expire_by_partition (
204238 self ,
0 commit comments