diff --git a/cellpose/contrib/cellposetrt/__init__.py b/cellpose/contrib/cellposetrt/__init__.py new file mode 100644 index 00000000..34a8e4c9 --- /dev/null +++ b/cellpose/contrib/cellposetrt/__init__.py @@ -0,0 +1,193 @@ +"""TensorRT-backed Cellpose model module. + + TensorRT is NVIDIA's neural‑network inference compiler/runtime for NVIDIA GPUs. It takes + an ONNX/graph, picks optimized kernels, fuses layers, and plans memory/scheduling + specialized for a given GPU architecture and fixed input profile. + + By specializing for fixed input shapes and fusing ops, TensorRT can deliver + higher performance than standard PyTorch inference (1.7x speedup in RTX 5090). + A CellposeSAM model can be converted to the TensorRT format by running + cellpose/contrib/cellposetrt/trt_build.py. + + `CellposeModelTRT(engine_path=...)` in this module is a drop-in replacement for + the standard `CellposeModel` to run CellposeSAM via TensorRT. +""" + +from pathlib import Path + +import tensorrt as trt +import torch + +from cellpose import models + + +class TRTEngineModule(torch.nn.Module): + """TensorRT-backed CellposeSAM model. + + It is not intended for auxiliary training/export variants that add extra outputs + (e.g., BioImage.IO downsampled tensors, denoise/perceptual losses). + + Notes + - Requires TensorRT >= 10. + - Engines are compiled for fixed profiles batch size and tile size. + """ + def __init__(self, engine_path: str|Path, device=torch.device("cuda")): + super().__init__() + + self.device = torch.device(device) + if self.device.type != "cuda": + raise RuntimeError( + f"TensorRT backend requires a CUDA device, got '{self.device.type}'. CPUs/MLX are unsupported." + ) + + ver = getattr(trt, "__version__", None) + if not ver: + raise RuntimeError("TensorRT >= 10 required (version unknown).") + try: + major = int(str(ver).split(".")[0]) + except Exception: + raise RuntimeError(f"TensorRT >= 10 required (found {ver}).") + if major < 10: + raise RuntimeError(f"TensorRT >= 10 required (found {ver}).") + + logger = trt.Logger(trt.Logger.ERROR) + with open(engine_path, "rb") as f, trt.Runtime(logger) as runtime: + try: + self._engine = runtime.deserialize_cuda_engine(f.read()) + except Exception as exc: + raise ValueError(f"{engine_path} is not a valid TensorRT engine") from exc + + self._ctx = self._engine.create_execution_context() + + # Names exported by our ONNX: 'input' -> y/style + self._name_in = "input" + self._name_y = "y" + self._name_s = "style" + + def _to_torch_dtype(dt): + if dt == trt.DataType.BF16: + return torch.bfloat16 + if dt == trt.DataType.HALF: + return torch.float16 + if dt == trt.DataType.FLOAT: + return torch.float32 + raise ValueError(f"Unsupported TensorRT dtype: {dt}") + + # Sanity: make sure the names exist and modes are right + for name in (self._name_in, self._name_y, self._name_s): + self._engine.get_tensor_dtype(name) + + # Capture per-tensor dtypes from engine + self._dtype_in = _to_torch_dtype(self._engine.get_tensor_dtype(self._name_in)) + self._dtype_y = _to_torch_dtype(self._engine.get_tensor_dtype(self._name_y)) + self._dtype_s = _to_torch_dtype(self._engine.get_tensor_dtype(self._name_s)) + + self.dtype = self._dtype_in + + # Detect fixed batch dimension from engine input shape (None if dynamic) + self._in_dims = tuple(self._engine.get_tensor_shape(self._name_in)) # (N,C,H,W) with -1 for dynamic dims + self._fixedN = self._in_dims[0] if self._in_dims[0]> 0 else None + + def forward(self, X: torch.Tensor): + if not X.is_cuda: + raise RuntimeError("Input must be a CUDA tensor") + if X.device != self.device: + X = X.to(self.device, non_blocking=True) + if X.dtype != self._dtype_in: + X = X.to(self._dtype_in) + X = X.contiguous() + N, C, H, W = X.shape + + # Require exact N match when engine has fixed batch. + if self._fixedN is not None and N != self._fixedN: + raise ValueError( + f"Input batch {N} must equal engine fixed batch N={self._fixedN}. " + f"Adjust batch_size or rebuild the engine." + ) + effective_N = self._fixedN or N + + # 1) Set input shape by name + self._ctx.set_input_shape(self._name_in, (effective_N, C, H, W)) + + # 2) Allocate outputs; query shapes from engine (may have -1 -> allocate by heuristics if needed) + # Cellpose heads are [N,3,H,W] and [N,S], so we can shape from input. + # Read S from engine if available; otherwise default to 256 (Cellpose style vec size) and adjust if needed. + try: + # If engine carries concrete dims (profile-dependent), use them + s_dims = tuple(self._engine.get_tensor_shape(self._name_s)) + if any(d < 0 for d in s_dims): + S = s_dims[-1] if s_dims[-1] > 0 else 256 + else: + S = s_dims[-1] + except Exception: + S = 256 + + y = torch.empty((effective_N, 3, H, W), device=X.device, dtype=self._dtype_y) + s = torch.empty((effective_N, S), device=X.device, dtype=self._dtype_s) + + stream = torch.cuda.current_stream(self.device) + stream_handle = int(stream.cuda_stream) + + self._ctx.set_tensor_address(self._name_in, int(X.data_ptr())) + self._ctx.set_tensor_address(self._name_y, int(y.data_ptr())) + self._ctx.set_tensor_address(self._name_s, int(s.data_ptr())) + + ok = self._ctx.execute_async_v3(stream_handle) + if not ok: + raise RuntimeError("TensorRT execute_async_v3 failed") + + return y, s + + +class CellposeModelTRT(models.CellposeModel): + """Drop-in replacement for CellposeModel (eval) using TensorRT. + + Preparation + - Build an engine for your model first with scripts/trt_build.py, for example: + python scripts/trt_build.py PRETRAINED -o OUTPUT.plan --batch-size 4 --bsize 256 + Then pass engine_path=OUTPUT.plan to this class. + + Contract + - Uses a TensorRT engine whose forward returns exactly (y, style) as defined + in TRTEngineModule; aligns with the main segmentation pipeline. + - Not intended for denoise/perceptual-loss training utilities or BioImage.IO + export paths that expect additional tensors beyond (y, style). + """ + + def __init__( + self, + gpu=False, + pretrained_model="cyto2", + model_type=None, + diam_mean=None, + device=None, + nchan=None, + use_bfloat16=True, + ): + engine_path = pretrained_model + if engine_path is None: + raise ValueError("TensorRT engine (.plan) must be generated from `trt_build.py` and provided via `pretrained_model`.") + engine_path = Path(engine_path) + if not engine_path.is_file(): + raise FileNotFoundError(f"TensorRT engine not found at {engine_path}") + self.engine_path = engine_path + + super().__init__( + gpu=gpu, + pretrained_model="cpsam", # dummy, not used + model_type=model_type, + diam_mean=diam_mean, + device=device, + nchan=nchan, + use_bfloat16=True, + ) + dev = torch.device("cuda" if device is None else device) + if not use_bfloat16: + raise ValueError("CellposeModelTRT only supports use_bfloat16=True") + + self.net = TRTEngineModule(engine_path, device=dev) + + def eval(self, x, **kwargs): + if kwargs.get("bsize", 256) != self.net._in_dims[2]: + raise ValueError(f"This engine only supports bsize={self.net._in_dims[2]} (built with this bsize)") + return super().eval(x, **kwargs) diff --git a/cellpose/contrib/cellposetrt/trt_benchmark.py b/cellpose/contrib/cellposetrt/trt_benchmark.py new file mode 100644 index 00000000..5084081f --- /dev/null +++ b/cellpose/contrib/cellposetrt/trt_benchmark.py @@ -0,0 +1,275 @@ +# %% +"""Cellpose vs TensorRT benchmarking + +Compare full pipeline (masks/flows/styles) between Torch and TRT +using the same inputs, report IoU and percent error (sMAPE), +and time both implementations. + +Example usage: +python 'trt_benchmark.py' \ + --image=/data/registered/reg-0076.tif \ + --pretrained cpsam \ + --engine cpsam.plan \ + --batch-size=4 + +Example output: + Loaded tile: (2, 512, 512) uint16 + Engine path: /home/chaichontat/cellpose/scripts/builds/cpsam_b4_sm120_bf16.plan + Using CUDA device: cuda:0 | NVIDIA GeForce RTX 5090 + + [TEST] Full pipeline parity + masks: torch=(512, 512) trt=(512, 512) IoU=0.9986 + flow[0]: shape=(512, 512, 3) | sMAPE=2.257% MAE=0.176858 + flow[1]: shape=(2, 512, 512) | sMAPE=27.623% MAE=0.0060048 + flow[2]: shape=(512, 512) | sMAPE=0.816% MAE=0.0170394 + + [TIMING] Full pipeline eval(tile3) + Torch eval: 222.155 ms/iter (avg over 5, warmup=1) + TRT eval: 138.330 ms/iter (avg over 5, warmup=1) + Speedup vs Torch: x1.61 + + [TIMING] Net-only forward (Nx3x256x256) + Torch net: 15.930 ms/iter (CUDA events, iters=50, warmup=10) + TRT net : 7.110 ms/iter (CUDA events, iters=50, warmup=10) + Speedup (net-only): x2.24 + + [TEST] IoU parity on first 20 images from: /data/registered + 1/20 processed... IoU=0.9994 + ⋮ + 20/20 processed... IoU=0.9991 + IoU range: min=0.9987 median=0.9991 max=0.9996 (N=20) +""" + +from __future__ import annotations + +import argparse +import time +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import tifffile +import torch + +from cellpose import models +from cellpose.contrib.cellposetrt import CellposeModelTRT + +TILE_SLICE = np.s_[5, :, :512, :512] + + +def parse_args(): + ap = argparse.ArgumentParser(description="Cellpose vs TensorRT benchmarking") + ap.add_argument( + "--image", type=Path, required=True, help="Path to a test image (TIF)" + ) + ap.add_argument( + "--pretrained", + type=str, + required=True, + help="Path/name of pretrained Cellpose model", + ) + ap.add_argument( + "--engine", type=Path, required=True, help="TensorRT engine (.plan) path" + ) + ap.add_argument( + "--n-samples", + type=int, + default=20, + help="Number of folder images to test IoU on", + ) + ap.add_argument( + "--folder", + type=Path, + default=None, + help="Folder of images for IoU test; defaults to image's parent", + ) + ap.add_argument("--batch-size", type=int, default=4, help="Eval/engine batch size") + ap.add_argument( + "--save-masks", + type=Path, + default=None, + help="Optional output path (directory or .tif file) to save stacked masks from the IoU parity test", + ) + return ap.parse_args() + + +def print_smape(name: str, ref, tst) -> None: + r = torch.as_tensor(ref).float().flatten() + t = torch.as_tensor(tst).float().flatten() + diff = (t - r).abs() + mae = float(diff.mean()) + smape = float((2.0 * diff / (r.abs() + t.abs() + 1e-12)).mean() * 100.0) + print( + f"{name}: shape={tuple(torch.as_tensor(ref).shape)} | sMAPE={smape:.3f}% MAE={mae:.6g}" + ) + + +def time_op( + name: str, + fn: Callable, + *, + warmup: int = 1, + iters: int = 5, +) -> float: + # Warmup + for _ in range(warmup): + _ = fn() + + # Run + t0 = time.perf_counter() + for _ in range(iters): + _ = fn() + + torch.cuda.synchronize() + dt = (time.perf_counter() - t0) / iters + ms = dt * 1000.0 + print(f"{name}: {ms:.3f} ms/iter (avg over {iters}, warmup={warmup})") + return ms + + +def time_op_cuda( + name: str, + fn, + *, + warmup: int = 10, + iters: int = 50, +) -> float: + """GPU kernel timing using CUDA events (net-only). + + Records elapsed time on the current CUDA stream across `iters` calls. Does + not include Python/host sync beyond the final event synchronize. + """ + # Warmup to stabilize autotuning/caches + for _ in range(warmup): + _ = fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + _ = fn() + end.record() + end.synchronize() + ms = start.elapsed_time(end) / iters + print(f"{name}: {ms:.3f} ms/iter (CUDA events, iters={iters}, warmup={warmup})") + return ms + + +def iou_binary(a: np.ndarray, b: np.ndarray) -> float: + a = a.astype(bool) + b = b.astype(bool) + inter = np.logical_and(a, b).sum() + union = np.logical_or(a, b).sum() + return float(inter) / max(1, float(union)) + + +args = parse_args() + +save_masks_target: Path | None = None +if args.save_masks is not None: + save_masks_target = Path(args.save_masks) + save_masks_target.parent.mkdir(parents=True, exist_ok=True) + +eval_kwargs = dict( + batch_size=args.batch_size, + flow_threshold=0, + compute_masks=True, +) + +tile = tifffile.imread(args.image)[TILE_SLICE] +print("Loaded tile:", tile.shape, tile.dtype) +print(f"Engine path: {args.engine}") + +# ---- Build models ---- +device = torch.device("cuda:0") +print(f"Using CUDA device: {device} | {torch.cuda.get_device_name(device)}") + +base = models.CellposeModel(gpu=True, device=device, pretrained_model=args.pretrained) +trt_model = CellposeModelTRT( + gpu=True, + device=device, + pretrained_model=args.pretrained, + engine_path=str(args.engine), +) + +with torch.inference_mode(): + base_out = base.eval(tile, **eval_kwargs) + trt_out = trt_model.eval(tile, **eval_kwargs) + +print("\n[TEST] Full pipeline parity") +masks_pt, masks_trt = base_out[0], trt_out[0] +print( + f" masks: torch={masks_pt.shape} trt={masks_trt.shape} IoU={iou_binary(masks_pt != 0, masks_trt != 0):.4f}" +) + +flows_pt = base_out[1] +flows_trt = trt_out[1] +for k, (fpt, ftrt) in enumerate(zip(flows_pt, flows_trt)): + print_smape(f" flow[{k}]", fpt, ftrt) + +# Timing (full pipeline): +with torch.inference_mode(): + print("\n[TIMING] Full pipeline eval(tile3)") + ms_base = time_op(" Torch eval", lambda: base.eval(tile, **eval_kwargs)) + ms_trt = time_op( + " TRT eval", + lambda: models.CellposeModel.eval(trt_model, tile, **eval_kwargs), + ) + +spd = ms_base / ms_trt +print(f" Speedup vs Torch: x{spd:.2f}") + +# Net-only timing on representative Nx3x256x256 batch (CUDA events) +with torch.inference_mode(): + print(f"\n[TIMING] Net-only forward ({args.batch_size}x3x256x256)") + Xb = torch.randn(args.batch_size, 3, 256, 256, device=device, dtype=torch.bfloat16) + ms_torch_net = time_op_cuda(" Torch net", lambda: base.net(Xb)) + ms_trt_net = time_op_cuda(" TRT net ", lambda: trt_model.net(Xb)) + if ms_trt_net > 0: + print(f" Speedup (net-only): x{ms_torch_net / ms_trt_net:.2f}") + +# ---- TEST: Folder IoU on first N images (Torch vs TRT masks) ---- +folder = args.folder or args.image.parent +files = [p for p in sorted(folder.glob("*.tif"))] +sub = files[: args.n_samples] + +print(f"\n[TEST] IoU parity on first {len(sub)} images from: {folder}") +ious: list[float] = [] +saved_masks: list[np.ndarray] = [] +for idx, f in enumerate(sub): + try: + arr = tifffile.imread(f)[TILE_SLICE] + with torch.inference_mode(): + out_t = base.eval(arr, **eval_kwargs) + out_r = trt_model.eval(arr, **eval_kwargs) + if not np.any(out_t[0]) or not np.any(out_r[0]): + print( + f" [warn] skipping {f.name}: empty masks from at least one of the models" + ) + continue + + m_t = out_t[0] + m_r = out_r[0] + iou = iou_binary(m_t != 0, m_r != 0) + ious.append(iou) + if save_masks_target is not None: + saved_masks.append(np.stack((m_t, m_r), axis=0)) + + print(f" {idx + 1}/{len(sub)} processed... IoU={iou:.4f}") + except Exception as e: + print(f" [warn] skipping {f.name}: {e}") + +a = np.array(ious, dtype=float) +print( + f" IoU range: min={a.min():.4f} median={np.median(a):.4f} max={a.max():.4f} (N={len(a)})" +) + +if save_masks_target is not None: + stacked = np.stack(saved_masks, axis=0) + tifffile.imwrite( + save_masks_target, + stacked, + metadata={"axes": "TCYX"}, + compression="zstd", + ) + print(f"Saved masks to {save_masks_target} with shape {stacked.shape}") diff --git a/cellpose/contrib/cellposetrt/trt_build.py b/cellpose/contrib/cellposetrt/trt_build.py new file mode 100644 index 00000000..aa7d2de6 --- /dev/null +++ b/cellpose/contrib/cellposetrt/trt_build.py @@ -0,0 +1,165 @@ +"""Builds a TensorRT engine (.plan) from a Cellpose model. + +Speed up of 1.7x (batch size 4) - 2.2x (batch size 1) observed for +CellposeSAM net on RTX 5090 with BF16 engine +compared to the native PyTorch bfloat16 inference. + +Requirements +- NVIDIA GPU with BF16 support (SM80+, e.g., Ampere or newer). +- TensorRT >= 10 (Python bindings with the tensors API). +- PyTorch with ONNX exporter. +- CellposeSAM bfloat16 weights (pretrained_model path). + +Dependencies can be installed via pip: + `pip install tensorrt-cu12 nvidia-cuda-runtime-cu12` +Ensure that the requested CUDA version matches your environment's CUDA version. + +Behavior +- Exports ONNX that returns exactly (y, style), matching Cellpose segmentation. +- Dynamic batch profile: N in [1, batch-size], C=3, H=W=bsize. +- BF16 engine with builder_optimization_level=3 and OBEY_PRECISION_CONSTRAINTS. + +Gotchas +- Plan files are not portable: they are specific to GPU arch (SM) and + TensorRT/CUDA/driver. Rebuild on each host/GPU family; +- Spatial tensor size is fixed (H=W=bsize). Batch is dynamic in [1, batch-size]. + +Usage + python trt_build.py PRETRAINED -o OUTPUT.plan [--batch-size N] [--bsize 256] [--vram 12000] [--opset 20] + +Runs in ~2 minutes on a Threadripper 7990X with RTX 5090 +Tested on tensorrt-cu12 10.13.3.9, torch 2.9.0+cu128, Python 3.13.9 +nvidia-driver-open 570.195.03, Ubuntu 24.04.2 +""" + +import argparse +import os +from pathlib import Path + +import tensorrt as trt +import torch + +from cellpose import models + + +class _CPNetWrapper(torch.nn.Module): + """Wrap Cellpose net to expose exactly (y, style) for ONNX export. + + Contract matches the main segmentation workflow and the TensorRT engine. + """ + + def __init__(self, net: torch.nn.Module): + super().__init__() + self.net = net + + def forward(self, x): + y, style = self.net(x)[:2] + return y, style + + +def export_onnx(pretrained_model: str, onnx_out: str, *, batch_size: int, bsize: int, opset: int = 20): + device = torch.device("cuda") + model = models.CellposeModel(gpu=True, pretrained_model=pretrained_model, use_bfloat16=True) + net = model.net.to(device).eval() + + # Ensure weights are BF16 as expected + param_dtypes = {p.dtype for p in net.parameters()} + if torch.float32 in param_dtypes: + raise RuntimeError(f"Loaded model contains FP32 parameters: {param_dtypes}. Expected BF16 only.") + wrapper = _CPNetWrapper(net) + + dummy = torch.randn(batch_size, 3, bsize, bsize, device=device, dtype=torch.bfloat16) + Path(os.path.dirname(onnx_out) or ".").mkdir(parents=True, exist_ok=True) + with torch.no_grad(): + torch.onnx.export( + wrapper, + dummy, + onnx_out, + opset_version=opset, + dynamo=False, + input_names=["input"], + output_names=["y", "style"], + dynamic_axes={ + "input": {0: "batch"}, + "y": {0: "batch"}, + "style": {0: "batch"} + }, + do_constant_folding=True, + ) + print(f"Exported ONNX to {onnx_out}.") + + +def build_engine(onnx_path: str, plan_path: str, *, bsize: int, vram: int, batch_size: int): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required to build a TensorRT engine.") + logger = trt.Logger(trt.Logger.ERROR) + builder = trt.Builder(logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + + if not parser.parse_from_file(onnx_path): + for i in range(parser.num_errors): + print(parser.get_error(i)) + raise RuntimeError(f"Failed to parse ONNX file {onnx_path}") + + config = builder.create_builder_config() + config.builder_optimization_level = 3 # 4 doesn't help, 5 is too slow + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, vram * (1 << 20)) + config.set_flag(trt.BuilderFlag.BF16) + + # Improves performance (1.7x -> 2x on RTX 5090) but worse IoU and rough masks + # Need mixed precision training to use this properly + # config.set_flag(trt.BuilderFlag.FP16) + + # Sanity check: input must be NCHW with static C + inp = network.get_input(0) + if inp.shape is None or len(inp.shape) != 4: + raise ValueError(f"Expected NCHW input, got {inp.shape}") + + _, C, _, _ = tuple(inp.shape) + if not isinstance(C, int) or C <= 0: + raise ValueError(f"Channel dimension must be static/int in ONNX, got {C}") + + Nmax = int(batch_size) + if Nmax < 1: + raise ValueError("--batch-size must be >= 1") + + # Dynamic batch: allow [1, Nmax] to handle remainders during eval + min_shape = (1, C, bsize, bsize) + opt_shape = (Nmax, C, bsize, bsize) + max_shape = (Nmax, C, bsize, bsize) + profile = builder.create_optimization_profile() + profile.set_shape(inp.name, min_shape, opt_shape, max_shape) + config.add_optimization_profile(profile) + + engine_blob = builder.build_serialized_network(network, config) + if engine_blob is None: + raise RuntimeError("TensorRT engine build failed or returned empty blob.") + data = bytes(engine_blob) + + out_dir = Path(plan_path).parent + out_dir.mkdir(parents=True, exist_ok=True) + with open(plan_path, "wb") as f: + f.write(data) + print(f"Saved TensorRT engine: {plan_path} (N∈[1,{Nmax}], C={C}, H=W={bsize}, dtype=bf16)") + + +def main(): + ap = argparse.ArgumentParser(description="Export Cellpose net to ONNX and build TensorRT engine") + ap.add_argument("pretrained_model", type=str, help="Path/name of pretrained model (e.g., cpsam)") + ap.add_argument("-o", "--output", type=str, required=True, help="TensorRT engine output path (.plan)") + ap.add_argument("--vram", type=int, default=12000, help="Amount of GPU memory available (in MB) for TensorRT to optimize for") + ap.add_argument("--batch-size", type=int, default=1, help="Max batch dimension N (engine supports dynamic [1..N])") + ap.add_argument("--bsize", type=int, default=256, help="Tile size (256x256 by default)") + ap.add_argument("--opset", type=int, default=20, help="ONNX opset version to use for export") + args = ap.parse_args() + + plan_path = args.output + p = Path(plan_path) + onnx_out = str(p.with_suffix(".onnx")) + export_onnx(args.pretrained_model, onnx_out, batch_size=args.batch_size, bsize=args.bsize, opset=args.opset) + build_engine(onnx_out, plan_path, batch_size=args.batch_size, bsize=args.bsize, vram=args.vram) + + +if __name__ == "__main__": + main()