5
5
import contextlib
6
6
import functools
7
7
import importlib .util
8
+ import itertools
8
9
import os
9
10
import sys
10
11
from pathlib import Path
@@ -125,11 +126,25 @@ def find_case_sensitive_path(path: Path, platform: str) -> Path:
125
126
def import_path (path : Path , root : Path ) -> ModuleType :
126
127
"""Import and return a module from the given path.
127
128
128
- The function is taken from pytest when the import mode is set to ``importlib``. It
129
- pytest's recommended import mode for new projects although the default is set to
130
- ``prepend``. More discussion and information can be found in :issue:`373`.
129
+ The functions are taken from pytest when the import mode is set to ``importlib``. It
130
+ was assumed to be the new default import mode but insurmountable tradeoffs caused
131
+ the default to be set to ``prepend``. More discussion and information can be found
132
+ in :issue:`373`.
131
133
132
134
"""
135
+ try :
136
+ pkg_root , module_name = _resolve_pkg_root_and_module_name (path )
137
+ except CouldNotResolvePathError :
138
+ pass
139
+ else :
140
+ # If the given module name is already in sys.modules, do not import it again.
141
+ with contextlib .suppress (KeyError ):
142
+ return sys .modules [module_name ]
143
+
144
+ mod = _import_module_using_spec (module_name , path , pkg_root )
145
+ if mod is not None :
146
+ return mod
147
+
133
148
module_name = _module_name_from_path (path , root )
134
149
with contextlib .suppress (KeyError ):
135
150
return sys .modules [module_name ]
@@ -147,42 +162,134 @@ def import_path(path: Path, root: Path) -> ModuleType:
147
162
return mod
148
163
149
164
165
+ def _resolve_package_path (path : Path ) -> Path | None :
166
+ """Resolve package path.
167
+
168
+ Return the Python package path by looking for the last directory upwards which still
169
+ contains an ``__init__.py``.
170
+
171
+ Returns None if it can not be determined.
172
+
173
+ """
174
+ result = None
175
+ for parent in itertools .chain ((path ,), path .parents ):
176
+ if parent .is_dir ():
177
+ if not (parent / "__init__.py" ).is_file ():
178
+ break
179
+ if not parent .name .isidentifier ():
180
+ break
181
+ result = parent
182
+ return result
183
+
184
+
185
+ def _resolve_pkg_root_and_module_name (path : Path ) -> tuple [Path , str ]:
186
+ """Resolve the root package directory and module name for the given Python file.
187
+
188
+ Return the path to the directory of the root package that contains the given Python
189
+ file, and its module name:
190
+
191
+ .. code-block:: text
192
+
193
+ src/
194
+ app/
195
+ __init__.py core/
196
+ __init__.py models.py
197
+
198
+ Passing the full path to `models.py` will yield Path("src") and "app.core.models".
199
+
200
+ Raises CouldNotResolvePathError if the given path does not belong to a package
201
+ (missing any __init__.py files).
202
+
203
+ """
204
+ pkg_path = _resolve_package_path (path )
205
+ if pkg_path is not None :
206
+ pkg_root = pkg_path .parent
207
+
208
+ names = list (path .with_suffix ("" ).relative_to (pkg_root ).parts )
209
+ if names [- 1 ] == "__init__" :
210
+ names .pop ()
211
+ module_name = "." .join (names )
212
+ return pkg_root , module_name
213
+
214
+ msg = f"Could not resolve for { path } "
215
+ raise CouldNotResolvePathError (msg )
216
+
217
+
218
+ class CouldNotResolvePathError (Exception ):
219
+ """Custom exception raised by _resolve_pkg_root_and_module_name."""
220
+
221
+
222
+ def _import_module_using_spec (
223
+ module_name : str , module_path : Path , module_location : Path
224
+ ) -> ModuleType | None :
225
+ """Import a module using its specification.
226
+
227
+ Tries to import a module by its canonical name, path to the .py file, and its parent
228
+ location.
229
+
230
+ """
231
+ # Checking with sys.meta_path first in case one of its hooks can import this module,
232
+ # such as our own assertion-rewrite hook.
233
+ for meta_importer in sys .meta_path :
234
+ spec = meta_importer .find_spec (module_name , [str (module_location )])
235
+ if spec is not None :
236
+ break
237
+ else :
238
+ spec = importlib .util .spec_from_file_location (module_name , str (module_path ))
239
+ if spec is not None :
240
+ mod = importlib .util .module_from_spec (spec )
241
+ sys .modules [module_name ] = mod
242
+ spec .loader .exec_module (mod ) # type: ignore[union-attr]
243
+ return mod
244
+
245
+ return None
246
+
247
+
150
248
def _module_name_from_path (path : Path , root : Path ) -> str :
151
249
"""Return a dotted module name based on the given path, anchored on root.
152
250
153
- For example: path="projects/src/project /task_foo.py" and root="/projects", the
154
- resulting module name will be "src.project .task_foo".
251
+ For example: path="projects/src/tasks /task_foo.py" and root="/projects", the
252
+ resulting module name will be "src.tasks .task_foo".
155
253
156
254
"""
157
255
path = path .with_suffix ("" )
158
256
try :
159
257
relative_path = path .relative_to (root )
160
258
except ValueError :
161
- # If we can't get a relative path to root, use the full path, except for the
162
- # first part ("d:\\" or "/" depending on the platform, for example).
259
+ # If we can't get a relative path to root, use the full path, except
260
+ # for the first part ("d:\\" or "/" depending on the platform, for example).
163
261
path_parts = path .parts [1 :]
164
262
else :
165
263
# Use the parts for the relative path to the root path.
166
264
path_parts = relative_path .parts
167
265
168
- # Module name for packages do not contain the __init__ file, unless the
169
- # `__init__.py` file is at the root.
266
+ # Module name for packages do not contain the __init__ file, unless
267
+ # the `__init__.py` file is at the root.
170
268
if len (path_parts ) >= 2 and path_parts [- 1 ] == "__init__" : # noqa: PLR2004
171
269
path_parts = path_parts [:- 1 ]
172
270
271
+ # Module names cannot contain ".", normalize them to "_". This prevents a directory
272
+ # having a "." in the name (".env.310" for example) causing extra intermediate
273
+ # modules. Also, important to replace "." at the start of paths, as those are
274
+ # considered relative imports.
275
+ path_parts = tuple (x .replace ("." , "_" ) for x in path_parts )
276
+
173
277
return "." .join (path_parts )
174
278
175
279
176
280
def _insert_missing_modules (modules : dict [str , ModuleType ], module_name : str ) -> None :
177
- """Insert missing modules when importing modules with :func:`import_path` .
281
+ """Insert missing modules in sys. modules.
178
282
179
- When we want to import a module as ``src.project.task_foo `` for example, we need to
180
- create empty modules ``src`` and `` src.project`` after inserting
181
- `` src.project.task_foo``, otherwise `` src.project.task_foo`` is not importable by
182
- ``__import__``.
283
+ Used by ``import_path `` to create intermediate modules when using mode=importlib.
284
+ When we want to import a module as " src.tasks.task_foo" for example, we need to
285
+ create empty modules " src" and " src.tasks" after inserting "src.tasks.task_foo",
286
+ otherwise "src.tasks.task_foo" is not importable by ``__import__``.
183
287
184
288
"""
185
289
module_parts = module_name .split ("." )
290
+ child_module : ModuleType | None = None
291
+ module : ModuleType | None = None
292
+ child_name : str = ""
186
293
while module_name :
187
294
if module_name not in modules :
188
295
try :
@@ -192,13 +299,20 @@ def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) ->
192
299
# creating a dummy module.
193
300
if not sys .meta_path :
194
301
raise ModuleNotFoundError # noqa: TRY301
195
- importlib .import_module (module_name )
302
+ module = importlib .import_module (module_name )
196
303
except ModuleNotFoundError :
197
304
module = ModuleType (
198
305
module_name ,
199
- doc = "Empty module created by pytask ." ,
306
+ doc = "Empty module created by pytest's importmode=importlib ." ,
200
307
)
201
- modules [module_name ] = module
308
+ else :
309
+ module = modules [module_name ]
310
+ # Add child attribute to the parent that can reference the child modules.
311
+ if child_module and not hasattr (module , child_name ):
312
+ setattr (module , child_name , child_module )
313
+ modules [module_name ] = module
314
+ # Keep track of the child module while moving up the tree.
315
+ child_module , child_name = module , module_name .rpartition ("." )[- 1 ]
202
316
module_parts .pop (- 1 )
203
317
module_name = "." .join (module_parts )
204
318
0 commit comments