Skip to content

Commit 1f79761

Browse files
authored
(bugfix)(torchx/components) handle BinOp style optional when decoding component arguments
Differential Revision: D82856462 Pull Request resolved: #1123
1 parent 6ab9f69 commit 1f79761

File tree

4 files changed

+28
-9
lines changed

4 files changed

+28
-9
lines changed

torchx/specs/builders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-strict
7+
# pyre-unsafe
88

99
import argparse
1010
import inspect
@@ -39,7 +39,7 @@ def _create_args_parser(
3939

4040

4141
def _create_args_parser_from_parameters(
42-
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
42+
cmpnt_fn: Callable[..., AppDef],
4343
parameters: Mapping[str, inspect.Parameter],
4444
cmpnt_defaults: Optional[Dict[str, str]] = None,
4545
config: Optional[Dict[str, Any]] = None,
@@ -120,7 +120,7 @@ def _merge_config_values_with_args(
120120

121121

122122
def parse_args(
123-
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
123+
cmpnt_fn: Callable[..., AppDef],
124124
cmpnt_args: List[str],
125125
cmpnt_defaults: Optional[Dict[str, Any]] = None,
126126
config: Optional[Dict[str, Any]] = None,
@@ -149,7 +149,7 @@ def parse_args(
149149

150150

151151
def component_args_from_str(
152-
cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type
152+
cmpnt_fn: Callable[..., AppDef],
153153
cmpnt_args: list[str],
154154
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
155155
config: Optional[Dict[str, Any]] = None,
@@ -238,7 +238,7 @@ def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
238238

239239

240240
def materialize_appdef(
241-
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
241+
cmpnt_fn: Callable[..., AppDef],
242242
cmpnt_args: List[str],
243243
cmpnt_defaults: Optional[Dict[str, Any]] = None,
244244
config: Optional[Dict[str, Any]] = None,

torchx/specs/finder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from inspect import getmembers, isfunction
1818
from pathlib import Path
1919
from types import ModuleType
20-
from typing import Any, Callable, Dict, Generator, List, Optional, Union
20+
from typing import Callable, Dict, Generator, List, Optional, Union
21+
22+
from torchx.specs import AppDef
2123

2224
from torchx.specs.file_linter import (
2325
ComponentFunctionValidator,
@@ -59,8 +61,7 @@ class _Component:
5961
description: str
6062
fn_name: str
6163

62-
# pyre-ignore[4] TODO temporary until PipelineDef is decoupled and can be exposed as type to OSS
63-
fn: Callable[..., Any]
64+
fn: Callable[..., AppDef]
6465

6566
validation_errors: List[str]
6667

torchx/specs/test/builders_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def example_test_complex_fn(
9393
nnodes: int = 4,
9494
first_arg: Optional[str] = None,
9595
nested_arg: Optional[Dict[str, List[str]]] = None,
96+
env: dict[str, str] | None = None,
9697
*roles_args: str,
9798
) -> AppDef:
9899
"""Creates complex application, testing all possible complex types
@@ -127,6 +128,7 @@ def example_test_complex_fn(
127128
args=args,
128129
resource=Resource(cpu=cpus, gpu=gpus, memMB=1),
129130
num_replicas=nnodes,
131+
env=env or {},
130132
)
131133
roles.append(role)
132134
return AppDef(app_name, roles)
@@ -193,6 +195,7 @@ def _get_expected_app_with_default(self) -> AppDef:
193195
4,
194196
None,
195197
None,
198+
None,
196199
*role_args,
197200
)
198201

@@ -220,6 +223,7 @@ def _get_expected_app_with_all_args(self) -> AppDef:
220223
8,
221224
"first_arg",
222225
None,
226+
{"FOO": "BAR", "HELLO": "WORLD"},
223227
*role_args,
224228
)
225229

@@ -240,6 +244,8 @@ def _get_app_args(self) -> List[str]:
240244
"8",
241245
"--first_arg",
242246
"first_arg",
247+
"--env",
248+
"FOO=BAR,HELLO=WORLD",
243249
"--",
244250
*role_args,
245251
]
@@ -256,6 +262,7 @@ def _get_expected_app_with_nested_objects(self) -> AppDef:
256262
8,
257263
"first_arg",
258264
defaults,
265+
None,
259266
*role_args,
260267
)
261268

torchx/util/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import inspect
1010
import re
11+
from types import UnionType
1112
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
1213

1314

@@ -234,10 +235,20 @@ def decode_optional(param_type: Any) -> Any:
234235
If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
235236
Otherwise returns ``param_type``
236237
"""
238+
237239
if not hasattr(param_type, "__origin__"):
238-
return param_type
240+
if isinstance(param_type, UnionType):
241+
# handle BinOp style Optional (e.g. `T | None`)
242+
if len(param_type.__args__) == 2 and param_type.__args__[1] is type(None):
243+
return param_type.__args__[0]
244+
else:
245+
return param_type
246+
else:
247+
return param_type
248+
239249
if param_type.__origin__ is not Union:
240250
return param_type
251+
241252
args = param_type.__args__
242253
if len(args) == 2 and args[1] is type(None):
243254
return args[0]

0 commit comments

Comments
 (0)