Skip to content

Commit 738e4b6

Browse files
committed
Split to before_update and after_update
1 parent d76bf73 commit 738e4b6

File tree

9 files changed

+479
-83
lines changed

9 files changed

+479
-83
lines changed

docs/windowing.md

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -596,19 +596,21 @@ if __name__ == '__main__':
596596
### Early window expiration with triggers
597597
!!! info New in v3.24.0
598598

599-
To expire windows before their natural expiration time based on custom conditions, you can pass the `on_update` callback to `.tumbling_window()` and `.hopping_window()` methods.
599+
To expire windows before their natural expiration time based on custom conditions, you can pass `before_update` or `after_update` callbacks to `.tumbling_window()` and `.hopping_window()` methods.
600600

601601
This is useful when you want to emit results as soon as certain conditions are met, rather than waiting for the window to close naturally.
602602

603603
**How it works**:
604604

605-
- The `on_update` callback is invoked every time a window is updated with a new value.
606-
- It receives `old_value` and `new_value` - the raw aggregated values before and after the update.
607-
- For `collect()` operations, it receives lists of collected values.
608-
- If the callback returns `True`, the window is immediately expired and emitted downstream.
609-
- The expired window is removed from state, but may be "resurrected" if new data arrives within its time range before natural expiration.
605+
- The `before_update` callback is invoked before the window aggregation is updated with a new value.
606+
- The `after_update` callback is invoked after the window aggregation has been updated with a new value.
607+
- Both callbacks receive: `aggregated` (current or updated aggregated value), `value` (incoming value), `key`, `timestamp`, and `headers`.
608+
- For `collect()` operations without aggregation, `aggregated` contains the list of collected values.
609+
- If either callback returns `True`, the window is immediately expired and emitted downstream.
610+
- The window metadata is deleted from state, but collected values (if using `.collect()`) remain until natural expiration.
611+
- This means a triggered window can be "resurrected" if new data arrives within its time range - a new window will be created with the previously collected values still present.
610612

611-
**Example**:
613+
**Example with after_update**:
612614

613615
```python
614616
from typing import Any
@@ -620,16 +622,18 @@ app = Application(...)
620622
sdf = app.dataframe(...)
621623

622624

623-
def trigger_on_threshold(old_value: int, new_value: int) -> bool:
625+
def trigger_on_threshold(
626+
aggregated: int, value: Any, key: Any, timestamp: int, headers: Any
627+
) -> bool:
624628
"""
625629
Expire the window early when the sum exceeds 1000.
626630
"""
627-
return new_value > 1000
631+
return aggregated > 1000
628632

629633

630634
# Define a 1-hour tumbling window with early expiration trigger
631635
sdf = (
632-
sdf.tumbling_window(timedelta(hours=1), on_update=trigger_on_threshold)
636+
sdf.tumbling_window(timedelta(hours=1), after_update=trigger_on_threshold)
633637
.sum()
634638
.final()
635639
)
@@ -640,6 +644,25 @@ if __name__ == '__main__':
640644

641645
```
642646

647+
**Example with before_update**:
648+
649+
```python
650+
def trigger_before_large_value(
651+
aggregated: int, value: Any, key: Any, timestamp: int, headers: Any
652+
) -> bool:
653+
"""
654+
Expire the window before adding a value if it would make the sum too large.
655+
"""
656+
return (aggregated + value) > 1000
657+
658+
659+
sdf = (
660+
sdf.tumbling_window(timedelta(hours=1), before_update=trigger_before_large_value)
661+
.sum()
662+
.final()
663+
)
664+
```
665+
643666

644667
## Emitting results
645668

quixstreams/dataframe/dataframe.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@
7272
TumblingCountWindowDefinition,
7373
TumblingTimeWindowDefinition,
7474
)
75-
from .windows.base import WindowOnLateCallback, WindowOnUpdateCallback
75+
from .windows.base import (
76+
WindowAfterUpdateCallback,
77+
WindowBeforeUpdateCallback,
78+
WindowOnLateCallback,
79+
)
7680

