diff --git a/torchx/components/dist.py b/torchx/components/dist.py index 55718474d..20fc840aa 100644 --- a/torchx/components/dist.py +++ b/torchx/components/dist.py @@ -92,6 +92,7 @@ def spmd( h: str = "gpu.small", j: str = "1x1", env: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, str]] = None, max_retries: int = 0, mounts: Optional[List[str]] = None, debug: bool = False, @@ -131,6 +132,7 @@ def spmd( h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) + metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3) max_retries: the number of scheduler retries allowed rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous. Only takes effect when running multi-node. When running single node, this parameter @@ -153,6 +155,7 @@ def spmd( h=h, j=str(StructuredJArgument.parse_from(h, j)), env=env, + metadata=metadata, max_retries=max_retries, mounts=mounts, debug=debug, @@ -171,6 +174,7 @@ def ddp( memMB: int = 1024, j: str = "1x2", env: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, str]] = None, max_retries: int = 0, rdzv_port: int = 29500, rdzv_backend: str = "c10d", @@ -203,6 +207,7 @@ def ddp( h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) + metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3) max_retries: the number of scheduler retries allowed rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous. Only takes effect when running multi-node. When running single node, this parameter @@ -238,8 +243,8 @@ def ddp( # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument) rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}") - if env is None: - env = {} + env = env or {} + metadata = metadata or {} argname = StructuredNameArgument.parse_from( name=name, @@ -299,6 +304,7 @@ def ddp( mounts=specs.parse_mounts(mounts) if mounts else [], ) ], + metadata=metadata, ) diff --git a/torchx/components/test/dist_test.py b/torchx/components/test/dist_test.py index ac57a1bf0..98499346e 100644 --- a/torchx/components/test/dist_test.py +++ b/torchx/components/test/dist_test.py @@ -40,6 +40,13 @@ def test_ddp_debug(self) -> None: for k, v in _TORCH_DEBUG_FLAGS.items(): self.assertEqual(env[k], v) + def test_ddp_metadata(self) -> None: + metadata = {"key": "value"} + app = ddp(script="foo.py", metadata=metadata) + for k, v in metadata.items(): + self.assertEqual(app.metadata[k], v) + self.assertEqual(len(metadata), len(app.metadata)) + def test_ddp_rdzv_backend_static(self) -> None: app = ddp(script="foo.py", rdzv_backend="static") cmd = app.roles[0].args[1] @@ -53,6 +60,13 @@ def test_validate_spmd(self) -> None: self.validate(dist, "ddp") + def test_spmd_metadata(self) -> None: + metadata = {"key": "value"} + app = spmd(script="foo.py", metadata=metadata) + for k, v in metadata.items(): + self.assertEqual(app.metadata[k], v) + self.assertEqual(len(metadata), len(app.metadata)) + def test_spmd_call_by_module_or_script_no_name(self) -> None: appdef = spmd(script="foo/bar.py") self.assertEqual("bar", appdef.name)