File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed
src/sagemaker_pytorch_container Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ def train(training_environment):
4747
4848 _set_nccl_environment (training_environment .network_interface_name )
4949
50- _set_distributed_environment (training_environment . hosts )
50+ _set_distributed_environment (training_environment )
5151
5252 mpi_enabled = training_environment .additional_framework_parameters .get ('sagemaker_mpi_enabled' )
5353
@@ -88,15 +88,15 @@ def _dns_lookup(host):
8888 return socket .gethostbyname (host )
8989
9090
91- def _set_distributed_environment (hosts ):
91+ def _set_distributed_environment (training_env ):
9292 """Set environment variable for distributed training.
9393
9494 Args:
9595 hosts: list of hosts that are used for training.
9696 """
9797 # According to https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
9898 # hosts are sorted lexicographically.
99- os .environ ['MASTER_ADDR' ] = hosts [ 0 ]
99+ os .environ ['MASTER_ADDR' ] = training_env . master_hostname
100100 os .environ ['MASTER_PORT' ] = MASTER_PORT
101101
102102
Original file line number Diff line number Diff line change @@ -31,6 +31,7 @@ def fixture_training_env():
3131 env = MagicMock ()
3232 env .current_host = 'algo-1'
3333 env .hosts = ['algo-1' ]
34+ env .master_hostname = 'algo-1'
3435 env .network_interface_name = 'eth0'
3536 tmp = tempfile .mkdtemp ()
3637 os .makedirs (os .path .join (tmp , 'model' ))
@@ -96,7 +97,7 @@ def test_environment(training_env):
9697
9798 # distributed training specific environment
9899 assert MASTER_PORT == os .environ ['MASTER_PORT' ]
99- assert training_env .hosts [ 0 ] == os .environ ['MASTER_ADDR' ]
100+ assert training_env .master_hostname == os .environ ['MASTER_ADDR' ]
100101
101102 # nccl specific environment
102103 assert training_env .network_interface_name == os .environ ['NCCL_SOCKET_IFNAME' ]
You can’t perform that action at this time.
0 commit comments