diff --git a/tests/test_examples.py b/tests/test_examples.py index 03a20cd251..99f79783d4 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -532,6 +532,9 @@ def test(self): if "--use_hpu_graphs_for_inference" in extra_command_line_arguments: extra_command_line_arguments.remove("--use_hpu_graphs_for_inference") + if self.TASK_NAME == "trl-sft": + env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" + extra_command_line_arguments += self._get_dataset_args() if torch_compile and (