7781
if typing.TYPE_CHECKING:
7882
from quixstreams.processing import ProcessingContext
@@ -1085,7 +1089,8 @@ def tumbling_window(
10851089
grace_ms: Union[int, timedelta] = 0,
10861090
name: Optional[str] = None,
10871091
on_late: Optional[WindowOnLateCallback] = None,
1088-
on_update: Optional[WindowOnUpdateCallback] = None,
1092+
before_update: Optional[WindowBeforeUpdateCallback] = None,
1093+
after_update: Optional[WindowAfterUpdateCallback] = None,
10891094
) -> TumblingTimeWindowDefinition:
10901095
"""
10911096
Create a time-based tumbling window transformation on this StreamingDataFrame.
@@ -1152,12 +1157,18 @@ def tumbling_window(
11521157
(default behavior).
11531158
Otherwise, no message will be logged.
11541159
1155-
:param on_update: an optional callback to trigger early window expiration based
1156-
on custom conditions.
1157-
The callback receives `old_value` and `new_value` (the raw aggregated values
1158-
before and after the update). If it returns `True`, the window will be expired
1159-
immediately, even if it hasn't reached its natural expiration time.
1160-
For `collect()` operations, the callback receives lists of collected values.
1160+
:param before_update: an optional callback to trigger early window expiration
1161+
before the window is updated.
1162+
The callback receives `aggregated` (current aggregated value or default/None),
1163+
`value`, `key`, `timestamp`, and `headers`.
1164+
If it returns `True`, the window will be expired immediately.
1165+
Default - `None`.
1166+
1167+
:param after_update: an optional callback to trigger early window expiration
1168+
after the window is updated.
1169+
The callback receives `aggregated` (updated aggregated value), `value`, `key`,
1170+
`timestamp`, and `headers`.
1171+
If it returns `True`, the window will be expired immediately.
11611172
Default - `None`.
11621173
11631174
:return: `TumblingTimeWindowDefinition` instance representing the tumbling window
@@ -1175,7 +1186,8 @@ def tumbling_window(
11751186
dataframe=self,
11761187
name=name,
11771188
on_late=on_late,
1178-
on_update=on_update,
1189+
before_update=before_update,
1190+
after_update=after_update,
11791191
)
11801192

11811193
def tumbling_count_window(
@@ -1235,7 +1247,8 @@ def hopping_window(
12351247
grace_ms: Union[int, timedelta] = 0,
12361248
name: Optional[str] = None,
12371249
on_late: Optional[WindowOnLateCallback] = None,
1238-
on_update: Optional[WindowOnUpdateCallback] = None,
1250+
before_update: Optional[WindowBeforeUpdateCallback] = None,
1251+
after_update: Optional[WindowAfterUpdateCallback] = None,
12391252
) -> HoppingTimeWindowDefinition:
12401253
"""
12411254
Create a time-based hopping window transformation on this StreamingDataFrame.
@@ -1313,12 +1326,18 @@ def hopping_window(
13131326
(default behavior).
13141327
Otherwise, no message will be logged.
13151328
1316-
:param on_update: an optional callback to trigger early window expiration based
1317-
on custom conditions.
1318-
The callback receives `old_value` and `new_value` (the raw aggregated values
1319-
before and after the update). If it returns `True`, the window will be expired
1320-
immediately, even if it hasn't reached its natural expiration time.
1321-
For `collect()` operations, the callback receives lists of collected values.
1329+
:param before_update: an optional callback to trigger early window expiration
1330+
before the window is updated.
1331+
The callback receives `aggregated` (current aggregated value or default/None),
1332+
`value`, `key`, `timestamp`, and `headers`.
1333+
If it returns `True`, the window will be expired immediately.
1334+
Default - `None`.
1335+
1336+
:param after_update: an optional callback to trigger early window expiration
1337+
after the window is updated.
1338+
The callback receives `aggregated` (updated aggregated value), `value`, `key`,
1339+
`timestamp`, and `headers`.
1340+
If it returns `True`, the window will be expired immediately.
13221341
Default - `None`.
13231342
13241343
:return: `HoppingTimeWindowDefinition` instance representing the hopping
@@ -1338,7 +1357,8 @@ def hopping_window(
13381357
dataframe=self,
13391358
name=name,
13401359
on_late=on_late,
1341-
on_update=on_update,
1360+
before_update=before_update,
1361+
after_update=after_update,
13421362
)
13431363

13441364
def hopping_count_window(

quixstreams/dataframe/windows/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
WindowResult: TypeAlias = dict[str, Any]
3535
WindowKeyResult: TypeAlias = tuple[Any, WindowResult]
3636
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
37-
WindowOnUpdateCallback: TypeAlias = Callable[[Any, Any], bool]
37+
WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
38+
WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
3839

3940
WindowAggregateFunc = Callable[[Any, Any], Any]
4041

@@ -66,6 +67,7 @@ def process_window(
6667
value: Any,
6768
key: Any,
6869
timestamp_ms: int,
70+
headers: Any,
6971
transaction: WindowedPartitionTransaction,
7072
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
7173
pass
@@ -135,6 +137,7 @@ def window_callback(
135137
value=value,
136138
key=key,
137139
timestamp_ms=timestamp_ms,
140+
headers=_headers,
138141
transaction=transaction,
139142
)
140143
# Use window start timestamp as a new record timestamp
@@ -177,7 +180,11 @@ def window_callback(
177180
transaction: WindowedPartitionTransaction,
178181
) -> Iterable[Message]:
179182
updated_windows, expired_windows = self.process_window(
180-
value=value, key=key, timestamp_ms=timestamp_ms, transaction=transaction
183+
value=value,
184+
key=key,
185+
timestamp_ms=timestamp_ms,
186+
headers=_headers,
187+
transaction=transaction,
181188
)
182189

183190
# loop over the expired_windows generator to ensure the windows

quixstreams/dataframe/windows/count_based.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def process_window(
5858
value: Any,
5959
key: Any,
6060
timestamp_ms: int,
61+
headers: Any,
6162
transaction: WindowedPartitionTransaction[str, CountWindowsData],
6263
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
6364
"""

quixstreams/dataframe/windows/definitions.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
)
1616
from .base import (
1717
Window,
18+
WindowAfterUpdateCallback,
19+
WindowBeforeUpdateCallback,
1820
WindowOnLateCallback,
19-
WindowOnUpdateCallback,
2021
)
2122
from .count_based import (
2223
CountWindow,
@@ -55,13 +56,15 @@ def __init__(
5556
name: Optional[str],
5657
dataframe: "StreamingDataFrame",
5758
on_late: Optional[WindowOnLateCallback] = None,
58-
on_update: Optional[WindowOnUpdateCallback] = None,
59+
before_update: Optional[WindowBeforeUpdateCallback] = None,
60+
after_update: Optional[WindowAfterUpdateCallback] = None,
5961
) -> None:
6062
super().__init__()
6163

6264
self._name = name
6365
self._on_late = on_late
64-
self._on_update = on_update
66+
self._before_update = before_update
67+
self._after_update = after_update
6568
self._dataframe = dataframe
6669

6770
@abstractmethod
@@ -242,7 +245,8 @@ def __init__(
242245
name: Optional[str] = None,
243246
step_ms: Optional[int] = None,
244247
on_late: Optional[WindowOnLateCallback] = None,
245-
on_update: Optional[WindowOnUpdateCallback] = None,
248+
before_update: Optional[WindowBeforeUpdateCallback] = None,
249+
after_update: Optional[WindowAfterUpdateCallback] = None,
246250
):
247251
if not isinstance(duration_ms, int):
248252
raise TypeError("Window size must be an integer")
@@ -257,7 +261,7 @@ def __init__(
257261
f"got {step_ms}ms"
258262
)
259263

260-
super().__init__(name, dataframe, on_late, on_update)
264+
super().__init__(name, dataframe, on_late, before_update, after_update)
261265

262266
self._duration_ms = duration_ms
263267
self._grace_ms = grace_ms
@@ -285,7 +289,8 @@ def __init__(
285289
dataframe: "StreamingDataFrame",
286290
name: Optional[str] = None,
287291
on_late: Optional[WindowOnLateCallback] = None,
288-
on_update: Optional[WindowOnUpdateCallback] = None,
292+
before_update: Optional[WindowBeforeUpdateCallback] = None,
293+
after_update: Optional[WindowAfterUpdateCallback] = None,
289294
):
290295
super().__init__(
291296
duration_ms=duration_ms,
@@ -294,7 +299,8 @@ def __init__(
294299
name=name,
295300
step_ms=step_ms,
296301
on_late=on_late,
297-
on_update=on_update,
302+
before_update=before_update,
303+
after_update=after_update,
298304
)
299305

300306
def _get_name(self, func_name: Optional[str]) -> str:
@@ -326,7 +332,8 @@ def _create_window(
326332
aggregators=aggregators or {},
327333
collectors=collectors or {},
328334
on_late=self._on_late,
329-
on_update=self._on_update,
335+
before_update=self._before_update,
336+
after_update=self._after_update,
330337
)
331338

332339

@@ -338,16 +345,18 @@ def __init__(
338345
dataframe: "StreamingDataFrame",
339346
name: Optional[str] = None,
340347
on_late: Optional[WindowOnLateCallback] = None,
341-
on_update: Optional[WindowOnUpdateCallback] = None,
348+
before_update: Optional[WindowBeforeUpdateCallback] = None,
349+
after_update: Optional[WindowAfterUpdateCallback] = None,
342350
):
343351
super().__init__(
344352
duration_ms=duration_ms,
345353
grace_ms=grace_ms,
346354
dataframe=dataframe,
347355
name=name,
348356
on_late=on_late,
357+
before_update=before_update,
358+
after_update=after_update,
349359
)
350-
self._on_update = on_update
351360

352361
def _get_name(self, func_name: Optional[str]) -> str:
353362
prefix = f"{self._name}_tumbling_window" if self._name else "tumbling_window"
@@ -377,7 +386,8 @@ def _create_window(
377386
aggregators=aggregators or {},
378387
collectors=collectors or {},
379388
on_late=self._on_late,
380-
on_update=self._on_update,
389+
before_update=self._before_update,
390+
after_update=self._after_update,
381391
)
382392

383393

@@ -389,11 +399,12 @@ def __init__(
389399
dataframe: "StreamingDataFrame",
390400
name: Optional[str] = None,
391401
on_late: Optional[WindowOnLateCallback] = None,
392-
on_update: Optional[WindowOnUpdateCallback] = None,
402+
before_update: Optional[WindowBeforeUpdateCallback] = None,
403+
after_update: Optional[WindowAfterUpdateCallback] = None,
393404
):
394-
if on_update is not None:
405+
if before_update is not None or after_update is not None:
395406
raise ValueError(
396-
"Sliding windows do not support the 'on_update' trigger callback. "
407+
"Sliding windows do not support trigger callbacks (before_update/after_update). "
397408
"Use tumbling or hopping windows instead."
398409
)
399410
super().__init__(
@@ -402,7 +413,8 @@ def __init__(
402413
dataframe=dataframe,
403414
name=name,
404415
on_late=on_late,
405-
on_update=on_update,
416+
before_update=before_update,
417+
after_update=after_update,
406418
)
407419

408420
def _get_name(self, func_name: Optional[str]) -> str:
@@ -434,7 +446,8 @@ def _create_window(
434446
aggregators=aggregators or {},
435447
collectors=collectors or {},
436448
on_late=self._on_late,
437-
on_update=self._on_update,
449+
before_update=self._before_update,
450+
after_update=self._after_update,
438451
)
439452

440453

quixstreams/dataframe/windows/sliding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def process_window(
3535
value: Any,
3636
key: Any,
3737
timestamp_ms: int,
38+
headers: Any,
3839
transaction: WindowedPartitionTransaction,
3940
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
4041
"""

0 commit comments

Comments
 (0)