Skip to content

Commit 3fcf398

Browse files
committed
addressing review comment- unifying the shape functionality for upsample with concat
1 parent 8db4c74 commit 3fcf398

File tree

3 files changed

+125
-8
lines changed

3 files changed

+125
-8
lines changed

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence, Union
1+
from typing import List, Optional, Sequence, Union
22

33
import numpy as np
44
import tensorrt as trt
@@ -16,6 +16,63 @@
1616
)
1717

1818

19+
def unify_trt_tensors(
20+
ctx: ConversionContext,
21+
target: Target,
22+
name: str,
23+
inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]],
24+
concat_axis: int,
25+
cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None,
26+
force_trt_output: bool = False,
27+
) -> Union[TRTTensor, List[int]]:
28+
"""
29+
Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic.
30+
31+
Args:
32+
ctx: TensorRT conversion context.
33+
target: FX target for naming.
34+
name: Base name for layers.
35+
inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors.
36+
concat_axis: Axis along which to concatenate tensors if dynamic.
37+
cast_dtype: Optional target dtype for casting TRT tensors.
38+
force_trt_output: If True, return TRT tensor even if all inputs are static ints.
39+
"""
40+
has_dynamic = any(not isinstance(x, int) for x in inputs)
41+
trt_tensors = []
42+
43+
for i, x in enumerate(inputs):
44+
# convert to TRTTensor
45+
if isinstance(x, TRTTensor):
46+
t = x
47+
elif isinstance(x, int) and not has_dynamic and not force_trt_output:
48+
t = x # pure static path
49+
else:
50+
t = ctx.net.add_constant((1,), np.array([x], dtype=np.int32))
51+
set_layer_name(t, target, f"{name}_dim{i}_const")
52+
t = t.get_output(0)
53+
54+
# optional cast
55+
if cast_dtype and isinstance(t, TRTTensor):
56+
t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}")
57+
58+
trt_tensors.append(t)
59+
60+
if not has_dynamic and not force_trt_output:
61+
return trt_tensors # all ints
62+
63+
# promote remaining ints to TRT consts before concat
64+
for i, t in enumerate(trt_tensors):
65+
if isinstance(t, int):
66+
const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32))
67+
set_layer_name(const, target, f"{name}_static_{i}_const")
68+
trt_tensors[i] = const.get_output(0)
69+
70+
concat = ctx.net.add_concatenation(trt_tensors)
71+
concat.axis = concat_axis
72+
set_layer_name(concat, target, f"{name}_concat")
73+
return concat.get_output(0)
74+
75+
1976
def cat(
2077
ctx: ConversionContext,
2178
target: Target,
@@ -54,9 +111,16 @@ def cat(
54111
)
55112
trt_casted_inputs.append(casted_input)
56113
trt_inputs = trt_casted_inputs
114+
else:
115+
trt_promoted_type = None
57116

58-
concat_layer = ctx.net.add_concatenation(trt_inputs)
59117
dim = get_positive_dim(dim, len(trt_inputs[0].shape))
60-
concat_layer.axis = dim
61-
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
62-
return concat_layer.get_output(0)
118+
return unify_trt_tensors(
119+
ctx,
120+
target,
121+
name,
122+
trt_inputs,
123+
concat_axis=dim,
124+
cast_dtype=trt_promoted_type,
125+
force_trt_output=True,
126+
)

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import List, Optional, Tuple
3+
from typing import List, Optional, Sequence, Tuple, Union
44

55
import numpy as np
66
import tensorrt as trt
@@ -159,3 +159,52 @@ def to_trt_shape_tensor(
159159

160160
# If no ITensor found, return plain list of ints
161161
return shape_list
162+
163+
164+
def collect_and_concat_trt_inputs(
165+
ctx: ConversionContext,
166+
target: Target,
167+
name: str,
168+
inputs: Sequence[Union[int, TRTTensor, torch.Tensor, np.ndarray]],
169+
concat_axis: int = 0,
170+
allow_static_return: bool = False,
171+
) -> Union[TRTTensor, List[int]]:
172+
"""
173+
Normalize a sequence of values into TRT ITensors and concatenate them.
174+
If `allow_static_return=True` and all inputs are ints, return a Python
175+
list of ints instead of creating any TRT layers.
176+
"""
177+
trt_tensors = []
178+
has_dynamic = False
179+
180+
for i, x in enumerate(inputs):
181+
if isinstance(x, TRTTensor):
182+
trt_tensors.append(x)
183+
has_dynamic = True
184+
185+
elif isinstance(x, (int, np.integer)):
186+
# keep raw for now, convert only if dynamic found
187+
trt_tensors.append(int(x))
188+
189+
else:
190+
# torch/np tensor -> TRT tensor
191+
t = get_trt_tensor(ctx, x, f"{name}_tensor_{i}")
192+
trt_tensors.append(t)
193+
has_dynamic = True
194+
195+
# fully static shape case
196+
if not has_dynamic and allow_static_return:
197+
return [int(v) for v in trt_tensors]
198+
199+
# promote remaining ints to TRT constants
200+
for i, v in enumerate(trt_tensors):
201+
if isinstance(v, int):
202+
const = ctx.net.add_constant((1,), np.array([v], dtype=np.int32))
203+
set_layer_name(const, target, f"{name}_static_dim{i}_const")
204+
trt_tensors[i] = const.get_output(0)
205+
206+
# concatenate
207+
concat = ctx.net.add_concatenation(trt_tensors)
208+
concat.axis = concat_axis
209+
set_layer_name(concat, target, f"{name}_concat")
210+
return concat.get_output(0)

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
has_dynamic_shape,
1010
set_layer_name,
1111
)
12+
from torch_tensorrt.dynamo.conversion.impl.cat import (
13+
unify_trt_tensors as unify_trt_shape_tensors,
14+
)
1215
from torch_tensorrt.dynamo.conversion.impl.shape import (
1316
get_shape_with_dynamic_shape,
14-
to_trt_shape_tensor,
1517
)
1618

1719

@@ -40,7 +42,9 @@ def upsample(
4042
)
4143
layer.set_input(1, shape)
4244
else:
43-
trt_shape = to_trt_shape_tensor(ctx, target, name, shape)
45+
trt_shape = unify_trt_shape_tensors(
46+
ctx, target, name, shape, concat_axis=0, force_trt_output=False
47+
)
4448
if isinstance(trt_shape, list):
4549
layer.shape = trt_shape
4650
else:

0 commit comments

Comments
 (0)