diff --git a/autointent/_callbacks/tensorboard.py b/autointent/_callbacks/tensorboard.py index e3020b21..2da29053 100644 --- a/autointent/_callbacks/tensorboard.py +++ b/autointent/_callbacks/tensorboard.py @@ -57,16 +57,12 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any] for key, value in module_kwargs.items(): self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call] - def log_value(self, **kwargs: dict[str, Any]) -> None: + def log_value(self, **kwargs: dict[str, int | float | Any]) -> None: """ Log data. :param kwargs: Data to log. """ - if self.module_writer is None: - msg = "start_run must be called before log_value." - raise RuntimeError(msg) - for key, value in kwargs.items(): if isinstance(value, int | float): self.module_writer.add_scalar(key, value) @@ -79,10 +75,6 @@ def log_metrics(self, metrics: dict[str, Any]) -> None: :param metrics: Metrics to log. """ - if self.module_writer is None: - msg = "start_run must be called before log_value." - raise RuntimeError(msg) - for key, value in metrics.items(): if isinstance(value, int | float): self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call] diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 54bbd5e0..f3644f9f 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -7,6 +7,7 @@ import numpy as np import yaml +from typing_extensions import assert_never from autointent import Context, Dataset from autointent.configs import DataConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig @@ -52,8 +53,7 @@ def __init__( self.vector_index_config = VectorIndexConfig() self.data_config = DataConfig() elif not isinstance(nodes[0], InferenceNode): - msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes" - raise TypeError(msg) + assert_never(nodes) def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None: """ @@ -68,8 +68,7 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> elif isinstance(config, DataConfig): self.data_config = config else: - msg = "unknown config type" - raise TypeError(msg) + assert_never(config) @classmethod def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline": @@ -180,7 +179,7 @@ def fit( ) if sampler is None: - sampler = self.sampler or "brute" + sampler = self.sampler self._fit(context, sampler) diff --git a/autointent/configs/_inference_node.py b/autointent/configs/_inference_node.py index 320feacf..b99c684d 100644 --- a/autointent/configs/_inference_node.py +++ b/autointent/configs/_inference_node.py @@ -18,7 +18,3 @@ class InferenceNodeConfig: """Configuration of the module""" load_path: str | None = None """Path to the module dump. If None, the module will be trained from scratch""" - - def __post_init__(self) -> None: - if not isinstance(self.node_type, NodeType): - self.node_type = NodeType(self.node_type) diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 62d37ab3..79f73029 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -94,10 +94,6 @@ def dump(self) -> None: optimization_results = self.optimization_info.dump_evaluation_results() logs_dir = self.logging_config.dirpath - if logs_dir is None: - msg = "something's wrong with LoggingConfig" - raise ValueError(msg) - logs_dir.mkdir(parents=True, exist_ok=True) logs_path = logs_dir / "logs.json" diff --git a/autointent/metrics/_converter.py b/autointent/metrics/_converter.py index cb0d765c..59bb99cb 100644 --- a/autointent/metrics/_converter.py +++ b/autointent/metrics/_converter.py @@ -22,8 +22,6 @@ def transform( :param y_pred: Y_pred values :return: """ - if isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray): - return y_true, y_pred y_pred_ = np.array(y_pred) y_true_ = np.array(y_true) return y_true_, y_pred_ diff --git a/autointent/modules/abc/_base.py b/autointent/modules/abc/_base.py index 29fad416..f44f5390 100644 --- a/autointent/modules/abc/_base.py +++ b/autointent/modules/abc/_base.py @@ -40,15 +40,15 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]: Calculate metric on test set and return metric value. :param context: Context to score - :param split: Split to score on + :param metrics: Metrics to score :return: Computed metrics value for the test set or error code of metrics """ if context.data_handler.config.scheme == "ho": return self.score_ho(context, metrics) if context.data_handler.config.scheme == "cv": return self.score_cv(context, metrics) - msg = "Something's wrong with validation schemas" - raise RuntimeError(msg) + msg = f"Unknown scheme: {context.data_handler.config.scheme}" + raise ValueError(msg) @abstractmethod def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: ... diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index bb5be151..33fedf05 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -11,6 +11,7 @@ import torch from optuna.trial import Trial from pydantic import BaseModel, Field +from typing_extensions import assert_never from autointent import Dataset from autointent.context import Context @@ -65,11 +66,12 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None: Fit the node optimizer. :param context: Context + :param sampler: Sampler to use for optimization """ self._logger.info("starting %s node optimization...", self.node_info.node_type) for search_space in deepcopy(self.modules_search_spaces): - self._counter = 0 + self._counter: int = 0 module_name = search_space.pop("module_name") n_trials = None if "n_trials" in search_space: @@ -84,8 +86,7 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None: sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment] n_trials = n_trials or 10 else: - msg = f"Unexpected sampler: {sampler}" - raise ValueError(msg) + assert_never(sampler) study = optuna.create_study(direction="maximize", sampler=sampler_instance) optuna.logging.set_verbosity(optuna.logging.WARNING) obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context) diff --git a/autointent/schemas/_schemas.py b/autointent/schemas/_schemas.py index b8ae4691..9c048954 100644 --- a/autointent/schemas/_schemas.py +++ b/autointent/schemas/_schemas.py @@ -197,7 +197,7 @@ def get_prompt_config(self) -> dict[str, str] | None: prompts[TaskTypeEnum.sts.value] = self.sts_prompt return prompts if len(prompts) > 0 else None - def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911 + def get_prompt_type(self, prompt_type: TaskTypeEnum | str | None) -> str | None: # noqa: PLR0911 """Get the prompt type for the given task type. :param prompt_type: Task type for which to get the prompt. diff --git a/pyproject.toml b/pyproject.toml index e9dd1069..e8867b7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,6 +166,9 @@ skip_empty = true python_version = "3.10" strict = true warn_redundant_casts = true +# align with mypy 2.0 release +warn_unreachable = true +local_partial_types = true plugins = [ "pydantic.mypy", "numpy.typing.mypy_plugin", @@ -193,3 +196,10 @@ module = [ "wandb", ] ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = [ + "autointent._callbacks.*", + "autointent.modules.abc.*", +] +warn_unreachable = false \ No newline at end of file