|
1 | | -from typing import Optional, Sequence, Union |
| 1 | +from typing import List, Optional, Sequence, Union |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import tensorrt as trt |
|
16 | 16 | ) |
17 | 17 |
|
18 | 18 |
|
| 19 | +def unify_trt_tensors( |
| 20 | + ctx: ConversionContext, |
| 21 | + target: Target, |
| 22 | + name: str, |
| 23 | + inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]], |
| 24 | + concat_axis: int, |
| 25 | + cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, |
| 26 | + force_trt_output: bool = False, |
| 27 | +) -> Union[TRTTensor, List[int]]: |
| 28 | + """ |
| 29 | + Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic. |
| 30 | +
|
| 31 | + Args: |
| 32 | + ctx: TensorRT conversion context. |
| 33 | + target: FX target for naming. |
| 34 | + name: Base name for layers. |
| 35 | + inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors. |
| 36 | + concat_axis: Axis along which to concatenate tensors if dynamic. |
| 37 | + cast_dtype: Optional target dtype for casting TRT tensors. |
| 38 | + force_trt_output: If True, return TRT tensor even if all inputs are static ints. |
| 39 | + """ |
| 40 | + has_dynamic = any(not isinstance(x, int) for x in inputs) |
| 41 | + trt_tensors = [] |
| 42 | + |
| 43 | + for i, x in enumerate(inputs): |
| 44 | + # convert to TRTTensor |
| 45 | + if isinstance(x, TRTTensor): |
| 46 | + t = x |
| 47 | + elif isinstance(x, int) and not has_dynamic and not force_trt_output: |
| 48 | + t = x # pure static path |
| 49 | + else: |
| 50 | + t = ctx.net.add_constant((1,), np.array([x], dtype=np.int32)) |
| 51 | + set_layer_name(t, target, f"{name}_dim{i}_const") |
| 52 | + t = t.get_output(0) |
| 53 | + |
| 54 | + # optional cast |
| 55 | + if cast_dtype and isinstance(t, TRTTensor): |
| 56 | + t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}") |
| 57 | + |
| 58 | + trt_tensors.append(t) |
| 59 | + |
| 60 | + if not has_dynamic and not force_trt_output: |
| 61 | + return trt_tensors # all ints |
| 62 | + |
| 63 | + # promote remaining ints to TRT consts before concat |
| 64 | + for i, t in enumerate(trt_tensors): |
| 65 | + if isinstance(t, int): |
| 66 | + const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) |
| 67 | + set_layer_name(const, target, f"{name}_static_{i}_const") |
| 68 | + trt_tensors[i] = const.get_output(0) |
| 69 | + |
| 70 | + concat = ctx.net.add_concatenation(trt_tensors) |
| 71 | + concat.axis = concat_axis |
| 72 | + set_layer_name(concat, target, f"{name}_concat") |
| 73 | + return concat.get_output(0) |
| 74 | + |
| 75 | + |
19 | 76 | def cat( |
20 | 77 | ctx: ConversionContext, |
21 | 78 | target: Target, |
@@ -54,9 +111,16 @@ def cat( |
54 | 111 | ) |
55 | 112 | trt_casted_inputs.append(casted_input) |
56 | 113 | trt_inputs = trt_casted_inputs |
| 114 | + else: |
| 115 | + trt_promoted_type = None |
57 | 116 |
|
58 | | - concat_layer = ctx.net.add_concatenation(trt_inputs) |
59 | 117 | dim = get_positive_dim(dim, len(trt_inputs[0].shape)) |
60 | | - concat_layer.axis = dim |
61 | | - set_layer_name(concat_layer, target, f"{name}_gather", source_ir) |
62 | | - return concat_layer.get_output(0) |
| 118 | + return unify_trt_tensors( |
| 119 | + ctx, |
| 120 | + target, |
| 121 | + name, |
| 122 | + trt_inputs, |
| 123 | + concat_axis=dim, |
| 124 | + cast_dtype=trt_promoted_type, |
| 125 | + force_trt_output=True, |
| 126 | + ) |
0 commit comments