Skip to content

Commit b1f9f5b

Browse files
committed
Move SessionWindow to a dedicated module
1 parent 11949a1 commit b1f9f5b

File tree

4 files changed

+373
-348
lines changed

4 files changed

+373
-348
lines changed

quixstreams/dataframe/windows/definitions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
CountWindowMultiAggregation,
2323
CountWindowSingleAggregation,
2424
)
25+
from .session import (
26+
SessionWindow,
27+
SessionWindowMultiAggregation,
28+
SessionWindowSingleAggregation,
29+
)
2530
from .sliding import (
2631
SlidingWindow,
2732
SlidingWindowMultiAggregation,
2833
SlidingWindowSingleAggregation,
2934
)
3035
from .time_based import (
31-
SessionWindow,
32-
SessionWindowMultiAggregation,
33-
SessionWindowSingleAggregation,
3436
TimeWindow,
3537
TimeWindowMultiAggregation,
3638
TimeWindowSingleAggregation,
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
import logging
2+
import time
3+
from typing import TYPE_CHECKING, Any, Iterable, Optional
4+
5+
from quixstreams.context import message_context
6+
from quixstreams.state import WindowedPartitionTransaction, WindowedState
7+
8+
from .base import (
9+
MultiAggregationWindowMixin,
10+
SingleAggregationWindowMixin,
11+
Window,
12+
WindowKeyResult,
13+
WindowOnLateCallback,
14+
)
15+
from .time_based import ClosingStrategy, ClosingStrategyValues
16+
17+
if TYPE_CHECKING:
18+
from quixstreams.dataframe.dataframe import StreamingDataFrame
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class SessionWindow(Window):
24+
"""
25+
Session window groups events that occur within a specified timeout period.
26+
27+
A session starts with the first event and extends each time a new event arrives
28+
within the timeout period. The session closes after the timeout period with no
29+
new events.
30+
31+
Each session window can have different start and end times based on the actual
32+
events, making sessions dynamic rather than fixed-time intervals.
33+
"""
34+
35+
def __init__(
36+
self,
37+
timeout_ms: int,
38+
grace_ms: int,
39+
name: str,
40+
dataframe: "StreamingDataFrame",
41+
on_late: Optional[WindowOnLateCallback] = None,
42+
):
43+
super().__init__(
44+
name=name,
45+
dataframe=dataframe,
46+
)
47+
48+
self._timeout_ms = timeout_ms
49+
self._grace_ms = grace_ms
50+
self._on_late = on_late
51+
self._closing_strategy = ClosingStrategy.KEY
52+
53+
def final(
54+
self, closing_strategy: ClosingStrategyValues = "key"
55+
) -> "StreamingDataFrame":
56+
"""
57+
Apply the session window aggregation and return results only when the sessions
58+
are closed.
59+
60+
The format of returned sessions:
61+
```python
62+
{
63+
"start": <session start time in milliseconds>,
64+
"end": <session end time in milliseconds>,
65+
"value: <aggregated session value>,
66+
}
67+
```
68+
69+
The individual session is closed when the event time
70+
(the maximum observed timestamp across the partition) passes
71+
the last event timestamp + timeout + grace period.
72+
The closed sessions cannot receive updates anymore and are considered final.
73+
74+
:param closing_strategy: the strategy to use when closing sessions.
75+
Possible values:
76+
- `"key"` - messages advance time and close sessions with the same key.
77+
If some message keys appear irregularly in the stream, the latest sessions can remain unprocessed until a message with the same key is received.
78+
- `"partition"` - messages advance time and close sessions for the whole partition to which this message key belongs.
79+
If timestamps between keys are not ordered, it may increase the number of discarded late messages.
80+
Default - `"key"`.
81+
"""
82+
self._closing_strategy = ClosingStrategy.new(closing_strategy)
83+
return super().final()
84+
85+
def current(
86+
self, closing_strategy: ClosingStrategyValues = "key"
87+
) -> "StreamingDataFrame":
88+
"""
89+
Apply the session window transformation to the StreamingDataFrame to return results
90+
for each updated session.
91+
92+
The format of returned sessions:
93+
```python
94+
{
95+
"start": <session start time in milliseconds>,
96+
"end": <session end time in milliseconds>,
97+
"value: <aggregated session value>,
98+
}
99+
```
100+
101+
This method processes streaming data and returns results as they come,
102+
regardless of whether the session is closed or not.
103+
104+
:param closing_strategy: the strategy to use when closing sessions.
105+
Possible values:
106+
- `"key"` - messages advance time and close sessions with the same key.
107+
If some message keys appear irregularly in the stream, the latest sessions can remain unprocessed until a message with the same key is received.
108+
- `"partition"` - messages advance time and close sessions for the whole partition to which this message key belongs.
109+
If timestamps between keys are not ordered, it may increase the number of discarded late messages.
110+
Default - `"key"`.
111+
"""
112+
self._closing_strategy = ClosingStrategy.new(closing_strategy)
113+
return super().current()
114+
115+
def process_window(
116+
self,
117+
value: Any,
118+
key: Any,
119+
timestamp_ms: int,
120+
transaction: WindowedPartitionTransaction,
121+
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
122+
state = transaction.as_state(prefix=key)
123+
timeout_ms = self._timeout_ms
124+
grace_ms = self._grace_ms
125+
126+
collect = self.collect
127+
aggregate = self.aggregate
128+
129+
# Determine the latest timestamp for expiration logic
130+
if self._closing_strategy == ClosingStrategy.PARTITION:
131+
latest_expired_timestamp = transaction.get_latest_expired(prefix=b"")
132+
latest_timestamp = max(timestamp_ms, latest_expired_timestamp)
133+
else:
134+
state_ts = state.get_latest_timestamp() or 0
135+
latest_timestamp = max(timestamp_ms, state_ts)
136+
137+
# Calculate session expiry threshold
138+
session_expiry_threshold = latest_timestamp - grace_ms
139+
140+
# Check if the event is too late
141+
if timestamp_ms < session_expiry_threshold:
142+
late_by_ms = session_expiry_threshold - timestamp_ms
143+
self._on_expired_session(
144+
value=value,
145+
key=key,
146+
start=timestamp_ms,
147+
end=timestamp_ms + timeout_ms,
148+
timestamp_ms=timestamp_ms,
149+
late_by_ms=late_by_ms,
150+
)
151+
return [], []
152+
153+
# Look for an existing session that can be extended
154+
can_extend_session = False
155+
existing_aggregated = None
156+
old_window_to_delete = None
157+
158+
# Search for active sessions that can accommodate the new event
159+
search_start = max(0, timestamp_ms - timeout_ms * 2)
160+
windows = state.get_windows(
161+
search_start, timestamp_ms + timeout_ms + 1, backwards=True
162+
)
163+
164+
for (window_start, window_end), aggregated_value, _ in windows:
165+
# Calculate the time gap between the new event and the session's last activity
166+
session_last_activity = window_end - timeout_ms
167+
time_gap = timestamp_ms - session_last_activity
168+
169+
# Check if this session can be extended
170+
if time_gap <= timeout_ms + grace_ms and timestamp_ms >= window_start:
171+
session_start = window_start
172+
session_end = timestamp_ms + timeout_ms
173+
can_extend_session = True
174+
existing_aggregated = aggregated_value
175+
old_window_to_delete = (window_start, window_end)
176+
break
177+
178+
# If no extendable session found, start a new one
179+
if not can_extend_session:
180+
session_start = timestamp_ms
181+
session_end = timestamp_ms + timeout_ms
182+
183+
# Process the event for this session
184+
updated_windows: list[WindowKeyResult] = []
185+
186+
# Delete the old window if extending an existing session
187+
if can_extend_session and old_window_to_delete:
188+
old_start, old_end = old_window_to_delete
189+
transaction.delete_window(old_start, old_end, prefix=state._prefix) # type: ignore # noqa: SLF001
190+
191+
# Add to collection if needed
192+
if collect:
193+
state.add_to_collection(
194+
value=self._collect_value(value),
195+
id=timestamp_ms,
196+
)
197+
198+
# Update the session window aggregation
199+
aggregated = None
200+
if aggregate:
201+
current_value = (
202+
existing_aggregated if can_extend_session else self._initialize_value()
203+
)
204+
205+
aggregated = self._aggregate_value(current_value, value, timestamp_ms)
206+
updated_windows.append(
207+
(
208+
key,
209+
self._results(aggregated, [], session_start, session_end),
210+
)
211+
)
212+
213+
state.update_window(
214+
session_start, session_end, value=aggregated, timestamp_ms=timestamp_ms
215+
)
216+
217+
# Expire old sessions
218+
if self._closing_strategy == ClosingStrategy.PARTITION:
219+
expired_windows = self.expire_sessions_by_partition(
220+
transaction, session_expiry_threshold, collect
221+
)
222+
else:
223+
expired_windows = self.expire_sessions_by_key(
224+
key, state, session_expiry_threshold, collect
225+
)
226+
227+
return updated_windows, expired_windows
228+
229+
def expire_sessions_by_partition(
230+
self,
231+
transaction: WindowedPartitionTransaction,
232+
expiry_threshold: int,
233+
collect: bool,
234+
) -> Iterable[WindowKeyResult]:
235+
start = time.monotonic()
236+
count = 0
237+
238+
# Import the parsing function to extract message keys from window keys
239+
from quixstreams.state.rocksdb.windowed.serialization import parse_window_key
240+
241+
expired_results = []
242+
243+
# Collect all keys and extract unique prefixes to avoid iteration conflicts
244+
all_keys = list(transaction.keys())
245+
seen_prefixes = set()
246+
247+
for key_bytes in all_keys:
248+
try:
249+
prefix, start_ms, end_ms = parse_window_key(key_bytes)
250+
if prefix not in seen_prefixes:
251+
seen_prefixes.add(prefix)
252+
except (ValueError, IndexError):
253+
# Skip invalid window key formats
254+
continue
255+
256+
# Expire sessions for each unique prefix
257+
for prefix in seen_prefixes:
258+
state = transaction.as_state(prefix=prefix)
259+
prefix_expired = list(
260+
self.expire_sessions_by_key(prefix, state, expiry_threshold, collect)
261+
)
262+
expired_results.extend(prefix_expired)
263+
count += len(prefix_expired)
264+
265+
if count:
266+
logger.debug(
267+
"Expired %s session windows in %ss",
268+
count,
269+
round(time.monotonic() - start, 2),
270+
)
271+
272+
return expired_results
273+
274+
def expire_sessions_by_key(
275+
self,
276+
key: Any,
277+
state: WindowedState,
278+
expiry_threshold: int,
279+
collect: bool,
280+
) -> Iterable[WindowKeyResult]:
281+
start = time.monotonic()
282+
count = 0
283+
284+
# Get all windows and check which ones have expired
285+
all_windows = list(
286+
state.get_windows(0, expiry_threshold + self._timeout_ms, backwards=False)
287+
)
288+
289+
windows_to_delete = []
290+
for (window_start, window_end), aggregated, _ in all_windows:
291+
# Session expires when the session end time has passed the expiry threshold
292+
if window_end <= expiry_threshold:
293+
collected = []
294+
if collect:
295+
collected = state.get_from_collection(window_start, window_end)
296+
297+
windows_to_delete.append((window_start, window_end))
298+
count += 1
299+
yield (
300+
key,
301+
self._results(aggregated, collected, window_start, window_end),
302+
)
303+
304+
# Clean up expired windows
305+
for window_start, window_end in windows_to_delete:
306+
state._transaction.delete_window( # type: ignore # noqa: SLF001
307+
window_start,
308+
window_end,
309+
prefix=state._prefix, # type: ignore # noqa: SLF001
310+
)
311+
if collect:
312+
state.delete_from_collection(window_end, start=window_start)
313+
314+
if count:
315+
logger.debug(
316+
"Expired %s session windows in %ss",
317+
count,
318+
round(time.monotonic() - start, 2),
319+
)
320+
321+
def _on_expired_session(
322+
self,
323+
value: Any,
324+
key: Any,
325+
start: int,
326+
end: int,
327+
timestamp_ms: int,
328+
late_by_ms: int,
329+
) -> None:
330+
ctx = message_context()
331+
to_log = True
332+
333+
# Trigger the "on_late" callback if provided
334+
if self._on_late:
335+
to_log = self._on_late(
336+
value,
337+
key,
338+
timestamp_ms,
339+
late_by_ms,
340+
start,
341+
end,
342+
self._name,
343+
ctx.topic,
344+
ctx.partition,
345+
ctx.offset,
346+
)
347+
if to_log:
348+
logger.warning(
349+
"Skipping session processing for the closed session "
350+
f"timestamp_ms={timestamp_ms} "
351+
f"session={(start, end)} "
352+
f"late_by_ms={late_by_ms} "
353+
f"store_name={self._name} "
354+
f"partition={ctx.topic}[{ctx.partition}] "
355+
f"offset={ctx.offset}"
356+
)
357+
358+
359+
class SessionWindowSingleAggregation(SingleAggregationWindowMixin, SessionWindow):
360+
pass
361+
362+
363+
class SessionWindowMultiAggregation(MultiAggregationWindowMixin, SessionWindow):
364+
pass

0 commit comments

Comments
 (0)