|
7 | 7 |
|
8 | 8 | import os
|
9 | 9 | import shutil
|
10 |
| -import sys |
11 | 10 | import tempfile
|
12 | 11 | import unittest
|
13 | 12 | from pathlib import Path
|
14 |
| -from unittest.mock import patch |
| 13 | +from unittest.mock import MagicMock, patch |
15 | 14 |
|
16 | 15 | import torchx.specs.finder as finder
|
| 16 | + |
| 17 | +from importlib_metadata import EntryPoints |
17 | 18 | from torchx.runner import get_runner
|
18 | 19 | from torchx.runtime.tracking import FsspecResultTracker
|
19 | 20 | from torchx.specs.api import AppDef, AppState, Role
|
|
26 | 27 | get_components,
|
27 | 28 | ModuleComponentsFinder,
|
28 | 29 | )
|
| 30 | +from torchx.util.test.entrypoints_test import EntryPoint_from_text |
29 | 31 | from torchx.util.types import none_throws
|
30 | 32 |
|
| 33 | +_METADATA_EPS: str = "torchx.util.entrypoints.metadata.entry_points" |
| 34 | + |
31 | 35 |
|
32 | 36 | def _test_component(name: str, role_name: str = "worker") -> AppDef:
|
33 | 37 | """
|
@@ -58,92 +62,126 @@ def invalid_component(name, role_name: str = "worker") -> AppDef:
|
58 | 62 | )
|
59 | 63 |
|
60 | 64 |
|
61 |
| -class DirComponentsFinderTest(unittest.TestCase): |
62 |
| - def test_get_components(self) -> None: |
63 |
| - components = _load_components() |
64 |
| - self.assertTrue(len(components) > 1) |
65 |
| - component = components["utils.echo"] |
66 |
| - self.assertEqual("utils.echo", component.name) |
67 |
| - self.assertEqual( |
68 |
| - "Echos a message to stdout (calls echo)", component.description |
| 65 | +class FinderTest(unittest.TestCase): |
| 66 | + _ENTRY_POINTS: EntryPoints = EntryPoints( |
| 67 | + EntryPoint_from_text( |
| 68 | + """ |
| 69 | +[torchx.components] |
| 70 | +_ = torchx.specs.test.finder_test |
| 71 | + """ |
69 | 72 | )
|
70 |
| - self.assertEqual("echo", component.fn_name) |
71 |
| - self.assertIsNotNone(component.fn) |
| 73 | + ) |
| 74 | + |
| 75 | + def tearDown(self) -> None: |
| 76 | + # clear the globals since find_component() has side-effects |
| 77 | + # and we load a bunch of mocks for components in the tests below |
| 78 | + finder._components = None |
| 79 | + |
| 80 | + def test_module_relname(self) -> None: |
| 81 | + import torchx.specs.test.components as c |
| 82 | + import torchx.specs.test.components.a as ca |
| 83 | + |
| 84 | + self.assertEqual("", finder.module_relname(c, relative_to=c)) |
| 85 | + self.assertEqual("a", finder.module_relname(ca, relative_to=c)) |
| 86 | + with self.assertRaises(ValueError): |
| 87 | + finder.module_relname(c, relative_to=ca) |
72 | 88 |
|
73 | 89 | def test_get_component_by_name(self) -> None:
|
74 | 90 | component = none_throws(get_component("utils.echo"))
|
75 | 91 | self.assertEqual("utils.echo", component.name)
|
76 | 92 | self.assertEqual("echo", component.fn_name)
|
77 | 93 | self.assertIsNotNone(component.fn)
|
78 | 94 |
|
79 |
| - def test_get_invalid_component_by_name(self) -> None: |
80 |
| - test_torchx_group = {"foobar": sys.modules[__name__]} |
81 |
| - finder._components = None |
82 |
| - with patch("torchx.specs.finder.entrypoints") as entrypoints_mock: |
83 |
| - entrypoints_mock.load_group.return_value = test_torchx_group |
84 |
| - with self.assertRaises(ComponentValidationException): |
85 |
| - get_component("foobar.finder_test.invalid_component") |
| 95 | + @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) |
| 96 | + def test_get_invalid_component_by_name(self, _: MagicMock) -> None: |
| 97 | + with self.assertRaises(ComponentValidationException): |
| 98 | + get_component("invalid_component") |
86 | 99 |
|
87 |
| - def test_get_unknown_component_by_name(self) -> None: |
88 |
| - test_torchx_group = {"foobar": sys.modules[__name__]} |
89 |
| - finder._components = None |
90 |
| - with patch("torchx.specs.finder.entrypoints") as entrypoints_mock: |
91 |
| - entrypoints_mock.load_group.return_value = test_torchx_group |
92 |
| - with self.assertRaises(ComponentNotFoundException): |
93 |
| - get_component("foobar.finder_test.unknown_component") |
94 |
| - |
95 |
| - def test_get_invalid_component(self) -> None: |
96 |
| - test_torchx_group = {"foobar": sys.modules[__name__]} |
97 |
| - with patch("torchx.specs.finder.entrypoints") as entrypoints_mock: |
98 |
| - entrypoints_mock.load_group.return_value = test_torchx_group |
99 |
| - components = _load_components() |
100 |
| - foobar_component = components["foobar.finder_test.invalid_component"] |
| 100 | + @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) |
| 101 | + def test_get_unknown_component_by_name(self, _: MagicMock) -> None: |
| 102 | + with self.assertRaises(ComponentNotFoundException): |
| 103 | + get_component("unknown_component") |
| 104 | + |
| 105 | + @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) |
| 106 | + def test_get_invalid_component(self, _: MagicMock) -> None: |
| 107 | + components = _load_components() |
| 108 | + foobar_component = components["invalid_component"] |
101 | 109 | self.assertEqual(1, len(foobar_component.validation_errors))
|
102 | 110 |
|
103 |
| - def test_get_entrypoints_components(self) -> None: |
104 |
| - test_torchx_group = {"foobar": sys.modules[__name__]} |
105 |
| - with patch("torchx.specs.finder.entrypoints") as entrypoints_mock: |
106 |
| - entrypoints_mock.load_group.return_value = test_torchx_group |
107 |
| - components = _load_components() |
108 |
| - foobar_component = components["foobar.finder_test._test_component"] |
| 111 | + @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) |
| 112 | + def test_get_entrypoints_components(self, _: MagicMock) -> None: |
| 113 | + components = _load_components() |
| 114 | + foobar_component = components["_test_component"] |
109 | 115 | self.assertEqual(_test_component, foobar_component.fn)
|
110 | 116 | self.assertEqual("_test_component", foobar_component.fn_name)
|
111 |
| - self.assertEqual("foobar.finder_test._test_component", foobar_component.name) |
| 117 | + self.assertEqual("_test_component", foobar_component.name) |
112 | 118 | self.assertEqual("Test component", foobar_component.description)
|
113 | 119 |
|
114 |
| - def test_get_base_module_name(self) -> None: |
115 |
| - finder = ModuleComponentsFinder(sys.modules[__name__], "") |
116 |
| - expected_name = "torchx.specs.test" |
117 |
| - actual_name = finder._get_base_module_name(sys.modules[__name__]) |
118 |
| - self.assertEqual(expected_name, actual_name) |
119 |
| - |
120 |
| - def test_get_base_module_name_for_init_module(self) -> None: |
121 |
| - finder = ModuleComponentsFinder("", "") |
122 |
| - expected_name = "torchx.specs" |
123 |
| - actual_name = finder._get_base_module_name(sys.modules["torchx.specs"]) |
124 |
| - self.assertEqual(expected_name, actual_name) |
125 |
| - |
126 |
| - def test_get_component_name(self) -> None: |
127 |
| - finder = ModuleComponentsFinder("", group="foobar") |
128 |
| - actual_name = finder._get_component_name( |
129 |
| - "test.main_module", "test.main_module.sub_module.bar", "get_component" |
130 |
| - ) |
131 |
| - expected_name = "foobar.sub_module.bar.get_component" |
132 |
| - self.assertEqual(expected_name, actual_name) |
133 |
| - |
134 |
| - def test_strip_init(self) -> None: |
135 |
| - finder = ModuleComponentsFinder("", "") |
136 |
| - self.assertEqual("foobar", finder._strip_init("foobar.__init__")) |
137 |
| - self.assertEqual("", finder._strip_init("__init__")) |
138 |
| - self.assertEqual("foobar", finder._strip_init("foobar")) |
139 |
| - |
140 |
| - def test_get_module_name(self) -> None: |
141 |
| - finder = ModuleComponentsFinder("", "") |
142 |
| - actual_name = finder._get_module_name( |
143 |
| - "/test/path/main_module/foobar.py", "/test/path", "main" |
| 120 | + @patch( |
| 121 | + _METADATA_EPS, |
| 122 | + return_value=EntryPoints( |
| 123 | + EntryPoint_from_text( |
| 124 | + """ |
| 125 | +[torchx.components] |
| 126 | +foo = torchx.specs.test.components.a |
| 127 | +bar = torchx.specs.test.components.c.d |
| 128 | +""" |
| 129 | + ) |
| 130 | + ), |
| 131 | + ) |
| 132 | + def test_load_custom_components(self, _: MagicMock) -> None: |
| 133 | + components = _load_components() |
| 134 | + |
| 135 | + # the name of the appdefs returned by each component |
| 136 | + # is the expected component name |
| 137 | + for actual_name, comp in components.items(): |
| 138 | + expected_name = comp.fn().name |
| 139 | + self.assertEqual(expected_name, actual_name) |
| 140 | + |
| 141 | + self.assertEqual(3, len(components)) |
| 142 | + |
| 143 | + @patch( |
| 144 | + _METADATA_EPS, |
| 145 | + return_value=EntryPoints( |
| 146 | + EntryPoint_from_text( |
| 147 | + """ |
| 148 | +[torchx.components] |
| 149 | +_0 = torchx.specs.test.components.a |
| 150 | +_1 = torchx.specs.test.components.c.d |
| 151 | +""" |
| 152 | + ) |
| 153 | + ), |
| 154 | + ) |
| 155 | + def test_load_custom_components_nogroup(self, _: MagicMock) -> None: |
| 156 | + components = _load_components() |
| 157 | + |
| 158 | + # test component names are hardcoded expecting |
| 159 | + # test.components.* to be grouped under foo.* |
| 160 | + # and components.a_namepace.* to be grouped under bar.* |
| 161 | + # since we are testing _* (no group prefix) remove the first prefix |
| 162 | + for actual_name, comp in components.items(): |
| 163 | + expected_name = comp.fn().name.split(".", maxsplit=1)[1] |
| 164 | + self.assertEqual(expected_name, actual_name) |
| 165 | + |
| 166 | + def test_load_builtins(self) -> None: |
| 167 | + components = _load_components() |
| 168 | + |
| 169 | + # if nothing registered in entrypoints, then builtins should be loaded |
| 170 | + expected = { |
| 171 | + c.name for c in ModuleComponentsFinder("torchx.components", group="").find() |
| 172 | + } |
| 173 | + self.assertEqual(components.keys(), expected) |
| 174 | + |
| 175 | + def test_load_builtin_echo(self) -> None: |
| 176 | + components = _load_components() |
| 177 | + self.assertTrue(len(components) > 1) |
| 178 | + component = components["utils.echo"] |
| 179 | + self.assertEqual("utils.echo", component.name) |
| 180 | + self.assertEqual( |
| 181 | + "Echos a message to stdout (calls echo)", component.description |
144 | 182 | )
|
145 |
| - expected_name = "main.main_module.foobar" |
146 |
| - self.assertEqual(expected_name, actual_name) |
| 183 | + self.assertEqual("echo", component.fn_name) |
| 184 | + self.assertIsNotNone(component.fn) |
147 | 185 |
|
148 | 186 |
|
149 | 187 | def current_file_path() -> str:
|
|
0 commit comments