From e3909ace8bd9c625d7b08ce94beff1ec4b9d3261 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Tue, 6 Jan 2026 00:38:14 +0000 Subject: [PATCH 1/2] tests: Adding additional test cases for the unowned tensor feature --- .../runtime/test_pre_allocated_outputs.py | 232 +++++++++++++++++- 1 file changed, 229 insertions(+), 3 deletions(-) diff --git a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py index 35dab61161..95faf6e2bd 100644 --- a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py +++ b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py @@ -125,7 +125,7 @@ def forward(self, x): ) torch._dynamo.reset() - def test_pre_allocated_outputs_unowned_outputs(self): + def test_pre_allocated_outputs_unowned_outputs_py_api_check_no_realloc(self): class SampleModel(torch.nn.Module): def forward(self, x): return torch.softmax(x * 7 + 2, dim=0) @@ -146,21 +146,247 @@ def forward(self, x): ) with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): - optimized_model(inputs[0]) + _ = optimized_model(inputs[0]) output_tensors = [ trt_mod.pre_allocated_outputs for name, trt_mod in optimized_model.named_children() if "_run_on_acc" in name ] - optimized_model(inputs[0]) + _ = optimized_model(inputs[0]) new_output_tensors = [ trt_mod.pre_allocated_outputs for name, trt_mod in optimized_model.named_children() if "_run_on_acc" in name ] + + # Run to run, output of intermediate engine is not reallocated self.assertTrue(output_tensors[0] is new_output_tensors[0]) + # Run to run, output of output engine is reallocated self.assertTrue(output_tensors[1] is not new_output_tensors[1]) + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_api_check(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + _ = optimized_model(inputs[0]) + if use_python_runtime: + self.assertTrue(all(seen == expected for seen, expected in zip([ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned() + ], [False, True]))) + + else: + self.assertTrue(all(seen == expected for seen, expected in zip([ + optimized_model._run_on_acc_0.engine.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.engine.are_output_tensors_unowned() + ], [False, True]))) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + torch_res = model(inputs[0]) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res_1 = optimized_model(inputs[0]) + res_2 = optimized_model(inputs[0]) + + # Results are correct + torch.testing.assert_close( + torch_res, + res_1, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + + # Results between runs are identical + torch.testing.assert_close( + res_1, + res_2, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + + torch._dynamo.reset() + + + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_py_api_check_no_realloc(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res1 = optimized_model(inputs[0]) + output_tensors = [ + [t.data_ptr() for t in trt_mod.pre_allocated_outputs] + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + + _ = optimized_model(inputs[0]) + new_output_tensors = [ + [t.data_ptr() for t in trt_mod.pre_allocated_outputs] + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + + # Run to run, output of intermediate engine is reallocated + self.assertTrue(output_tensors[0] != new_output_tensors[0]) + # Run to run, output of output engine is reallocated + self.assertTrue(output_tensors[1] != new_output_tensors[1]) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_api_check(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + _ = optimized_model(inputs[0]) + if use_python_runtime: + self.assertTrue(all(seen == expected for seen, expected in zip([ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned() + ], [True, True]))) + + else: + self.assertTrue(all(seen == expected for seen, expected in zip([ + optimized_model._run_on_acc_0.engine.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.engine.are_output_tensors_unowned() + ], [True, True]))) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_multi_outputs(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res_1 = optimized_model(inputs[0]) + res_2 = optimized_model(inputs[0]) + + torch.testing.assert_close( + res_1, + res_2, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset() From 58fcc3e33b1a9a8c076771decac120150e709850 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 6 Jan 2026 01:05:05 +0000 Subject: [PATCH 2/2] Changed the tests and fixed a bug --- py/torch_tensorrt/dynamo/_compiler.py | 9 +-- .../runtime/test_pre_allocated_outputs.py | 71 +++++++++++-------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 67be775d3d..cbde956a88 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1084,10 +1084,11 @@ def preserve_module_specs( output_node = list(partitioned_module.graph.nodes)[-1] for arg in output_node.args: - target = arg[0].target - if "_run_on_acc" not in str(target): - continue - getattr(partitioned_module, target).set_output_tensors_as_unowned(True) + for output in arg: + target = output.target + if "_run_on_acc" not in str(target): + continue + getattr(partitioned_module, target).set_output_tensors_as_unowned(True) # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: diff --git a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py index 95faf6e2bd..a9f8cfbbe5 100644 --- a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py +++ b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py @@ -165,12 +165,14 @@ def forward(self, x): self.assertTrue(output_tensors[1] is not new_output_tensors[1]) @parameterized.expand( - [ - ("python_runtime", True), - ("cpp_runtime", False), - ] + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_pre_allocated_outputs_unowned_outputs_api_check(self, _, use_python_runtime): + def test_pre_allocated_outputs_unowned_outputs_api_check( + self, _, use_python_runtime + ): class SampleModel(torch.nn.Module): def forward(self, x): return torch.softmax(x * 7 + 2, dim=0) @@ -192,17 +194,18 @@ def forward(self, x): with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): _ = optimized_model(inputs[0]) - if use_python_runtime: - self.assertTrue(all(seen == expected for seen, expected in zip([ - optimized_model._run_on_acc_0.are_output_tensors_unowned(), - optimized_model._run_on_acc_2.are_output_tensors_unowned() - ], [False, True]))) - - else: - self.assertTrue(all(seen == expected for seen, expected in zip([ - optimized_model._run_on_acc_0.engine.are_output_tensors_unowned(), - optimized_model._run_on_acc_2.engine.are_output_tensors_unowned() - ], [False, True]))) + self.assertTrue( + all( + seen == expected + for seen, expected in zip( + [ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned(), + ], + [False, True], + ) + ) + ) @parameterized.expand( [ @@ -258,8 +261,9 @@ def forward(self, x): torch._dynamo.reset() - - def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_py_api_check_no_realloc(self): + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_py_api_check_no_realloc( + self, + ): class SampleModel(torch.nn.Module): def forward(self, x): y = torch.ops.aten.mul(x, 7) @@ -308,7 +312,9 @@ def forward(self, x): ("cpp_runtime", False), ] ) - def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_api_check(self, _, use_python_runtime): + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_api_check( + self, _, use_python_runtime + ): class SampleModel(torch.nn.Module): def forward(self, x): y = torch.ops.aten.mul(x, 7) @@ -333,17 +339,18 @@ def forward(self, x): with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): _ = optimized_model(inputs[0]) - if use_python_runtime: - self.assertTrue(all(seen == expected for seen, expected in zip([ - optimized_model._run_on_acc_0.are_output_tensors_unowned(), - optimized_model._run_on_acc_2.are_output_tensors_unowned() - ], [True, True]))) - - else: - self.assertTrue(all(seen == expected for seen, expected in zip([ - optimized_model._run_on_acc_0.engine.are_output_tensors_unowned(), - optimized_model._run_on_acc_2.engine.are_output_tensors_unowned() - ], [True, True]))) + self.assertTrue( + all( + seen == expected + for seen, expected in zip( + [ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned(), + ], + [True, True], + ) + ) + ) @parameterized.expand( [ @@ -351,7 +358,9 @@ def forward(self, x): ("cpp_runtime", False), ] ) - def test_pre_allocated_outputs_unowned_outputs_multi_outputs(self, _, use_python_runtime): + def test_pre_allocated_outputs_unowned_outputs_multi_outputs( + self, _, use_python_runtime + ): class SampleModel(torch.nn.Module): def forward(self, x): y = torch.ops.aten.mul(x, 7)