diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 67be775d3d..cbde956a88 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1084,10 +1084,11 @@ def preserve_module_specs( output_node = list(partitioned_module.graph.nodes)[-1] for arg in output_node.args: - target = arg[0].target - if "_run_on_acc" not in str(target): - continue - getattr(partitioned_module, target).set_output_tensors_as_unowned(True) + for output in arg: + target = output.target + if "_run_on_acc" not in str(target): + continue + getattr(partitioned_module, target).set_output_tensors_as_unowned(True) # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: diff --git a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py index 35dab61161..a9f8cfbbe5 100644 --- a/tests/py/dynamo/runtime/test_pre_allocated_outputs.py +++ b/tests/py/dynamo/runtime/test_pre_allocated_outputs.py @@ -125,7 +125,7 @@ def forward(self, x): ) torch._dynamo.reset() - def test_pre_allocated_outputs_unowned_outputs(self): + def test_pre_allocated_outputs_unowned_outputs_py_api_check_no_realloc(self): class SampleModel(torch.nn.Module): def forward(self, x): return torch.softmax(x * 7 + 2, dim=0) @@ -146,21 +146,256 @@ def forward(self, x): ) with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): - optimized_model(inputs[0]) + _ = optimized_model(inputs[0]) output_tensors = [ trt_mod.pre_allocated_outputs for name, trt_mod in optimized_model.named_children() if "_run_on_acc" in name ] - optimized_model(inputs[0]) + _ = optimized_model(inputs[0]) new_output_tensors = [ trt_mod.pre_allocated_outputs for name, trt_mod in optimized_model.named_children() if "_run_on_acc" in name ] + + # Run to run, output of intermediate engine is not reallocated self.assertTrue(output_tensors[0] is new_output_tensors[0]) + # Run to run, output of output engine is reallocated self.assertTrue(output_tensors[1] is not new_output_tensors[1]) + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_api_check( + self, _, use_python_runtime + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + _ = optimized_model(inputs[0]) + self.assertTrue( + all( + seen == expected + for seen, expected in zip( + [ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned(), + ], + [False, True], + ) + ) + ) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax(x * 7 + 2, dim=0) + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + torch_res = model(inputs[0]) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res_1 = optimized_model(inputs[0]) + res_2 = optimized_model(inputs[0]) + + # Results are correct + torch.testing.assert_close( + torch_res, + res_1, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + + # Results between runs are identical + torch.testing.assert_close( + res_1, + res_2, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + + torch._dynamo.reset() + + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_py_api_check_no_realloc( + self, + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res1 = optimized_model(inputs[0]) + output_tensors = [ + [t.data_ptr() for t in trt_mod.pre_allocated_outputs] + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + + _ = optimized_model(inputs[0]) + new_output_tensors = [ + [t.data_ptr() for t in trt_mod.pre_allocated_outputs] + for name, trt_mod in optimized_model.named_children() + if "_run_on_acc" in name + ] + + # Run to run, output of intermediate engine is reallocated + self.assertTrue(output_tensors[0] != new_output_tensors[0]) + # Run to run, output of output engine is reallocated + self.assertTrue(output_tensors[1] != new_output_tensors[1]) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_multiple_outputs_api_check( + self, _, use_python_runtime + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + _ = optimized_model(inputs[0]) + self.assertTrue( + all( + seen == expected + for seen, expected in zip( + [ + optimized_model._run_on_acc_0.are_output_tensors_unowned(), + optimized_model._run_on_acc_2.are_output_tensors_unowned(), + ], + [True, True], + ) + ) + ) + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_pre_allocated_outputs_unowned_outputs_multi_outputs( + self, _, use_python_runtime + ): + class SampleModel(torch.nn.Module): + def forward(self, x): + y = torch.ops.aten.mul(x, 7) + z = torch.ops.aten.add(y, 2) + a = torch.ops.aten.softmax(z, dim=0) + return y, z, a + + model = SampleModel().eval().cuda() + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=use_python_runtime, + torch_executed_ops={torch.ops.aten.add.Tensor}, + ) + + with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): + res_1 = optimized_model(inputs[0]) + res_2 = optimized_model(inputs[0]) + + torch.testing.assert_close( + res_1, + res_2, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset()