Skip to content

Commit 4becc22

Browse files
committed
Split to before_update and after_update
1 parent d76bf73 commit 4becc22

File tree

9 files changed

+477
-82
lines changed

9 files changed

+477
-82
lines changed

docs/windowing.md

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -596,19 +596,20 @@ if __name__ == '__main__':
596596
### Early window expiration with triggers
597597
!!! info New in v3.24.0
598598

599-
To expire windows before their natural expiration time based on custom conditions, you can pass the `on_update` callback to `.tumbling_window()` and `.hopping_window()` methods.
599+
To expire windows before their natural expiration time based on custom conditions, you can pass `before_update` or `after_update` callbacks to `.tumbling_window()` and `.hopping_window()` methods.
600600

601601
This is useful when you want to emit results as soon as certain conditions are met, rather than waiting for the window to close naturally.
602602

603603
**How it works**:
604604

605-
- The `on_update` callback is invoked every time a window is updated with a new value.
606-
- It receives `old_value` and `new_value` - the raw aggregated values before and after the update.
607-
- For `collect()` operations, it receives lists of collected values.
608-
- If the callback returns `True`, the window is immediately expired and emitted downstream.
605+
- The `before_update` callback is invoked before the window aggregation is updated with a new value.
606+
- The `after_update` callback is invoked after the window aggregation has been updated with a new value.
607+
- Both callbacks receive: `aggregated` (current or updated aggregated value), `value` (incoming value), `key`, `timestamp`, and `headers`.
608+
- For `collect()` operations without aggregation, `aggregated` contains the list of collected values.
609+
- If either callback returns `True`, the window is immediately expired and emitted downstream.
609610
- The expired window is removed from state, but may be "resurrected" if new data arrives within its time range before natural expiration.
610611

611-
**Example**:
612+
**Example with after_update**:
612613

613614
```python
614615
from typing import Any
@@ -620,16 +621,18 @@ app = Application(...)
620621
sdf = app.dataframe(...)
621622

622623

623-
def trigger_on_threshold(old_value: int, new_value: int) -> bool:
624+
def trigger_on_threshold(
625+
aggregated: int, value: Any, key: Any, timestamp: int, headers: Any
626+
) -> bool:
624627
"""
625628
Expire the window early when the sum exceeds 1000.
626629
"""
627-
return new_value > 1000
630+
return aggregated > 1000
628631

629632

630633
# Define a 1-hour tumbling window with early expiration trigger
631634
sdf = (
632-
sdf.tumbling_window(timedelta(hours=1), on_update=trigger_on_threshold)
635+
sdf.tumbling_window(timedelta(hours=1), after_update=trigger_on_threshold)
633636
.sum()
634637
.final()
635638
)
@@ -640,6 +643,25 @@ if __name__ == '__main__':
640643

641644
```
642645

646+
**Example with before_update**:
647+
648+
```python
649+
def trigger_before_large_value(
650+
aggregated: int, value: Any, key: Any, timestamp: int, headers: Any
651+
) -> bool:
652+
"""
653+
Expire the window before adding a value if it would make the sum too large.
654+
"""
655+
return (aggregated + value) > 1000
656+
657+
658+
sdf = (
659+
sdf.tumbling_window(timedelta(hours=1), before_update=trigger_before_large_value)
660+
.sum()
661+
.final()
662+
)
663+
```
664+
643665

644666
## Emitting results
645667

quixstreams/dataframe/dataframe.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@
7272
TumblingCountWindowDefinition,
7373
TumblingTimeWindowDefinition,
7474
)
75-
from .windows.base import WindowOnLateCallback, WindowOnUpdateCallback
75+
from .windows.base import (
76+
WindowAfterUpdateCallback,
77+
WindowBeforeUpdateCallback,
78+
WindowOnLateCallback,
79+
)
7680

