File tree Expand file tree Collapse file tree 2 files changed +26
-1
lines changed
src/sagemaker_pytorch_container Expand file tree Collapse file tree 2 files changed +26
-1
lines changed Original file line number Diff line number Diff line change 2121
2222MASTER_PORT = '7777'
2323LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled'
24+ LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
25+ LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
2426
2527logger = logging .getLogger (__name__ )
2628
@@ -49,7 +51,11 @@ def train(training_environment):
4951
5052 _set_distributed_environment (training_environment )
5153
52- mpi_enabled = training_environment .additional_framework_parameters .get ('sagemaker_mpi_enabled' )
54+ mpi_enabled = training_environment .additional_framework_parameters .get (LAUNCH_MPI_ENV_NAME )
55+
56+ pytorch_ddp_enabled = training_environment .additional_framework_parameters .get (
57+ LAUNCH_PYTORCH_DDP_ENV_NAME , False
58+ )
5359
5460 smdataparallel_enabled = training_environment .additional_framework_parameters .get (
5561 LAUNCH_SMDATAPARALLEL_ENV_NAME , False
@@ -60,6 +66,9 @@ def train(training_environment):
6066 if training_environment .current_instance_group in training_environment .distribution_instance_groups :
6167 if mpi_enabled :
6268 runner_type = runner .MPIRunnerType
69+ elif pytorch_ddp_enabled :
70+ runner_type = runner .SMDataParallelRunnerType
71+ logger .info ('Invoking SMDataParallel for native PT DDP job' )
6372 elif smdataparallel_enabled :
6473 runner_type = runner .SMDataParallelRunnerType
6574 logger .info ('Invoking SMDataParallel' )
Original file line number Diff line number Diff line change @@ -90,6 +90,22 @@ def test_train_smdataparallel(run_module, training_env):
9090 )
9191
9292
93+ @patch ("sagemaker_training.entry_point.run" )
94+ @patch ('socket.gethostbyname' , MagicMock ())
95+ def test_train_pytorch_ddp (run_module , training_env ):
96+ training_env .additional_framework_parameters ["sagemaker_pytorch_ddp_enabled" ] = True
97+
98+ train (training_env )
99+ run_module .assert_called_with (
100+ uri = training_env .module_dir ,
101+ user_entry_point = training_env .user_entry_point ,
102+ args = training_env .to_cmd_args (),
103+ env_vars = training_env .to_env_vars (),
104+ capture_error = True ,
105+ runner_type = runner .SMDataParallelRunnerType ,
106+ )
107+
108+
93109@patch ('sagemaker_training.entry_point.run' , MagicMock ())
94110@patch ('socket.gethostbyname' , MagicMock ())
95111def test_environment (training_env ):
You can’t perform that action at this time.
0 commit comments