Skip to content

Commit aee0801

Browse files
andywagfacebook-github-bot
authored andcommitted
Make torchx scheduler opts support enums (pytorch#870)
Summary: Added Support for Enumerations in scheduler options. This is a bit more generic as it takes in a creator function which converts a CfgVal to another type. There are some limitations on how general it can be based on how the typing is setup. This also makes it tricky to make a convenience function to handle only Enums. Differential Revision: D55551233
1 parent 45bc4ce commit aee0801

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

torchx/schedulers/test/api_test.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import unittest
1212
from datetime import datetime
13+
from enum import Enum
1314
from typing import Iterable, List, Mapping, Optional, TypeVar, Union
1415
from unittest.mock import MagicMock, patch
1516

@@ -36,6 +37,16 @@
3637
T = TypeVar("T")
3738

3839

40+
class EnumConfig(str, Enum):
41+
option1 = "option1"
42+
option2 = "option2"
43+
44+
45+
class IntEnumConfig(int, Enum):
46+
option1 = 1
47+
option2 = 2
48+
49+
3950
class SchedulerTest(unittest.TestCase):
4051
class MockScheduler(Scheduler[T], WorkspaceMixin[None]):
4152
def __init__(self, session_name: str) -> None:
@@ -78,6 +89,21 @@ def list(self) -> List[ListAppResponse]:
7889
def _run_opts(self) -> runopts:
7990
opts = runopts()
8091
opts.add("foo", type_=str, required=True, help="required option")
92+
opts.add(
93+
"bar",
94+
type_=EnumConfig,
95+
required=True,
96+
help=f"Test Enum Config {[m.name for m in EnumConfig]}",
97+
creator=lambda x: EnumConfig(x),
98+
),
99+
opts.add(
100+
"ienum",
101+
type_=IntEnumConfig,
102+
required=False,
103+
help=f"Test Enum Config {[m.name for m in IntEnumConfig]}",
104+
creator=lambda x: IntEnumConfig(x),
105+
),
106+
81107
return opts
82108

83109
def resolve_resource(self, resource: Union[str, Resource]) -> Resource:
@@ -92,12 +118,16 @@ def test_invalid_run_cfg(self) -> None:
92118
scheduler_mock = SchedulerTest.MockScheduler("test_session")
93119
app_mock = MagicMock()
94120

121+
empty_cfg = {}
95122
with self.assertRaises(InvalidRunConfigException):
96-
empty_cfg = {}
97123
scheduler_mock.submit(app_mock, empty_cfg)
98124

125+
bad_type_cfg = {"foo": 100}
126+
with self.assertRaises(InvalidRunConfigException):
127+
scheduler_mock.submit(app_mock, bad_type_cfg)
128+
129+
bad_type_cfg = {"foo": "here", "bar": "temp"}
99130
with self.assertRaises(InvalidRunConfigException):
100-
bad_type_cfg = {"foo": 100}
101131
scheduler_mock.submit(app_mock, bad_type_cfg)
102132

103133
def test_submit_workspace(self) -> None:
@@ -110,7 +140,7 @@ def test_submit_workspace(self) -> None:
110140

111141
scheduler_mock = SchedulerTest.MockScheduler("test_session")
112142

113-
cfg = {"foo": "asdf"}
143+
cfg = {"foo": "asdf", "bar": "option1", "ienum": 1}
114144
scheduler_mock.submit(app, cfg, workspace="some_workspace")
115145
self.assertEqual(app.roles[0].image, "some_workspace")
116146

@@ -131,7 +161,7 @@ def test_role_preproc_called(self) -> None:
131161
app_mock = MagicMock()
132162
app_mock.roles = [MagicMock()]
133163

134-
cfg = {"foo": "bar"}
164+
cfg = {"foo": "bar", "bar": "option2"}
135165
scheduler_mock.submit_dryrun(app_mock, cfg)
136166
role_mock = app_mock.roles[0]
137167
role_mock.pre_proc.assert_called_once()

torchx/specs/api.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ class runopt:
702702
opt_type: Type[CfgVal]
703703
is_required: bool
704704
help: str
705+
creator: Optional[Callable[[CfgVal], CfgVal]] = None
705706

706707

707708
class runopts:
@@ -793,13 +794,23 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
793794
)
794795

795796
# check type (None matches all types)
796-
if val is not None and not runopts.is_type(val, runopt.opt_type):
797-
raise InvalidRunConfigException(
798-
f"Run option: {cfg_key}, must be of type: {get_type_name(runopt.opt_type)},"
799-
f" but was: {val} ({type(val).__name__})",
800-
cfg_key,
801-
cfg,
802-
)
797+
if val is not None:
798+
if runopt.creator is not None:
799+
try:
800+
val = runopt.creator(val)
801+
except Exception as e:
802+
raise InvalidRunConfigException(
803+
f"Run option failed with error: {e}",
804+
cfg_key,
805+
cfg,
806+
)
807+
if not runopts.is_type(val, runopt.opt_type):
808+
raise InvalidRunConfigException(
809+
f"Run option: {cfg_key}, must be of type: {get_type_name(runopt.opt_type)},"
810+
f" but was: {val} ({type(val).__name__})",
811+
cfg_key,
812+
cfg,
813+
)
803814

804815
# not required and not set, set to default
805816
if val is None:
@@ -892,6 +903,7 @@ def add(
892903
help: str,
893904
default: CfgVal = None,
894905
required: bool = False,
906+
creator: Optional[Callable[[CfgVal], CfgVal]] = None,
895907
) -> None:
896908
"""
897909
Adds the ``config`` option with the given help string and ``default``
@@ -909,7 +921,7 @@ def add(
909921
f" Given: {default} ({type(default).__name__})"
910922
)
911923

912-
self._opts[cfg_key] = runopt(default, type_, required, help)
924+
self._opts[cfg_key] = runopt(default, type_, required, help, creator)
913925

914926
def update(self, other: "runopts") -> None:
915927
self._opts.update(other._opts)

0 commit comments

Comments
 (0)