diff --git a/examples/dynamo/compile_with_dynamic_inputs.py b/examples/dynamo/compile_with_dynamic_inputs.py new file mode 100644 index 0000000000..b2f61c6caa --- /dev/null +++ b/examples/dynamo/compile_with_dynamic_inputs.py @@ -0,0 +1,58 @@ +import logging + +import torch +import torch.nn as nn +import torch_tensorrt + +logging.basicConfig(level=logging.DEBUG) + +torch.manual_seed(0) + + +class ExpandReshapeModel(nn.Module): + def __init__(self, embed_dim: int): + super().__init__() + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.embed_dim = embed_dim + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3) + + def forward(self, x: torch.Tensor): + batch_size = x.shape[0] + cls_token = self.cls_token.expand(batch_size, -1, -1) + x = torch.cat([cls_token, x], dim=1) + x = self.qkv_proj(x) + reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1) + return reshaped_qkv + + +model = ExpandReshapeModel(embed_dim=768).cuda().eval() +x = torch.randn(4, 196, 768).cuda() + +# 1. JIT: torch.compile +x1 = x.clone() +torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32) +trt_module = torch.compile(model, backend="tensorrt") +out1 = trt_module(x1) + +# 2. AOT: torch_tensorrt.compile +x2 = x.clone() +example_input = torch_tensorrt.Input( + min_shape=[1, 196, 768], + opt_shape=[4, 196, 768], + max_shape=[32, 196, 768], + dtype=torch.float32, +) +trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input) +out2 = trt_module(x2) + +# 3. AOT: torch.export + Dynamo compile +x3 = x.clone() +bs = torch.export.Dim("bs", min=1, max=32) +dynamic_shapes = {"x": {0: bs}} +exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes) +trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,)) +out3 = trt_module(x3) + +assert torch.allclose(out1, out2) +assert torch.allclose(out1, out3) +assert torch.allclose(out2, out3) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py index 9f69572059..4af92e7afd 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py @@ -15,6 +15,8 @@ def remove_sym_nodes( """Remove sym_int placeholders which get inserted due to torch.compile's dynamic=True behavior """ + gm = replace_symint_with_sym_size(gm) + # Extract SymInt placeholder Tensors placeholder_idx_sym_ints = [ (idx, node) @@ -36,3 +38,42 @@ def remove_sym_nodes( logger.debug(f"Removed SymInt placeholders:\n{gm.graph}") return gm + + +def replace_symint_with_sym_size( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Replace SymInt placeholders with sym_size nodes""" + # Find all SymInt placeholders and their args + symint_node_arg_dict = {} + for node in gm.graph.nodes: + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.SymInt) + ): + ga = node.meta["grapharg"] + src = ga.source # TensorPropertySource + symint_node_arg_dict[node] = (src.base.local_name, src.idx) + + # Replace SymInt placeholders with sym_size nodes + for node in gm.graph.nodes: + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ): + for symint_node, (arg_name, idx) in symint_node_arg_dict.items(): + if node.target == "L_" + arg_name + "_": + with gm.graph.inserting_after(node): + size_node = gm.graph.call_function( + torch.ops.aten.sym_size, args=(node, idx) + ) + symint_node.replace_all_uses_with(size_node) + # the symint_node is not used anymore, it will be removed in the outside of the function + + gm.graph.lint() + gm.recompile() + logger.debug(f"Added sym_size nodes for SymInt placeholders:\n{gm.graph}") + + return gm