From 8e20bf1a305b269a9b8b80b6bc6972d28f705b0e Mon Sep 17 00:00:00 2001 From: Benjamin Gallusser Date: Wed, 3 Dec 2025 09:21:37 -0800 Subject: [PATCH 1/2] Expose prediction batch size to CLI --- trackastra/cli.py | 12 +++++++++++- trackastra/model/model_api.py | 18 +++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/trackastra/cli.py b/trackastra/cli.py index 11dbf4b..cd44b98 100644 --- a/trackastra/cli.py +++ b/trackastra/cli.py @@ -84,6 +84,12 @@ def cli(): " falling back to cpu." ), ) + p_track.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size for model inference. If not set, uses device-dependent default.", + ) p_track.set_defaults(cmd=_track_from_disk) if len(sys.argv) == 1: @@ -114,7 +120,11 @@ def _track_from_disk(args): ) track_graph, masks = model.track_from_disk( - args.imgs, args.masks, mode=args.mode, max_distance=args.max_distance + args.imgs, + args.masks, + mode=args.mode, + max_distance=args.max_distance, + batch_size=args.batch_size, ) if args.output_ctc: diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py index e16accc..243bfaf 100644 --- a/trackastra/model/model_api.py +++ b/trackastra/model/model_api.py @@ -132,6 +132,7 @@ def from_folder( dir: Path | str, device: str | None = None, checkpoint_path: str | None = None, + **kwargs, ): """Load a Trackastra model from a local folder. @@ -152,11 +153,17 @@ def from_folder( Path(dir).expanduser(), map_location="cpu", checkpoint_path=checkpoint_path ) train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader) - return cls(transformer=transformer, train_args=train_args, device=device) + return cls( + transformer=transformer, train_args=train_args, device=device, **kwargs + ) @classmethod def from_pretrained( - cls, name: str, device: str | None = None, download_dir: Path | None = None + cls, + name: str, + device: str | None = None, + download_dir: Path | None = None, + **kwargs, ): """Load a pretrained Trackastra model. @@ -172,7 +179,7 @@ def from_pretrained( """ folder = download_pretrained(name, download_dir) # download zip from github to location/name, then unzip - return cls.from_folder(folder, device=device) + return cls.from_folder(folder, device=device, **kwargs) def _predict( self, @@ -210,7 +217,8 @@ def _predict( as_torch=True, ) - logger.info("Predicting windows") + batch_size = batch_size or self.batch_size + logger.info(f"Predicting windows with batch size {batch_size}") predictions = predict_windows( windows=windows, features=features, @@ -218,7 +226,7 @@ def _predict( edge_threshold=edge_threshold, spatial_dim=masks.ndim - 1, progbar_class=progbar_class, - batch_size=batch_size or self.batch_size, + batch_size=batch_size, ) return predictions From 25be65163873952e85a91ff94341f1dc015553de Mon Sep 17 00:00:00 2001 From: Benjamin Gallusser Date: Wed, 3 Dec 2025 09:29:39 -0800 Subject: [PATCH 2/2] Expose n_workers for feature extraction to CLI --- trackastra/cli.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/trackastra/cli.py b/trackastra/cli.py index cd44b98..2fe3027 100644 --- a/trackastra/cli.py +++ b/trackastra/cli.py @@ -90,6 +90,12 @@ def cli(): default=None, help="Batch size for model inference. If not set, uses device-dependent default.", ) + p_track.add_argument( + "--n-workers", + type=int, + default=0, + help="Number of workers for feature extraction.", + ) p_track.set_defaults(cmd=_track_from_disk) if len(sys.argv) == 1: @@ -125,6 +131,7 @@ def _track_from_disk(args): mode=args.mode, max_distance=args.max_distance, batch_size=args.batch_size, + n_workers=args.n_workers, ) if args.output_ctc: