File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed
src/sagemaker_pytorch_container Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -81,12 +81,19 @@ def train(training_environment):
8181 runner_type = runner .PyTorchXLARunnerType
8282 logger .info ('Invoking PT-XLA Runner' )
8383 logger .info ('Invoking user training script.' )
84+
85+ # get capture_error from framework parameters
86+ capture_error = True
87+ if training_environment .additional_framework_parameters .get ("sagemaker_toolkit_native_launcher_enabled" ):
88+ capture_error = False
89+ logger .info (f'capture_error is { capture_error } . Default is True' )
90+
8491 try :
8592 entry_point .run (uri = training_environment .module_dir ,
8693 user_entry_point = training_environment .user_entry_point ,
8794 args = training_environment .to_cmd_args (),
8895 env_vars = training_environment .to_env_vars (),
89- capture_error = True ,
96+ capture_error = capture_error ,
9097 runner_type = runner_type )
9198 except errors .ExecuteUserScriptError as err :
9299 message = str (err )
Original file line number Diff line number Diff line change @@ -74,6 +74,20 @@ def test_train(run_entry_point, training_env):
7474 runner_type = runner .ProcessRunnerType )
7575
7676
77+ @patch ('sagemaker_training.entry_point.run' )
78+ @patch ('socket.gethostbyname' , MagicMock ())
79+ def test_train_no_capture_error (run_entry_point , training_env ):
80+ training_env .additional_framework_parameters ["sagemaker_toolkit_native_launcher_enabled" ] = True
81+ train (training_env )
82+
83+ run_entry_point .assert_called_with (uri = training_env .module_dir ,
84+ user_entry_point = training_env .user_entry_point ,
85+ args = training_env .to_cmd_args (),
86+ env_vars = training_env .to_env_vars (),
87+ capture_error = False ,
88+ runner_type = runner .ProcessRunnerType )
89+
90+
7791@patch ("sagemaker_training.entry_point.run" )
7892@patch ('socket.gethostbyname' , MagicMock ())
7993def test_train_smdataparallel (run_module , training_env ):
You can’t perform that action at this time.
0 commit comments