Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 28 additions & 149 deletions klippy/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
import collections
import gc
import importlib
import importlib.util
import logging
import optparse
import os
import pkgutil
import sys
import time
from collections import defaultdict
from enum import Enum
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Generator, Optional, Union

from klippy.configfile import ConfigWrapper

Expand Down Expand Up @@ -84,56 +83,6 @@ class WaitInterruption(gcode.CommandError):
pass


class PrinterModuleType(Enum):
EXTRA = "klippy.extras."
PLUGIN = ("klippy.extras.", True)
PLUGIN_OVERRIDE_EXTRA = ("klippy.extras.", True, True)
PLUGIN_DIRECTORY = ("klippy.plugins.", True)
PLUGIN_DIRECTORY_OVERRIDE_EXTRA = ("klippy.plugins.", True, True)

def __init__(
self,
module_root,
custom_loading: bool = False,
is_override: bool = False,
):
self.module_root = module_root
self.custom_loading = custom_loading
self.is_override = is_override

def import_module(self, module_name: str, module_path: Path) -> ModuleType:
full_name = self.module_root + module_name
if self.custom_loading:
return self._module_from_spec(full_name, module_path)
return self._import_module(full_name)

@staticmethod
def _import_module(module_name: str) -> ModuleType:
"""
Import a module when its physical path on disk matches it module path
All extras and plugins in a directory
"""
return importlib.import_module(module_name)

@staticmethod
def _module_from_spec(module_name: str, module_path: Path) -> ModuleType:
"""
Import a module when its module path doesn't match its physical path
Default for plugin files
"""
path = module_path
if path.is_dir():
path = module_path.joinpath("__init__.py")
mod_spec = importlib.util.spec_from_file_location(module_name, path)
if mod_spec is None:
raise ModuleNotFoundError(f"Module {module_name} failed to load")
module = importlib.util.module_from_spec(mod_spec)
mod_spec.loader.exec_module(module)
# TODO: insert into sys_modules?
# sys.modules[module_name] = module
return module


class SubsystemComponentCollection:
def __init__(self, config_error):
self._subsystems: dict[str, dict[str, Any]] = defaultdict(dict)
Expand Down Expand Up @@ -163,38 +112,26 @@ def register_component(


class PrinterModule:
path: Path
name: str
module_type: PrinterModuleType
exception: Optional[Exception] = None
module_info: pkgutil.ModuleInfo
module: Optional[ModuleType] = None
allow_plugin_override: bool
config_error: Callable
exception: Optional[Exception] = None

def __init__(
self,
path: Path,
module_type: PrinterModuleType,
allow_plugin_override: bool,
config_error: Callable,
):
self.path = path
self.name = path.stem
self.module_type = module_type
self.allow_plugin_override = allow_plugin_override
self.config_error = config_error
def __init__(self, name: str, module_info: pkgutil.ModuleInfo):
self.name = name
self.module_info = module_info

def load(self):
try:
self.module = self.module_type.import_module(self.name, self.path)
self.module = importlib.import_module(self.module_info.name)
except Exception as ex:
logging.exception(f"Failed to load module '{self.name}'.")
self.exception = ex

def get_init_function(self, section: str):
# if loading failed, raise that exception now
self.verify_loaded()
self.validate_plugin_overrides()
if self.exception is not None:
raise self.exception
# find the right init function
is_prefix = self.name != section
init_func_name = "load_config_prefix" if is_prefix else "load_config"
Expand All @@ -208,22 +145,8 @@ def register_components(self, collector: SubsystemComponentCollection):
register_func = self.get_method("register_components")
if register_func is None:
return
# only validate now that the call will actually happen
self.validate_plugin_overrides()
register_func(collector)

def validate_plugin_overrides(self):
if not self.module_type.is_override:
return
if not self.allow_plugin_override:
raise self.config_error(
f"Module '{self.name}' found in both extras and plugins!"
)

def verify_loaded(self):
if self.exception is not None:
raise self.exception

def get_method(self, function_name):
if self.module is None:
return None
Expand Down Expand Up @@ -255,74 +178,31 @@ def __init__(self, main_reactor, bglogger, start_args):
m.add_early_printer_objects(self)

@staticmethod
def _list_modules(search_path: str) -> list[Path]:
"""
list files + directories and filter to only those that could be a module
"""
path_list: list[Path] = []
for path_string in os.listdir(search_path):
path = Path(os.path.join(search_path, path_string))
# don't include hidden files or directories
# don't include __init__.py
if path.name.startswith(".") or path.name.startswith("__"):
continue
# only include files that are .py files
if path.is_file() and not path.name.endswith(".py"):
continue
path_list.append(path)
return path_list
def _iter_modules(prefix: str, path: Path) -> Generator[PrinterModule]:
for module_info in pkgutil.iter_modules([str(path)], prefix=prefix):
name = module_info.name.rsplit(".", 1)[-1]
yield PrinterModule(name, module_info)

def _load_modules(self, config: ConfigWrapper):
allow_overrides = self._allow_plugin_override(config)
extra_modules: dict[str, PrinterModule] = {}
extras_path = os.path.join(os.path.dirname(__file__), "extras")
extras = self._list_modules(extras_path)
extra_names = [extra.stem for extra in extras]
plugin_modules: dict[str, PrinterModule] = {}
plugins_path = os.path.join(os.path.dirname(__file__), "plugins")
plugins = self._list_modules(plugins_path)
plugin_names = [plugin.stem for plugin in plugins]

for plugin in plugins:
is_dir = plugin.is_dir()
is_override = plugin.name in extra_names
if is_override:
if is_dir:
module_type = (
PrinterModuleType.PLUGIN_DIRECTORY_OVERRIDE_EXTRA
)
else:
module_type = PrinterModuleType.PLUGIN_OVERRIDE_EXTRA
else:
if is_dir:
module_type = PrinterModuleType.PLUGIN_DIRECTORY
else:
module_type = PrinterModuleType.PLUGIN
pm = PrinterModule(
plugin, module_type, allow_overrides, self.config_error
)
plugin_modules[pm.name] = pm
pm.load()
klippy_dir = Path(__file__).parent

for extra in extras:
# don't load extras that were overridden by plugins
if extra.name in plugin_names:
continue
pm = PrinterModule(
extra,
PrinterModuleType.EXTRA,
allow_overrides,
self.config_error,
)
pm.load()
extra_modules[pm.name] = pm
for pm in self._iter_modules("klippy.extras.", klippy_dir / "extras"):
self.printer_modules[pm.name] = pm

# plugins override extras:
self.printer_modules = extra_modules | plugin_modules
for pm in self._iter_modules("klippy.plugins.", klippy_dir / "plugins"):
if pm.name in self.printer_modules and not allow_overrides:
raise configfile.error(
f"Module '{pm.name}' found in both extras and plugins!"
)
self.printer_modules[pm.name] = pm

for pm in self.printer_modules.values():
pm.load()

def _register_subsystem_components(self):
for name, module in self.printer_modules.items():
module.register_components(self.components)
for printer_module in self.printer_modules.values():
printer_module.register_components(self.components)

@staticmethod
def _allow_plugin_override(config) -> bool:
Expand Down Expand Up @@ -402,7 +282,6 @@ def load_object(
module_name = module_parts[0]
if module_name in self.printer_modules:
printer_module = self.printer_modules[module_name]
printer_module.verify_loaded()
init_func = printer_module.get_init_function(section)
if init_func is None:
if default is not configfile.sentinel:
Expand Down
Loading