Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/dynamo/compile_with_dynamic_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging

import torch
import torch.nn as nn
import torch_tensorrt

logging.basicConfig(level=logging.DEBUG)

torch.manual_seed(0)


class ExpandReshapeModel(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.embed_dim = embed_dim
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = self.qkv_proj(x)
reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)
return reshaped_qkv


model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()

# 1. JIT: torch.compile
x1 = x.clone()
torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out1 = trt_module(x1)

# 2. AOT: torch_tensorrt.compile
x2 = x.clone()
example_input = torch_tensorrt.Input(
min_shape=[1, 196, 768],
opt_shape=[4, 196, 768],
max_shape=[32, 196, 768],
dtype=torch.float32,
)
trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input)
out2 = trt_module(x2)

# 3. AOT: torch.export + Dynamo compile
x3 = x.clone()
bs = torch.export.Dim("bs", min=1, max=32)
dynamic_shapes = {"x": {0: bs}}
exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes)
trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,))
out3 = trt_module(x3)

assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
assert torch.allclose(out2, out3)
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def remove_sym_nodes(
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
gm = replace_symint_with_sym_size(gm)

# Extract SymInt placeholder Tensors
placeholder_idx_sym_ints = [
(idx, node)
Expand All @@ -36,3 +38,42 @@ def remove_sym_nodes(
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")

return gm


def replace_symint_with_sym_size(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Replace SymInt placeholders with sym_size nodes"""
# Find all SymInt placeholders and their args
symint_node_arg_dict = {}
for node in gm.graph.nodes:
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.SymInt)
):
ga = node.meta["grapharg"]
src = ga.source # TensorPropertySource
symint_node_arg_dict[node] = (src.base.local_name, src.idx)

# Replace SymInt placeholders with sym_size nodes
for node in gm.graph.nodes:
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
):
for symint_node, (arg_name, idx) in symint_node_arg_dict.items():
if node.target == "L_" + arg_name + "_":
with gm.graph.inserting_after(node):
size_node = gm.graph.call_function(
torch.ops.aten.sym_size, args=(node, idx)
)
symint_node.replace_all_uses_with(size_node)
# the symint_node is not used anymore, it will be removed in the outside of the function

gm.graph.lint()
gm.recompile()
logger.debug(f"Added sym_size nodes for SymInt placeholders:\n{gm.graph}")

return gm
Loading