Skip to content

Commit 8baf943

Browse files
committed
Feature: on_update trigger
1 parent c17f3d7 commit 8baf943

File tree

7 files changed

+347
-13
lines changed

7 files changed

+347
-13
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
TumblingCountWindowDefinition,
7373
TumblingTimeWindowDefinition,
7474
)
75-
from .windows.base import WindowOnLateCallback
75+
from .windows.base import WindowOnLateCallback, WindowOnUpdateCallback
7676

7777
if typing.TYPE_CHECKING:
7878
from quixstreams.processing import ProcessingContext
@@ -1085,6 +1085,7 @@ def tumbling_window(
10851085
grace_ms: Union[int, timedelta] = 0,
10861086
name: Optional[str] = None,
10871087
on_late: Optional[WindowOnLateCallback] = None,
1088+
on_update: Optional[WindowOnUpdateCallback] = None,
10881089
) -> TumblingTimeWindowDefinition:
10891090
"""
10901091
Create a time-based tumbling window transformation on this StreamingDataFrame.
@@ -1151,6 +1152,14 @@ def tumbling_window(
11511152
(default behavior).
11521153
Otherwise, no message will be logged.
11531154
1155+
:param on_update: an optional callback to trigger early window expiration based
1156+
on custom conditions.
1157+
The callback receives `old_value` and `new_value` (the raw aggregated values
1158+
before and after the update). If it returns `True`, the window will be expired
1159+
immediately, even if it hasn't reached its natural expiration time.
1160+
For `collect()` operations, the callback receives lists of collected values.
1161+
Default - `None`.
1162+
11541163
:return: `TumblingTimeWindowDefinition` instance representing the tumbling window
11551164
configuration.
11561165
This object can be further configured with aggregation functions
@@ -1166,6 +1175,7 @@ def tumbling_window(
11661175
dataframe=self,
11671176
name=name,
11681177
on_late=on_late,
1178+
on_update=on_update,
11691179
)
11701180

11711181
def tumbling_count_window(
@@ -1225,6 +1235,7 @@ def hopping_window(
12251235
grace_ms: Union[int, timedelta] = 0,
12261236
name: Optional[str] = None,
12271237
on_late: Optional[WindowOnLateCallback] = None,
1238+
on_update: Optional[WindowOnUpdateCallback] = None,
12281239
) -> HoppingTimeWindowDefinition:
12291240
"""
12301241
Create a time-based hopping window transformation on this StreamingDataFrame.
@@ -1302,6 +1313,14 @@ def hopping_window(
13021313
(default behavior).
13031314
Otherwise, no message will be logged.
13041315
1316+
:param on_update: an optional callback to trigger early window expiration based
1317+
on custom conditions.
1318+
The callback receives `old_value` and `new_value` (the raw aggregated values
1319+
before and after the update). If it returns `True`, the window will be expired
1320+
immediately, even if it hasn't reached its natural expiration time.
1321+
For `collect()` operations, the callback receives lists of collected values.
1322+
Default - `None`.
1323+
13051324
:return: `HoppingTimeWindowDefinition` instance representing the hopping
13061325
window configuration.
13071326
This object can be further configured with aggregation functions
@@ -1319,6 +1338,7 @@ def hopping_window(
13191338
dataframe=self,
13201339
name=name,
13211340
on_late=on_late,
1341+
on_update=on_update,
13221342
)
13231343

13241344
def hopping_count_window(

quixstreams/dataframe/windows/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
WindowResult: TypeAlias = dict[str, Any]
3535
WindowKeyResult: TypeAlias = tuple[Any, WindowResult]
3636
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
37+
WindowOnUpdateCallback: TypeAlias = Callable[[Any, Any], bool]
3738

3839
WindowAggregateFunc = Callable[[Any, Any], Any]
3940

quixstreams/dataframe/windows/definitions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .base import (
1717
Window,
1818
WindowOnLateCallback,
19+
WindowOnUpdateCallback,
1920
)
2021
from .count_based import (
2122
CountWindow,
@@ -54,11 +55,13 @@ def __init__(
5455
name: Optional[str],
5556
dataframe: "StreamingDataFrame",
5657
on_late: Optional[WindowOnLateCallback] = None,
58+
on_update: Optional[WindowOnUpdateCallback] = None,
5759
) -> None:
5860
super().__init__()
5961

6062
self._name = name
6163
self._on_late = on_late
64+
self._on_update = on_update
6265
self._dataframe = dataframe
6366

6467
@abstractmethod
@@ -239,6 +242,7 @@ def __init__(
239242
name: Optional[str] = None,
240243
step_ms: Optional[int] = None,
241244
on_late: Optional[WindowOnLateCallback] = None,
245+
on_update: Optional[WindowOnUpdateCallback] = None,
242246
):
243247
if not isinstance(duration_ms, int):
244248
raise TypeError("Window size must be an integer")
@@ -253,7 +257,7 @@ def __init__(
253257
f"got {step_ms}ms"
254258
)
255259

256-
super().__init__(name, dataframe, on_late)
260+
super().__init__(name, dataframe, on_late, on_update)
257261

258262
self._duration_ms = duration_ms
259263
self._grace_ms = grace_ms
@@ -281,6 +285,7 @@ def __init__(
281285
dataframe: "StreamingDataFrame",
282286
name: Optional[str] = None,
283287
on_late: Optional[WindowOnLateCallback] = None,
288+
on_update: Optional[WindowOnUpdateCallback] = None,
284289
):
285290
super().__init__(
286291
duration_ms=duration_ms,
@@ -289,6 +294,7 @@ def __init__(
289294
name=name,
290295
step_ms=step_ms,
291296
on_late=on_late,
297+
on_update=on_update,
292298
)
293299

294300
def _get_name(self, func_name: Optional[str]) -> str:
@@ -320,6 +326,7 @@ def _create_window(
320326
aggregators=aggregators or {},
321327
collectors=collectors or {},
322328
on_late=self._on_late,
329+
on_update=self._on_update,
323330
)
324331

325332

@@ -331,6 +338,7 @@ def __init__(
331338
dataframe: "StreamingDataFrame",
332339
name: Optional[str] = None,
333340
on_late: Optional[WindowOnLateCallback] = None,
341+
on_update: Optional[WindowOnUpdateCallback] = None,
334342
):
335343
super().__init__(
336344
duration_ms=duration_ms,
@@ -339,6 +347,7 @@ def __init__(
339347
name=name,
340348
on_late=on_late,
341349
)
350+
self._on_update = on_update
342351

343352
def _get_name(self, func_name: Optional[str]) -> str:
344353
prefix = f"{self._name}_tumbling_window" if self._name else "tumbling_window"
@@ -368,6 +377,7 @@ def _create_window(
368377
aggregators=aggregators or {},
369378
collectors=collectors or {},
370379
on_late=self._on_late,
380+
on_update=self._on_update,
371381
)
372382

373383

@@ -379,13 +389,20 @@ def __init__(
379389
dataframe: "StreamingDataFrame",
380390
name: Optional[str] = None,
381391
on_late: Optional[WindowOnLateCallback] = None,
392+
on_update: Optional[WindowOnUpdateCallback] = None,
382393
):
394+
if on_update is not None:
395+
raise ValueError(
396+
"Sliding windows do not support the 'on_update' trigger callback. "
397+
"Use tumbling or hopping windows instead."
398+
)
383399
super().__init__(
384400
duration_ms=duration_ms,
385401
grace_ms=grace_ms,
386402
dataframe=dataframe,
387403
name=name,
388404
on_late=on_late,
405+
on_update=on_update,
389406
)
390407

391408
def _get_name(self, func_name: Optional[str]) -> str:
@@ -417,6 +434,7 @@ def _create_window(
417434
aggregators=aggregators or {},
418435
collectors=collectors or {},
419436
on_late=self._on_late,
437+
on_update=self._on_update,
420438
)
421439

422440

quixstreams/dataframe/windows/time_based.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Window,
1212
WindowKeyResult,
1313
WindowOnLateCallback,
14+
WindowOnUpdateCallback,
1415
get_window_ranges,
1516
)
1617

@@ -46,6 +47,7 @@ def __init__(
4647
dataframe: "StreamingDataFrame",
4748
step_ms: Optional[int] = None,
4849
on_late: Optional[WindowOnLateCallback] = None,
50+
on_update: Optional[WindowOnUpdateCallback] = None,
4951
):
5052
super().__init__(
5153
name=name,
@@ -56,6 +58,7 @@ def __init__(
5658
self._grace_ms = grace_ms
5759
self._step_ms = step_ms
5860
self._on_late = on_late
61+
self._on_update = on_update
5962

6063
self._closing_strategy = ClosingStrategy.KEY
6164

@@ -132,6 +135,7 @@ def process_window(
132135
state = transaction.as_state(prefix=key)
133136
duration_ms = self._duration_ms
134137
grace_ms = self._grace_ms
138+
on_update = self._on_update
135139

136140
collect = self.collect
137141
aggregate = self.aggregate
@@ -152,6 +156,7 @@ def process_window(
152156
max_expired_window_end = latest_timestamp - grace_ms
153157
max_expired_window_start = max_expired_window_end - duration_ms
154158
updated_windows: list[WindowKeyResult] = []
159+
triggered_windows: list[WindowKeyResult] = []
155160
for start, end in ranges:
156161
if start <= max_expired_window_start:
157162
late_by_ms = max_expired_window_end - timestamp_ms
@@ -169,18 +174,44 @@ def process_window(
169174
# since actual values are stored separately and combined into an array
170175
# during window expiration.
171176
aggregated = None
177+
172178
if aggregate:
173179
current_value = state.get_window(start, end)
174180
if current_value is None:
175181
current_value = self._initialize_value()
176182

177183
aggregated = self._aggregate_value(current_value, value, timestamp_ms)
178-
updated_windows.append(
179-
(
180-
key,
181-
self._results(aggregated, [], start, end),
182-
)
183-
)
184+
185+
if on_update and on_update(current_value, aggregated):
186+
# Get collected values for the result
187+
collected = []
188+
if collect:
189+
collected = state.get_from_collection(start, end)
190+
# Add the current value that's being collected
191+
collected.append(self._collect_value(value))
192+
193+
result = self._results(aggregated, collected, start, end)
194+
triggered_windows.append((key, result))
195+
transaction.delete_window(start, end, prefix=key)
196+
# Note: We don't delete from collection here - normal expiration
197+
# will handle cleanup for both tumbling and hopping windows
198+
continue
199+
200+
result = self._results(aggregated, [], start, end)
201+
updated_windows.append((key, result))
202+
elif collect and on_update:
203+
# For collect-only windows, get the old and new collected values
204+
old_collected = state.get_from_collection(start, end)
205+
new_collected = [*old_collected, self._collect_value(value)]
206+
207+
if on_update(old_collected, new_collected):
208+
result = self._results(None, new_collected, start, end)
209+
triggered_windows.append((key, result))
210+
transaction.delete_window(start, end, prefix=key)
211+
# Note: We don't delete from collection here - normal expiration
212+
# will handle cleanup for both tumbling and hopping windows
213+
continue
214+
184215
state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms)
185216

186217
if collect:
@@ -198,7 +229,10 @@ def process_window(
198229
key, state, max_expired_window_start, collect
199230
)
200231

201-
return updated_windows, expired_windows
232+
# Combine triggered windows with time-expired windows
233+
all_expired_windows = triggered_windows + list(expired_windows)
234+
235+
return updated_windows, iter(all_expired_windows)
202236

203237
def expire_by_partition(
204238
self,

quixstreams/state/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,16 @@ def expire_all_windows(
391391
"""
392392
...
393393

394+
def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None:
395+
"""
396+
Delete a single window defined by start and end timestamps.
397+
398+
:param start_ms: start of the window in milliseconds
399+
:param end_ms: end of the window in milliseconds
400+
:param prefix: a key prefix
401+
"""
402+
...
403+
394404
def delete_windows(
395405
self, max_start_time: int, delete_values: bool, prefix: bytes
396406
) -> None:

0 commit comments

Comments
 (0)