diff --git a/docs/windowing.md b/docs/windowing.md index ea5a9a72f..c58e53b08 100644 --- a/docs/windowing.md +++ b/docs/windowing.md @@ -593,6 +593,76 @@ if __name__ == '__main__': ``` +### Early window expiration with triggers +!!! info New in v3.24.0 + +To expire windows before their natural expiration time based on custom conditions, you can pass `before_update` or `after_update` callbacks to `.tumbling_window()` and `.hopping_window()` methods. + +This is useful when you want to emit results as soon as certain conditions are met, rather than waiting for the window to close naturally. + +**How it works**: + +- The `before_update` callback is invoked before the window aggregation is updated with a new value. +- The `after_update` callback is invoked after the window aggregation has been updated with a new value. +- Both callbacks receive: `aggregated` (current or updated aggregated value), `value` (incoming value), `key`, `timestamp`, and `headers`. +- For `collect()` operations without aggregation, `aggregated` contains the list of collected values. +- If either callback returns `True`, the window is immediately expired and emitted downstream. +- The window metadata is deleted from state, but collected values (if using `.collect()`) remain until natural expiration. +- This means a triggered window can be "resurrected" if new data arrives within its time range - a new window will be created with the previously collected values still present. + +**Example with after_update**: + +```python +from typing import Any + +from datetime import timedelta +from quixstreams import Application + +app = Application(...) +sdf = app.dataframe(...) + + +def trigger_on_threshold( + aggregated: int, value: Any, key: Any, timestamp: int, headers: Any +) -> bool: + """ + Expire the window early when the sum exceeds 1000. + """ + return aggregated > 1000 + + +# Define a 1-hour tumbling window with early expiration trigger +sdf = ( + sdf.tumbling_window(timedelta(hours=1), after_update=trigger_on_threshold) + .sum() + .final() +) + +# Start the application +if __name__ == '__main__': + app.run() + +``` + +**Example with before_update**: + +```python +def trigger_before_large_value( + aggregated: int, value: Any, key: Any, timestamp: int, headers: Any +) -> bool: + """ + Expire the window before adding a value if it would make the sum too large. + """ + return (aggregated + value) > 1000 + + +sdf = ( + sdf.tumbling_window(timedelta(hours=1), before_update=trigger_before_large_value) + .sum() + .final() +) +``` + ## Emitting results diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 53e90c767..85fd1a660 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -72,7 +72,11 @@ TumblingCountWindowDefinition, TumblingTimeWindowDefinition, ) -from .windows.base import WindowOnLateCallback +from .windows.base import ( + WindowAfterUpdateCallback, + WindowBeforeUpdateCallback, + WindowOnLateCallback, +) if typing.TYPE_CHECKING: from quixstreams.processing import ProcessingContext @@ -1085,6 +1089,8 @@ def tumbling_window( grace_ms: Union[int, timedelta] = 0, name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ) -> TumblingTimeWindowDefinition: """ Create a time-based tumbling window transformation on this StreamingDataFrame. @@ -1151,6 +1157,20 @@ def tumbling_window( (default behavior). Otherwise, no message will be logged. + :param before_update: an optional callback to trigger early window expiration + before the window is updated. + The callback receives `aggregated` (current aggregated value or default/None), + `value`, `key`, `timestamp`, and `headers`. + If it returns `True`, the window will be expired immediately. + Default - `None`. + + :param after_update: an optional callback to trigger early window expiration + after the window is updated. + The callback receives `aggregated` (updated aggregated value), `value`, `key`, + `timestamp`, and `headers`. + If it returns `True`, the window will be expired immediately. + Default - `None`. + :return: `TumblingTimeWindowDefinition` instance representing the tumbling window configuration. This object can be further configured with aggregation functions @@ -1166,6 +1186,8 @@ def tumbling_window( dataframe=self, name=name, on_late=on_late, + before_update=before_update, + after_update=after_update, ) def tumbling_count_window( @@ -1225,6 +1247,8 @@ def hopping_window( grace_ms: Union[int, timedelta] = 0, name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ) -> HoppingTimeWindowDefinition: """ Create a time-based hopping window transformation on this StreamingDataFrame. @@ -1302,6 +1326,20 @@ def hopping_window( (default behavior). Otherwise, no message will be logged. + :param before_update: an optional callback to trigger early window expiration + before the window is updated. + The callback receives `aggregated` (current aggregated value or default/None), + `value`, `key`, `timestamp`, and `headers`. + If it returns `True`, the window will be expired immediately. + Default - `None`. + + :param after_update: an optional callback to trigger early window expiration + after the window is updated. + The callback receives `aggregated` (updated aggregated value), `value`, `key`, + `timestamp`, and `headers`. + If it returns `True`, the window will be expired immediately. + Default - `None`. + :return: `HoppingTimeWindowDefinition` instance representing the hopping window configuration. This object can be further configured with aggregation functions @@ -1319,6 +1357,8 @@ def hopping_window( dataframe=self, name=name, on_late=on_late, + before_update=before_update, + after_update=after_update, ) def hopping_count_window( diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 8040b2774..9aa073410 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -34,6 +34,8 @@ WindowResult: TypeAlias = dict[str, Any] WindowKeyResult: TypeAlias = tuple[Any, WindowResult] Message: TypeAlias = tuple[WindowResult, Any, int, Any] +WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool] +WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool] WindowAggregateFunc = Callable[[Any, Any], Any] @@ -65,6 +67,7 @@ def process_window( value: Any, key: Any, timestamp_ms: int, + headers: Any, transaction: WindowedPartitionTransaction, ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: pass @@ -134,6 +137,7 @@ def window_callback( value=value, key=key, timestamp_ms=timestamp_ms, + headers=_headers, transaction=transaction, ) # Use window start timestamp as a new record timestamp @@ -176,7 +180,11 @@ def window_callback( transaction: WindowedPartitionTransaction, ) -> Iterable[Message]: updated_windows, expired_windows = self.process_window( - value=value, key=key, timestamp_ms=timestamp_ms, transaction=transaction + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=_headers, + transaction=transaction, ) # loop over the expired_windows generator to ensure the windows diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index 57c6b36e5..0899c66c4 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -58,6 +58,7 @@ def process_window( value: Any, key: Any, timestamp_ms: int, + headers: Any, transaction: WindowedPartitionTransaction[str, CountWindowsData], ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: """ diff --git a/quixstreams/dataframe/windows/definitions.py b/quixstreams/dataframe/windows/definitions.py index 90d4d815b..20e2ce944 100644 --- a/quixstreams/dataframe/windows/definitions.py +++ b/quixstreams/dataframe/windows/definitions.py @@ -15,6 +15,8 @@ ) from .base import ( Window, + WindowAfterUpdateCallback, + WindowBeforeUpdateCallback, WindowOnLateCallback, ) from .count_based import ( @@ -54,11 +56,15 @@ def __init__( name: Optional[str], dataframe: "StreamingDataFrame", on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ) -> None: super().__init__() self._name = name self._on_late = on_late + self._before_update = before_update + self._after_update = after_update self._dataframe = dataframe @abstractmethod @@ -239,6 +245,8 @@ def __init__( name: Optional[str] = None, step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ): if not isinstance(duration_ms, int): raise TypeError("Window size must be an integer") @@ -253,7 +261,7 @@ def __init__( f"got {step_ms}ms" ) - super().__init__(name, dataframe, on_late) + super().__init__(name, dataframe, on_late, before_update, after_update) self._duration_ms = duration_ms self._grace_ms = grace_ms @@ -281,6 +289,8 @@ def __init__( dataframe: "StreamingDataFrame", name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ): super().__init__( duration_ms=duration_ms, @@ -289,6 +299,8 @@ def __init__( name=name, step_ms=step_ms, on_late=on_late, + before_update=before_update, + after_update=after_update, ) def _get_name(self, func_name: Optional[str]) -> str: @@ -320,6 +332,8 @@ def _create_window( aggregators=aggregators or {}, collectors=collectors or {}, on_late=self._on_late, + before_update=self._before_update, + after_update=self._after_update, ) @@ -331,6 +345,8 @@ def __init__( dataframe: "StreamingDataFrame", name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ): super().__init__( duration_ms=duration_ms, @@ -338,6 +354,8 @@ def __init__( dataframe=dataframe, name=name, on_late=on_late, + before_update=before_update, + after_update=after_update, ) def _get_name(self, func_name: Optional[str]) -> str: @@ -368,6 +386,8 @@ def _create_window( aggregators=aggregators or {}, collectors=collectors or {}, on_late=self._on_late, + before_update=self._before_update, + after_update=self._after_update, ) @@ -379,13 +399,22 @@ def __init__( dataframe: "StreamingDataFrame", name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ): + if before_update is not None or after_update is not None: + raise ValueError( + "Sliding windows do not support trigger callbacks (before_update/after_update). " + "Use tumbling or hopping windows instead." + ) super().__init__( duration_ms=duration_ms, grace_ms=grace_ms, dataframe=dataframe, name=name, on_late=on_late, + before_update=before_update, + after_update=after_update, ) def _get_name(self, func_name: Optional[str]) -> str: @@ -417,6 +446,8 @@ def _create_window( aggregators=aggregators or {}, collectors=collectors or {}, on_late=self._on_late, + before_update=self._before_update, + after_update=self._after_update, ) diff --git a/quixstreams/dataframe/windows/sliding.py b/quixstreams/dataframe/windows/sliding.py index d3dfdbb39..f2ff2d461 100644 --- a/quixstreams/dataframe/windows/sliding.py +++ b/quixstreams/dataframe/windows/sliding.py @@ -35,6 +35,7 @@ def process_window( value: Any, key: Any, timestamp_ms: int, + headers: Any, transaction: WindowedPartitionTransaction, ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: """ diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index c403cfdfa..4620974c4 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -1,3 +1,4 @@ +import itertools import logging from enum import Enum from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional @@ -9,6 +10,8 @@ MultiAggregationWindowMixin, SingleAggregationWindowMixin, Window, + WindowAfterUpdateCallback, + WindowBeforeUpdateCallback, WindowKeyResult, WindowOnLateCallback, get_window_ranges, @@ -46,6 +49,8 @@ def __init__( dataframe: "StreamingDataFrame", step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ): super().__init__( name=name, @@ -56,6 +61,8 @@ def __init__( self._grace_ms = grace_ms self._step_ms = step_ms self._on_late = on_late + self._before_update = before_update + self._after_update = after_update self._closing_strategy = ClosingStrategy.KEY @@ -127,11 +134,14 @@ def process_window( value: Any, key: Any, timestamp_ms: int, + headers: Any, transaction: WindowedPartitionTransaction, ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: state = transaction.as_state(prefix=key) duration_ms = self._duration_ms grace_ms = self._grace_ms + before_update = self._before_update + after_update = self._after_update collect = self.collect aggregate = self.aggregate @@ -152,6 +162,7 @@ def process_window( max_expired_window_end = latest_timestamp - grace_ms max_expired_window_start = max_expired_window_end - duration_ms updated_windows: list[WindowKeyResult] = [] + triggered_windows: list[WindowKeyResult] = [] for start, end in ranges: if start <= max_expired_window_start: late_by_ms = max_expired_window_end - timestamp_ms @@ -169,18 +180,78 @@ def process_window( # since actual values are stored separately and combined into an array # during window expiration. aggregated = None + if aggregate: current_value = state.get_window(start, end) if current_value is None: current_value = self._initialize_value() + # Check before_update trigger + if before_update and before_update( + current_value, value, key, timestamp_ms, headers + ): + # Get collected values for the result + # Do NOT include the current value - before_update means + # we expire BEFORE adding the current value + collected = state.get_from_collection(start, end) if collect else [] + + result = self._results(current_value, collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + aggregated = self._aggregate_value(current_value, value, timestamp_ms) - updated_windows.append( - ( - key, - self._results(aggregated, [], start, end), - ) - ) + + # Check after_update trigger + if after_update and after_update( + aggregated, value, key, timestamp_ms, headers + ): + # Get collected values for the result + collected = [] + if collect: + collected = state.get_from_collection(start, end) + # Add the current value that's being collected + collected.append(self._collect_value(value)) + + result = self._results(aggregated, collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + + result = self._results(aggregated, [], start, end) + updated_windows.append((key, result)) + elif collect and (before_update or after_update): + # For collect-only windows, get the old collected values + old_collected = state.get_from_collection(start, end) + + # Check before_update trigger (before adding new value) + if before_update and before_update( + old_collected, value, key, timestamp_ms, headers + ): + # Expire with the current collection (WITHOUT the new value) + result = self._results(None, old_collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + + # Check after_update trigger (conceptually after adding new value) + # For collect, "after update" means after the value would be added + if after_update: + new_collected = [*old_collected, self._collect_value(value)] + if after_update(new_collected, value, key, timestamp_ms, headers): + result = self._results(None, new_collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms) if collect: @@ -198,7 +269,10 @@ def process_window( key, state, max_expired_window_start, collect ) - return updated_windows, expired_windows + # Combine triggered windows with time-expired windows + all_expired_windows = itertools.chain(expired_windows, triggered_windows) + + return updated_windows, all_expired_windows def expire_by_partition( self, diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index c80c9e2ad..2764651b5 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -391,6 +391,16 @@ def expire_all_windows( """ ... + def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None: + """ + Delete a single window defined by start and end timestamps. + + :param start_ms: start of the window in milliseconds + :param end_ms: end of the window in milliseconds + :param prefix: a key prefix + """ + ... + def delete_windows( self, max_start_time: int, delete_values: bool, prefix: bytes ) -> None: diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py index 6a0b1fd5f..3c14edd76 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py @@ -1,3 +1,5 @@ +import functools + import pytest import quixstreams.dataframe.windows.aggregations as agg @@ -12,27 +14,431 @@ @pytest.fixture() def hopping_window_definition_factory(state_manager, dataframe_factory): def factory( - duration_ms: int, step_ms: int, grace_ms: int = 0 + duration_ms: int, + step_ms: int, + grace_ms: int = 0, + before_update=None, + after_update=None, ) -> HoppingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = HoppingTimeWindowDefinition( - duration_ms=duration_ms, step_ms=step_ms, grace_ms=grace_ms, dataframe=sdf + duration_ms=duration_ms, + step_ms=step_ms, + grace_ms=grace_ms, + dataframe=sdf, + before_update=before_update, + after_update=after_update, ) return window_def return factory -def process(window, value, key, transaction, timestamp_ms): +def process(window, value, key, transaction, timestamp_ms, headers=None): updated, expired = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=headers, + transaction=transaction, ) return list(updated), list(expired) class TestHoppingWindow: + def test_hopping_window_with_after_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + # Define a trigger that expires windows when the sum reaches 100 or more + def trigger_on_sum_100(aggregated, value, key, timestamp, headers) -> bool: + return aggregated >= 100 + + window_def = hopping_window_definition_factory( + duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_sum_100 + ) + window = window_def.sum() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add value=90 at timestamp 50ms + # Creates windows [0, 100) and [50, 150) with sum 90 each + updated, expired = _process(value=90, timestamp_ms=50) + assert len(updated) == 2 + assert updated[0][1]["value"] == 90 + assert updated[0][1]["start"] == 0 + assert updated[0][1]["end"] == 100 + assert updated[1][1]["value"] == 90 + assert updated[1][1]["start"] == 50 + assert updated[1][1]["end"] == 150 + assert not expired + + # Step 2: Add value=5 at timestamp 110ms + # With grace_ms=100, [0, 100) does NOT expire naturally yet + # [0, 100): stays 90 (timestamp 110 is outside [0, 100), not updated) + # [50, 150): 90 -> 95 (< 100, NOT TRIGGERED) + # [100, 200): newly created with sum 5 + updated, expired = _process(value=5, timestamp_ms=110) + assert len(updated) == 2 + assert updated[0][1]["value"] == 95 + assert updated[0][1]["start"] == 50 + assert updated[0][1]["end"] == 150 + assert updated[1][1]["value"] == 5 + assert updated[1][1]["start"] == 100 + assert updated[1][1]["end"] == 200 + # No windows expired (grace period keeps [0, 100) alive) + assert not expired + + # Step 3: Add value=5 at timestamp 90ms (late message) + # Timestamp 90 belongs to BOTH [0, 100) and [50, 150) + # [0, 100): 90 -> 95 (< 100, NOT TRIGGERED) + # [50, 150): 95 -> 100 (>= 100, TRIGGERED!) + updated, expired = _process(value=5, timestamp_ms=90) + # Only [0, 100) remains in updated (not triggered, 95 < 100) + # Only [50, 150) was triggered (100 >= 100) + assert len(updated) == 1 + assert updated[0][1]["value"] == 95 + assert updated[0][1]["start"] == 0 + assert updated[0][1]["end"] == 100 + assert len(expired) == 1 + assert expired[0][1]["value"] == 100 + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_with_before_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that before_update callback works for hopping windows.""" + + # Define a trigger that expires windows before adding a value + # if the sum would exceed 50 + def trigger_before_exceeding_50( + aggregated, value, key, timestamp, headers + ) -> bool: + return (aggregated + value) > 50 + + window_def = hopping_window_definition_factory( + duration_ms=100, + step_ms=50, + grace_ms=100, + before_update=trigger_before_exceeding_50, + ) + window = window_def.sum() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Helper to process and return results + def _process(value, timestamp_ms): + return process( + window, + value=value, + key=key, + transaction=tx, + timestamp_ms=timestamp_ms, + ) + + # Step 1: Add value=10 at timestamp 50ms + # Belongs to windows [0, 100) and [50, 150) (hopping windows overlap) + # Both windows: Sum=10, doesn't exceed 50, no trigger + updated, expired = _process(value=10, timestamp_ms=50) + assert len(updated) == 2 + assert updated[0][1]["value"] == 10 + assert updated[0][1]["start"] == 0 + assert updated[1][1]["value"] == 10 + assert updated[1][1]["start"] == 50 + assert not expired + + # Step 2: Add value=20 at timestamp 60ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows: Sum=30, doesn't exceed 50, no trigger + updated, expired = _process(value=20, timestamp_ms=60) + assert len(updated) == 2 + assert updated[0][1]["value"] == 30 # [0, 100) + assert updated[1][1]["value"] == 30 # [50, 150) + assert not expired + + # Step 3: Add value=25 at timestamp 70ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows: Sum would be 55 which exceeds 50, should trigger BEFORE adding + # Both expired windows should have value=30 (not 55) + updated, expired = _process(value=25, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + assert expired[0][1]["value"] == 30 # [0, 100) before the update + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + assert expired[1][1]["value"] == 30 # [50, 150) before the update + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add value=5 at timestamp 100ms + # Belongs to windows [50, 150) and [100, 200) + # Window [50, 150) sum=5, doesn't trigger + # Window [100, 200) sum=5, doesn't trigger + updated, expired = _process(value=5, timestamp_ms=100) + assert len(updated) == 2 + # Results should be for both windows + assert not expired + + def test_hopping_window_collect_with_after_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that after_update callback works with collect for hopping windows.""" + + # Define a trigger that expires windows when we collect 3 or more items + def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: + return len(aggregated) >= 3 + + window_def = hopping_window_definition_factory( + duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_count_3 + ) + window = window_def.collect() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add first value at timestamp 50ms + # Creates windows [0, 100) and [50, 150) with 1 item each + updated, expired = _process(value=1, timestamp_ms=50) + assert not updated # collect doesn't emit on updates + assert not expired + + # Step 2: Add second value at timestamp 60ms + # Both windows now have 2 items + updated, expired = _process(value=2, timestamp_ms=60) + assert not updated + assert not expired + + # Step 3: Add third value at timestamp 70ms + # Both windows now have 3 items - BOTH SHOULD TRIGGER + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + # Window [0, 100) triggered + assert expired[0][1]["value"] == [1, 2, 3] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + # Window [50, 150) triggered + assert expired[1][1]["value"] == [1, 2, 3] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add fourth value at timestamp 110ms + # Timestamp 110 belongs to windows [50, 150) and [100, 200) + # Window [50, 150) is "resurrected" because collection values weren't deleted + # (for hopping windows, we don't delete collection on trigger to preserve + # values for overlapping windows) + # Window [50, 150) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN! + # Window [100, 200) has [4] = 1 item - doesn't trigger + updated, expired = _process(value=4, timestamp_ms=110) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_collect_with_before_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that before_update callback works with collect for hopping windows.""" + + # Define a trigger that expires windows before adding a value + # if the collection would reach 3 or more items + def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: + # For collect, aggregated is the list of collected values BEFORE adding + return len(aggregated) + 1 >= 3 + + window_def = hopping_window_definition_factory( + duration_ms=100, + step_ms=50, + grace_ms=100, + before_update=trigger_before_count_3, + ) + window = window_def.collect() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Helper to process and return results + def _process(value, timestamp_ms): + return process( + window, + value=value, + key=key, + transaction=tx, + timestamp_ms=timestamp_ms, + ) + + # Step 1: Add value=1 at timestamp 50ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows would have 1 item, no trigger + updated, expired = _process(value=1, timestamp_ms=50) + assert not updated # collect doesn't emit on updates + assert not expired + + # Step 2: Add value=2 at timestamp 60ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows would have 2 items, no trigger + updated, expired = _process(value=2, timestamp_ms=60) + assert not updated + assert not expired + + # Step 3: Add value=3 at timestamp 70ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows would have 3 items, triggers BEFORE adding + # Both windows should have [1, 2] (not [1, 2, 3]) + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + # Window [0, 100) + assert expired[0][1]["value"] == [1, 2] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + # Window [50, 150) + assert expired[1][1]["value"] == [1, 2] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add value=4 at timestamp 110ms + # Belongs to windows [50, 150) and [100, 200) + # Window [50, 150) resurrected with [1, 2, 3] - would be 4 items, triggers + # Window [100, 200) would have 1 item, no trigger + updated, expired = _process(value=4, timestamp_ms=110) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4 + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_agg_and_collect_with_before_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test before_update with BOTH aggregation and collect for hopping windows. + + This verifies that: + 1. The triggered window does NOT include the triggering value in collect + 2. The triggering value IS still added to collection storage for future windows + 3. The aggregated value is BEFORE the triggering value + 4. For hopping windows, overlapping windows share the collection storage + """ + import quixstreams.dataframe.windows.aggregations as agg + + # Trigger when count would reach 3 + def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: + # In multi-aggregation, keys are like 'count/Count', 'sum/Sum' + # Find the count aggregation value + for k, v in agg_dict.items(): + if k.startswith("count"): + return v + 1 >= 3 + return False + + window_def = hopping_window_definition_factory( + duration_ms=100, + step_ms=50, + grace_ms=100, + before_update=trigger_before_count_3, + ) + window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add value=1 at timestamp 50ms + # Windows [0, 100) and [50, 150) both get count=1 + updated, expired = _process(value=1, timestamp_ms=50) + assert len(updated) == 2 + assert not expired + + # Step 2: Add value=2 at timestamp 60ms + # Both windows get count=2 + updated, expired = _process(value=2, timestamp_ms=60) + assert len(updated) == 2 + assert not expired + + # Step 3: Add value=3 at timestamp 70ms + # Both windows: count would be 3, triggers BEFORE adding + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + + # Window [0, 100) + assert expired[0][1]["count"] == 2 # Before the update (not 3) + assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) + # CRITICAL: collect should NOT include the triggering value (3) + assert expired[0][1]["collect"] == [1, 2] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Window [50, 150) + assert expired[1][1]["count"] == 2 # Before the update (not 3) + assert expired[1][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) + # CRITICAL: collect should NOT include the triggering value (3) + assert expired[1][1]["collect"] == [1, 2] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add value=4 at timestamp 100ms + # This belongs to windows [50, 150) and [100, 200) + # The triggering value (3) should still be in collection storage + updated, expired = _process(value=4, timestamp_ms=100) + assert len(updated) == 2 + assert not expired + + # Step 5: Force natural expiration to verify collection includes triggering value + # Windows that were deleted by trigger won't resurrect in hopping windows + # since they were explicitly deleted. Let's verify the triggering value + # was still added to collection by adding more values to a later window + updated, expired = _process(value=5, timestamp_ms=120) + assert len(updated) == 2 # Windows [50,150) resurrected and [100,200) + assert not expired + + # Force expiration at timestamp 260 (well past grace period) + updated, expired = _process(value=6, timestamp_ms=260) + # This should expire windows that existed + assert len(expired) >= 1 + + # The key point: the triggering value (3) WAS added to collection storage + # So any window that overlaps with that timestamp includes it + # Verify at least one expired window contains the triggering value + found_triggering_value = False + for _, window_result in expired: + if 3 in window_result["collect"]: + found_triggering_value = True + break + assert ( + found_triggering_value + ), "Triggering value (3) should be in collection storage" + @pytest.mark.parametrize( "duration, grace, step, provided_name, func_name, expected_name", [ diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py index fc5ab8eba..2d763583d 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py @@ -21,9 +21,13 @@ } -def process(window, value, key, transaction, timestamp_ms): +def process(window, value, key, transaction, timestamp_ms, headers=None): updated, expired = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + value=value, + key=key, + transaction=transaction, + timestamp_ms=timestamp_ms, + headers=headers, ) return list(updated), list(expired) diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py index 98d9f56c1..e363723b2 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py @@ -11,26 +11,345 @@ @pytest.fixture() def tumbling_window_definition_factory(state_manager, dataframe_factory): - def factory(duration_ms: int, grace_ms: int = 0) -> TumblingTimeWindowDefinition: + def factory( + duration_ms: int, + grace_ms: int = 0, + before_update=None, + after_update=None, + ) -> TumblingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = TumblingTimeWindowDefinition( - duration_ms=duration_ms, grace_ms=grace_ms, dataframe=sdf + duration_ms=duration_ms, + grace_ms=grace_ms, + dataframe=sdf, + before_update=before_update, + after_update=after_update, ) return window_def return factory -def process(window, value, key, transaction, timestamp_ms): +def process(window, value, key, transaction, timestamp_ms, headers=None): updated, expired = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=headers, + transaction=transaction, ) return list(updated), list(expired) class TestTumblingWindow: + def test_tumbling_window_with_after_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + # Define a trigger that expires the window when the sum reaches 9 or more + def trigger_on_sum_9(aggregated, value, key, timestamp, headers) -> bool: + return aggregated >= 9 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, after_update=trigger_on_sum_9 + ) + window = window_def.sum() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=2, sum becomes 2, delta from 0 is 2, should not trigger + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + # Add value=2, sum becomes 4, delta from 2 is 2, should not trigger + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 4 + assert not expired + + # Add value=5, sum becomes 9, delta from 4 is 5, should trigger (>= 5) + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + assert expired[0][1]["value"] == 9 + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + def test_tumbling_window_with_before_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that before_update callback works and triggers before aggregation.""" + + # Define a trigger that expires the window before adding a value + # if the sum would exceed 10 + def trigger_before_exceeding_10( + aggregated, value, key, timestamp, headers + ) -> bool: + return (aggregated + value) > 10 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, before_update=trigger_before_exceeding_10 + ) + window = window_def.sum() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=3, sum becomes 3, would not exceed 10, should not trigger + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + # Add value=5, sum becomes 8, would not exceed 10, should not trigger + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 8 + assert not expired + + # Add value=3, would make sum 11 which exceeds 10, should trigger BEFORE adding + # So the expired window should have value=8 (not 11) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + assert expired[0][1]["value"] == 8 # Before the update (not 11) + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + def test_tumbling_window_collect_with_after_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that after_update callback works with collect.""" + + # Define a trigger that expires the window when we collect 3 or more items + def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: + # For collect, aggregated is the list of collected values + return len(aggregated) >= 3 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, after_update=trigger_on_count_3 + ) + window = window_def.collect() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add first value - should not trigger (count=1) + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert not updated # collect doesn't emit on updates + assert not expired + + # Add second value - should not trigger (count=2) + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert not updated + assert not expired + + # Add third value - should trigger (count=3) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value at t=80 still belongs to window [0, 100) + # Window is "resurrected" because collection values weren't deleted + # (we let normal expiration handle cleanup for simplicity) + # Window [0, 100) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + def test_tumbling_window_collect_with_before_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that before_update callback works with collect.""" + + # Define a trigger that expires the window before adding a value + # if the collection would reach 3 or more items + def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: + # For collect, aggregated is the list of collected values BEFORE adding the new value + return len(aggregated) + 1 >= 3 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 + ) + window = window_def.collect() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add first value - should not trigger (count would be 1) + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert not updated # collect doesn't emit on updates + assert not expired + + # Add second value - should not trigger (count would be 2) + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert not updated + assert not expired + + # Add third value - should trigger BEFORE adding (count would be 3) + # Expired window should have [1, 2] (not [1, 2, 3]) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2] # Before adding the third value + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start accumulating in the same window again + # (window was deleted but collection values remain until natural expiration) + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert not updated + # Window [0, 100) is "resurrected" with [1, 2, 3] + # Adding value 4 would make it 4 items, triggers again + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4 + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + def test_tumbling_window_agg_and_collect_with_before_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test before_update with BOTH aggregation and collect. + + This verifies that: + 1. The triggered window does NOT include the triggering value in collect + 2. The triggering value IS still added to collection storage for future + 3. The aggregated value is BEFORE the triggering value + """ + import quixstreams.dataframe.windows.aggregations as agg + + # Trigger when count would reach 3 + def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: + # In multi-aggregation, keys are like 'count/Count', 'sum/Sum' + # Find the count aggregation value + for k, v in agg_dict.items(): + if k.startswith("count"): + return v + 1 >= 3 + return False + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 + ) + window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=1, count becomes 1 + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert not expired + + # Add value=2, count becomes 2 + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert not expired + + # Add value=3, would make count 3 + # Should trigger BEFORE adding + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + + assert expired[0][1]["count"] == 2 # Before the update (not 3) + assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) + # CRITICAL: collect should NOT include the triggering value (3) + assert expired[0][1]["collect"] == [1, 2] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + # But the triggering value (3) should still be in storage + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert not expired + + # Force window expiration to see what was collected + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 1 + # The collection should include the triggering value (3) that was added to storage + # even though it wasn't in the triggered window result + assert expired[0][1]["collect"] == [1, 2, 3, 4] # All values before t=110 + @pytest.mark.parametrize( "duration, grace, provided_name, func_name, expected_name", [