|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import threading |
3 | 4 | import time |
4 | 5 | from collections import deque |
5 | 6 | from collections.abc import Callable, Generator |
6 | 7 | from dataclasses import dataclass |
| 8 | +from queue import Queue |
7 | 9 |
|
8 | 10 | import cv2 |
9 | 11 | import numpy as np |
@@ -255,6 +257,130 @@ def callback(scene: np.ndarray, index: int) -> np.ndarray: |
255 | 257 | sink.write_frame(frame=result_frame) |
256 | 258 |
|
257 | 259 |
|
| 260 | +def process_video_threads( |
| 261 | + source_path: str, |
| 262 | + target_path: str, |
| 263 | + callback: Callable[[np.ndarray, int], np.ndarray], |
| 264 | + *, |
| 265 | + max_frames: int | None = None, |
| 266 | + prefetch: int = 32, |
| 267 | + writer_buffer: int = 32, |
| 268 | + show_progress: bool = False, |
| 269 | + progress_message: str = "Processing video (with threads)", |
| 270 | +) -> None: |
| 271 | + """ |
| 272 | + Process a video using a threaded pipeline that asynchronously |
| 273 | + reads frames, applies a callback to each, and writes the results |
| 274 | + to an output file. |
| 275 | +
|
| 276 | + Overview: |
| 277 | + This function implements a three-stage pipeline designed to maximize |
| 278 | + frame throughput. |
| 279 | +
|
| 280 | + │ Reader │ >> │ Processor │ >> │ Writer │ |
| 281 | + (thread) (main) (thread) |
| 282 | +
|
| 283 | + - Reader thread: reads frames from disk into a bounded queue ('read_q') |
| 284 | + until full, then blocks. This ensures we never load more than 'prefetch' |
| 285 | + frames into memory at once. |
| 286 | +
|
| 287 | + - Main thread: dequeues frames, applies the 'callback(frame, idx)', |
| 288 | + and enqueues the processed result into 'write_q'. |
| 289 | + This is the compute stage. It's important to note that it's not threaded, |
| 290 | + so you can safely use any detectors, trackers, or other stateful objects |
| 291 | + without synchronization issues. |
| 292 | +
|
| 293 | + - Writer thread: dequeues frames and writes them to disk. |
| 294 | +
|
| 295 | + Both queues are bounded to enforce back-pressure: |
| 296 | + - The reader cannot outpace processing (avoids unbounded RAM usage). |
| 297 | + - The processor cannot outpace writing (avoids output buffer bloat). |
| 298 | +
|
| 299 | + Summary: |
| 300 | + - It's thread-safe: because the callback runs only in the main thread, |
| 301 | + using a single stateful detector/tracker inside callback does not require |
| 302 | + synchronization with the reader/writer threads. |
| 303 | +
|
| 304 | + - While the main thread processes frame N, the reader is already decoding frame N+1, |
| 305 | + and the writer is encoding frame N-1. They operate concurrently without blocking |
| 306 | + each other. |
| 307 | +
|
| 308 | + - When is it fastest? |
| 309 | + - When there's heavy computation in the callback function that releases |
| 310 | + the Python GIL (for example, OpenCV filters, resizes, color conversions, ...) |
| 311 | + - When using CUDA or GPU-accelerated inference. |
| 312 | +
|
| 313 | + - When is it better not to use it? |
| 314 | + - When the callback function is Python-heavy and GIL-bound. In that case, |
| 315 | + using a process-based approach is more effective. |
| 316 | +
|
| 317 | + Args: |
| 318 | + source_path (str): The path to the source video file. |
| 319 | + target_path (str): The path to the target video file. |
| 320 | + callback (Callable[[np.ndarray, int], np.ndarray]): A function that takes in |
| 321 | + a numpy ndarray representation of a video frame and an |
| 322 | + int index of the frame and returns a processed numpy ndarray |
| 323 | + representation of the frame. |
| 324 | + max_frames (Optional[int]): The maximum number of frames to process. |
| 325 | + prefetch (int): The maximum number of frames buffered by the reader thread. |
| 326 | + writer_buffer (int): The maximum number of frames buffered before writing. |
| 327 | + show_progress (bool): Whether to show a progress bar. |
| 328 | + progress_message (str): The message to display in the progress bar. |
| 329 | + """ |
| 330 | + |
| 331 | + source_video_info = VideoInfo.from_video_path(video_path=source_path) |
| 332 | + total_frames = ( |
| 333 | + min(source_video_info.total_frames, max_frames) |
| 334 | + if max_frames is not None |
| 335 | + else source_video_info.total_frames |
| 336 | + ) |
| 337 | + |
| 338 | + # Each queue includes frames + sentinel |
| 339 | + read_q: Queue[tuple[int, np.ndarray] | None] = Queue(maxsize=prefetch) |
| 340 | + write_q: Queue[np.ndarray | None] = Queue(maxsize=writer_buffer) |
| 341 | + |
| 342 | + def reader_thread(): |
| 343 | + gen = get_video_frames_generator(source_path=source_path, end=max_frames) |
| 344 | + for idx, frame in enumerate(gen): |
| 345 | + read_q.put((idx, frame)) |
| 346 | + read_q.put(None) # sentinel |
| 347 | + |
| 348 | + def writer_thread(video_sink: VideoSink): |
| 349 | + while True: |
| 350 | + frame = write_q.get() |
| 351 | + if frame is None: |
| 352 | + break |
| 353 | + video_sink.write_frame(frame=frame) |
| 354 | + |
| 355 | + # Heads up! We set 'daemon=True' so this thread won't block program exit |
| 356 | + # if the main thread finishes first. |
| 357 | + t_reader = threading.Thread(target=reader_thread, daemon=True) |
| 358 | + with VideoSink(target_path=target_path, video_info=source_video_info) as sink: |
| 359 | + t_writer = threading.Thread(target=writer_thread, args=(sink,), daemon=True) |
| 360 | + t_reader.start() |
| 361 | + t_writer.start() |
| 362 | + |
| 363 | + process_bar = tqdm( |
| 364 | + total=total_frames, disable=not show_progress, desc=progress_message |
| 365 | + ) |
| 366 | + |
| 367 | + # Main thread: we take a frame, apply function and update process bar. |
| 368 | + while True: |
| 369 | + item = read_q.get() |
| 370 | + if item is None: |
| 371 | + break |
| 372 | + idx, frame = item |
| 373 | + out = callback(frame, idx) |
| 374 | + write_q.put(out) |
| 375 | + if total_frames is not None: |
| 376 | + process_bar.update(1) |
| 377 | + |
| 378 | + write_q.put(None) |
| 379 | + t_reader.join() |
| 380 | + t_writer.join() |
| 381 | + process_bar.close() |
| 382 | + |
| 383 | + |
258 | 384 | class FPSMonitor: |
259 | 385 | """ |
260 | 386 | A class for monitoring frames per second (FPS) to benchmark latency. |
|
0 commit comments