diff --git a/src/maturin_import_hook/_resolve_project.py b/src/maturin_import_hook/_resolve_project.py index 86b7d25..aaafc30 100644 --- a/src/maturin_import_hook/_resolve_project.py +++ b/src/maturin_import_hook/_resolve_project.py @@ -108,6 +108,9 @@ class MaturinProject: # the name of the compiled extension module without any suffix # (i.e. "some_package.my_module" instead of "some_package/my_module.cpython-311-x86_64-linux-gnu") module_full_name: str + # the name of the package containing the module, accounting for namespacing + # (e.g. foo.a for an implicit package at foo/a/__init__.py, foo for one at foo/__init__.py) + package_name: str # the root of the python part of the project (or the project root if there is none) python_dir: Path # the path to the top level python package if the project is mixed @@ -119,10 +122,6 @@ class MaturinProject: # all path dependencies including transitive dependencies _all_path_dependencies: Optional[list[Path]] = None - @property - def package_name(self) -> str: - return self.module_full_name.split(".")[0] - @property def module_name(self) -> str: return self.module_full_name.split(".")[-1] @@ -192,13 +191,15 @@ def _resolve_project(project_dir: Path) -> MaturinProject: python_module, extension_module_dir, extension_module_name = _resolve_rust_module(python_dir, module_full_name) immediate_path_dependencies = _get_immediate_path_dependencies(manifest_path.parent, cargo) + package_name = ".".join(python_module.relative_to(python_dir).parts) if not python_module.exists(): extension_module_dir = None python_module = None - + package_name = module_full_name.split(".")[0] return MaturinProject( cargo_manifest_path=manifest_path, module_full_name=module_full_name, + package_name=package_name, python_dir=python_dir, python_module=python_module, extension_module_dir=extension_module_dir, @@ -221,8 +222,14 @@ def _resolve_rust_module(python_dir: Path, module_name: str) -> tuple[Path, Path """ parts = module_name.split(".") if len(parts) > 1: - python_module = python_dir / parts[0] + # Find the first level between the module name and the python source dir containing + # an __init__.py extension_module_dir = python_dir / Path(*parts[:-1]) + python_module = extension_module_dir + while python_module != python_dir: + if (python_module / "__init__.py").exists(): + break + python_module = python_module.parent() extension_module_name = parts[-1] else: python_module = python_dir / module_name diff --git a/src/maturin_import_hook/project_importer.py b/src/maturin_import_hook/project_importer.py index 55dc9b3..86ec9a3 100644 --- a/src/maturin_import_hook/project_importer.py +++ b/src/maturin_import_hook/project_importer.py @@ -6,6 +6,7 @@ import json import logging import os +import re import site import sys import tempfile @@ -15,10 +16,11 @@ from abc import ABC, abstractmethod from collections.abc import Iterator, Sequence from functools import lru_cache +from importlib._bootstrap_external import _NamespacePath from importlib.machinery import ExtensionFileLoader, ModuleSpec, PathFinder from pathlib import Path from types import ModuleType -from typing import ClassVar, Optional, Union +from typing import Any, ClassVar, Optional, Union from maturin_import_hook._building import ( BuildCache, @@ -49,6 +51,7 @@ "DefaultProjectFileSearcher", ] +_DIST_INFO_REGEX = re.compile(r"^(?P.+?)-\d[\w\.\-]*\.dist-info$") class ProjectFileSearcher(ABC): @abstractmethod @@ -112,13 +115,16 @@ def invalidate_caches(self) -> None: def find_spec( self, fullname: str, - path: Optional[Sequence[Union[str, bytes]]] = None, + path: Optional[Sequence[Union[str, bytes]] | _NamespacePath] = None, target: Optional[ModuleType] = None, ) -> Optional[ModuleSpec]: - is_top_level_import = path is None + is_in_namespace = path is not None and not isinstance(path, Sequence) + is_top_level_import = path is None or is_in_namespace if not is_top_level_import: return None - assert "." not in fullname + if not is_in_namespace: + assert "." not in fullname + # Impossible to tell if fullname corresponds to a bare namespace at this point package_name = fullname already_loaded = package_name in sys.modules @@ -140,7 +146,10 @@ def find_spec( spec = None rebuilt = False for search_path in search_paths: - project_dir, is_editable = _load_dist_info(search_path, package_name) + # Account for namespaced packages + dist_name = package_name.replace(".", "_") + project_dir, is_editable = _load_dist_info(search_path, dist_name) + # namespaces do not have dist-infos of their own if project_dir is not None: logger.debug('found project linked by dist-info: "%s"', project_dir) if not is_editable and not self._enable_automatic_installation: @@ -148,18 +157,18 @@ def find_spec( "package not installed in editable-mode and enable_automatic_installation=False. not rebuilding" ) else: - spec, rebuilt = self._rebuild_project(package_name, project_dir) + spec, rebuilt = self._rebuild_project(package_name, project_dir, dist_name) if spec is not None: break project_dir = _find_maturin_project_above(search_path) - if project_dir is not None: + if project_dir is not None and (search_path / package_name / "__init__.py").exists(): logger.debug( 'found project above the search path: "%s" ("%s")', project_dir, search_path, ) - spec, rebuilt = self._rebuild_project(package_name, project_dir) + spec, rebuilt = self._rebuild_project(package_name, project_dir, dist_name) if spec is not None: break @@ -221,6 +230,7 @@ def _rebuild_project( self, package_name: str, project_dir: Path, + dist_name: str ) -> tuple[Optional[ModuleSpec], bool]: resolved = self._resolver.resolve(project_dir) if resolved is None: @@ -238,7 +248,7 @@ def _rebuild_project( ) return None, False - if not self._enable_automatic_installation and not _is_editable_installed_package(project_dir, package_name): + if not self._enable_automatic_installation and not _is_editable_installed_package(project_dir, dist_name): logger.debug( 'package "%s" is not already installed and enable_automatic_installation=False. Not importing', package_name, @@ -344,7 +354,15 @@ def _log_build_warnings(self, module_path: str, maturin_output: str, is_fresh: b def _find_spec_for_package(package_name: str) -> Optional[ModuleSpec]: path_finder = PathFinder() - spec = path_finder.find_spec(package_name) + current_path: Optional[Path] = None + for item in package_name.split("."): + spec = path_finder.find_spec(item, current_path) + if not spec: + break + if spec.submodule_search_locations and not isinstance(spec.submodule_search_locations, Sequence): + current_path = spec.submodule_search_locations + if isinstance(spec.submodule_search_locations, Sequence): + break if spec is not None: return spec logger.debug('spec for package "%s" not found', package_name) @@ -396,7 +414,8 @@ def _find_dist_info_path(directory: Path, package_name: str) -> Optional[Path]: except FileNotFoundError: return None for name in names: - if name.startswith(package_name) and name.endswith(".dist-info"): + match_res = _DIST_INFO_REGEX.match(name) + if match_res is not None and match_res.group("package_name") == package_name: return Path(directory / name) return None