7781
if typing.TYPE_CHECKING:
7882
from quixstreams.processing import ProcessingContext
@@ -1085,7 +1089,8 @@ def tumbling_window(
10851089
grace_ms: Union[int, timedelta] = 0,
10861090
name: Optional[str] = None,
10871091
on_late: Optional[WindowOnLateCallback] = None,
1088-
on_update: Optional[WindowOnUpdateCallback] = None,
1092+
before_update: Optional[WindowBeforeUpdateCallback] = None,
1093+
after_update: Optional[WindowAfterUpdateCallback] = None,
10891094
) -> TumblingTimeWindowDefinition:
10901095
"""
10911096
Create a time-based tumbling window transformation on this StreamingDataFrame.
@@ -1152,12 +1157,18 @@ def tumbling_window(
11521157
(default behavior).
11531158
Otherwise, no message will be logged.
11541159
1155-
:param on_update: an optional callback to trigger early window expiration based
1156-
on custom conditions.
1157-
The callback receives `old_value` and `new_value` (the raw aggregated values
1158-
before and after the update). If it returns `True`, the window will be expired
1159-
immediately, even if it hasn't reached its natural expiration time.
1160-
For `collect()` operations, the callback receives lists of collected values.
1160+
:param before_update: an optional callback to trigger early window expiration
1161+
before the window is updated.
1162+
The callback receives `aggregated` (current aggregated value or default/None),
1163+
`value`, `key`, `timestamp`, and `headers`.
1164+
If it returns `True`, the window will be expired immediately.
1165+
Default - `None`.
1166+
1167+
:param after_update: an optional callback to trigger early window expiration
1168+
after the window is updated.
1169+
The callback receives `aggregated` (updated aggregated value), `value`, `key`,
1170+
`timestamp`, and `headers`.
1171+
If it returns `True`, the window will be expired immediately.
11611172
Default - `None`.
11621173
11631174
:return: `TumblingTimeWindowDefinition` instance representing the tumbling window
@@ -1175,7 +1186,8 @@ def tumbling_window(
11751186
dataframe=self,
11761187
name=name,
11771188
on_late=on_late,
1178-
on_update=on_update,
1189+
before_update=before_update,
1190+
after_update=after_update,
11791191
)
11801192

11811193
def tumbling_count_window(
@@ -1235,7 +1247,8 @@ def hopping_window(
12351247
grace_ms: Union[int, timedelta] = 0,
12361248
name: Optional[str] = None,
12371249
on_late: Optional[WindowOnLateCallback] = None,
1238-
on_update: Optional[WindowOnUpdateCallback] = None,
1250+
before_update: Optional[WindowBeforeUpdateCallback] = None,
1251+
after_update: Optional[WindowAfterUpdateCallback] = None,
12391252
) -> HoppingTimeWindowDefinition:
12401253
"""
12411254
Create a time-based hopping window transformation on this StreamingDataFrame.
@@ -1313,12 +1326,18 @@ def hopping_window(
13131326
(default behavior).
13141327
Otherwise, no message will be logged.
13151328
1316-
:param on_update: an optional callback to trigger early window expiration based
1317-
on custom conditions.
1318-
The callback receives `old_value` and `new_value` (the raw aggregated values
1319-
before and after the update). If it returns `True`, the window will be expired
1320-
immediately, even if it hasn't reached its natural expiration time.
1321-
For `collect()` operations, the callback receives lists of collected values.
1329+
:param before_update: an optional callback to trigger early window expiration
1330+
before the window is updated.
1331+
The callback receives `aggregated` (current aggregated value or default/None),
1332+
`value`, `key`, `timestamp`, and `headers`.
1333+
If it returns `True`, the window will be expired immediately.
1334+
Default - `None`.
1335+
1336+
:param after_update: an optional callback to trigger early window expiration
1337+
after the window is updated.
1338+
The callback receives `aggregated` (updated aggregated value), `value`, `key`,
1339+
`timestamp`, and `headers`.
1340+
If it returns `True`, the window will be expired immediately.
13221341
Default - `None`.
13231342
13241343
:return: `HoppingTimeWindowDefinition` instance representing the hopping
@@ -1338,7 +1357,8 @@ def hopping_window(
13381357
dataframe=self,
13391358
name=name,
13401359
on_late=on_late,
1341-
on_update=on_update,
1360+
before_update=before_update,
1361+
after_update=after_update,
13421362
)
13431363

13441364
def hopping_count_window(

quixstreams/dataframe/windows/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
WindowResult: TypeAlias = dict[str, Any]
3535
WindowKeyResult: TypeAlias = tuple[Any, WindowResult]
3636
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
37-
WindowOnUpdateCallback: TypeAlias = Callable[[Any, Any], bool]
37+
WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
38+
WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
3839

3940
WindowAggregateFunc = Callable[[Any, Any], Any]
4041

@@ -66,6 +67,7 @@ def process_window(
6667
value: Any,
6768
key: Any,
6869
timestamp_ms: int,
70+
headers: Any,
6971
transaction: WindowedPartitionTransaction,
7072
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
7173
pass
@@ -135,6 +137,7 @@ def window_callback(
135137
value=value,
136138
key=key,
137139
timestamp_ms=timestamp_ms,
140+
headers=_headers,
138141
transaction=transaction,
139142
)
140143
# Use window start timestamp as a new record timestamp
@@ -177,7 +180,11 @@ def window_callback(
177180
transaction: WindowedPartitionTransaction,
178181
) -> Iterable[Message]:
179182
updated_windows, expired_windows = self.process_window(
180-
value=value, key=key, timestamp_ms=timestamp_ms, transaction=transaction
183+
value=value,
184+
key=key,
185+
timestamp_ms=timestamp_ms,
186+
headers=_headers,
187+
transaction=transaction,
181188
)
182189

183190
# loop over the expired_windows generator to ensure the windows

quixstreams/dataframe/windows/count_based.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def process_window(
5858
value: Any,
5959
key: Any,
6060
timestamp_ms: int,
61+
headers: Any,
6162
transaction: WindowedPartitionTransaction[str, CountWindowsData],
6263
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
6364
"""

quixstreams/dataframe/windows/definitions.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
)
1616
from .base import (
1717
Window,
18+
WindowAfterUpdateCallback,
19+
WindowBeforeUpdateCallback,
1820
WindowOnLateCallback,
19-
WindowOnUpdateCallback,
2021
)
2122
from .count_based import (
2223
CountWindow,
@@ -55,13 +56,15 @@ def __init__(
5556
name: Optional[str],
5657
dataframe: "StreamingDataFrame",
5758
on_late: Optional[WindowOnLateCallback] = None,
58-
on_update: Optional[WindowOnUpdateCallback] = None,
59+
before_update: Optional[WindowBeforeUpdateCallback] = None,
60+
after_update: Optional[WindowAfterUpdateCallback] = None,
5961
) -> None:
6062
super().__init__()
6163

6264
self._name = name
6365
self._on_late = on_late
64-
self._on_update = on_update
66+
self._before_update = before_update
67+
self._after_update = after_update
6568
self._dataframe = dataframe
6669

6770
@abstractmethod
@@ -242,7 +245,8 @@ def __init__(
242245
name: Optional[str] = None,
243246
step_ms: Optional[int] = None,
244247
on_late: Optional[WindowOnLateCallback] = None,
245-
on_update: Optional[WindowOnUpdateCallback] = None,
248+
before_update: Optional[WindowBeforeUpdateCallback] = None,
249+
after_update: Optional[WindowAfterUpdateCallback] = None,
246250
):
247251
if not isinstance(duration_ms, int):
248252
raise TypeError("Window size must be an integer")
@@ -257,7 +261,7 @@ def __init__(
257261
f"got {step_ms}ms"
258262
)
259263

260-
super().__init__(name, dataframe, on_late, on_update)
264+
super().__init__(name, dataframe, on_late, before_update, after_update)
261265

262266
self._duration_ms = duration_ms
263267
self._grace_ms = grace_ms
@@ -285,7 +289,8 @@ def __init__(
285289
dataframe: "StreamingDataFrame",
286290
name: Optional[str] = None,
287291
on_late: Optional[WindowOnLateCallback] = None,
288-
on_update: Optional[WindowOnUpdateCallback] = None,
292+
before_update: Optional[WindowBeforeUpdateCallback] = None,
293+
after_update: Optional[WindowAfterUpdateCallback] = None,
289294
):
290295
super().__init__(
291296
duration_ms=duration_ms,
@@ -294,7 +299,8 @@ def __init__(
294299
name=name,
295300
step_ms=step_ms,
296301
on_late=on_late,
297-
on_update=on_update,
302+
before_update=before_update,
303+
after_update=after_update,
298304
)
299305

300306
def _get_name(self, func_name: Optional[str]) -> str:
@@ -326,7 +332,8 @@ def _create_window(
326332
aggregators=aggregators or {},
327333
collectors=collectors or {},
328334
on_late=self._on_late,
329-
on_update=self._on_update,
335+
before_update=self._before_update,
336+
after_update=self._after_update,
330337
)
331338

332339

@@ -338,16 +345,18 @@ def __init__(
338345
dataframe: "StreamingDataFrame",
339346
name: Optional[str] = None,
340347
on_late: Optional[WindowOnLateCallback] = None,
341-
on_update: Optional[WindowOnUpdateCallback] = None,
348+
before_update: Optional[WindowBeforeUpdateCallback] = None,
349+
after_update: Optional[WindowAfterUpdateCallback] = None,
342350
):
343351
super().__init__(
344352
duration_ms=duration_ms,
345353
grace_ms=grace_ms,
346354
dataframe=dataframe,
347355
name=name,
348356
on_late=on_late,
357+
before_update=before_update,
358+
after_update=after_update,
349359
)
350-
self._on_update = on_update
351360

352361
def _get_name(self, func_name: Optional[str]) -> str:
353362
prefix = f"{self._name}_tumbling_window" if self._name else "tumbling_window"
@@ -377,7 +386,8 @@ def _create_window(
377386
aggregators=aggregators or {},
378387
collectors=collectors or {},
379388
on_late=self._on_late,
380-
on_update=self._on_update,
389+
before_update=self._before_update,
390+
after_update=self._after_update,
381391
)
382392

383393

@@ -389,11 +399,12 @@ def __init__(
389399
dataframe: "StreamingDataFrame",
390400
name: Optional[str] = None,
391401
on_late: Optional[WindowOnLateCallback] = None,
392-
on_update: Optional[WindowOnUpdateCallback] = None,
402+
before_update: Optional[WindowBeforeUpdateCallback] = None,
403+
after_update: Optional[WindowAfterUpdateCallback] = None,
393404
):
394-
if on_update is not None:
405+
if before_update is not None or after_update is not None:
395406
raise ValueError(
396-
"Sliding windows do not support the 'on_update' trigger callback. "
407+
"Sliding windows do not support trigger callbacks (before_update/after_update). "
397408
"Use tumbling or hopping windows instead."
398409
)
399410
super().__init__(
@@ -402,7 +413,8 @@ def __init__(
402413
dataframe=dataframe,
403414
name=name,
404415
on_late=on_late,
405-
on_update=on_update,
416+
before_update=before_update,
417+
after_update=after_update,
406418
)
407419

408420
def _get_name(self, func_name: Optional[str]) -> str:
@@ -434,7 +446,8 @@ def _create_window(
434446
aggregators=aggregators or {},
435447
collectors=collectors or {},
436448
on_late=self._on_late,
437-
on_update=self._on_update,
449+
before_update=self._before_update,
450+
after_update=self._after_update,
438451
)
439452

440453

quixstreams/dataframe/windows/sliding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def process_window(
3535
value: Any,
3636
key: Any,
3737
timestamp_ms: int,
38+
headers: Any,
3839
transaction: WindowedPartitionTransaction,
3940
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
4041
"""

0 commit comments

Comments
 (0)