Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update mypy config #145

Merged
merged 8 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions autointent/_callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,12 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
for key, value in module_kwargs.items():
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]

def log_value(self, **kwargs: dict[str, Any]) -> None:
def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
"""
Log data.

:param kwargs: Data to log.
"""
if self.module_writer is None:
msg = "start_run must be called before log_value."
raise RuntimeError(msg)

for key, value in kwargs.items():
if isinstance(value, int | float):
self.module_writer.add_scalar(key, value)
Expand All @@ -79,10 +75,6 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:

:param metrics: Metrics to log.
"""
if self.module_writer is None:
msg = "start_run must be called before log_value."
raise RuntimeError(msg)

for key, value in metrics.items():
if isinstance(value, int | float):
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
Expand Down
9 changes: 4 additions & 5 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import yaml
from typing_extensions import assert_never

from autointent import Context, Dataset
from autointent.configs import DataConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
Expand Down Expand Up @@ -52,8 +53,7 @@ def __init__(
self.vector_index_config = VectorIndexConfig()
self.data_config = DataConfig()
elif not isinstance(nodes[0], InferenceNode):
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
raise TypeError(msg)
assert_never(nodes)

def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None:
"""
Expand All @@ -68,8 +68,7 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) ->
elif isinstance(config, DataConfig):
self.data_config = config
else:
msg = "unknown config type"
raise TypeError(msg)
assert_never(config)

@classmethod
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline":
Expand Down Expand Up @@ -180,7 +179,7 @@ def fit(
)

if sampler is None:
sampler = self.sampler or "brute"
sampler = self.sampler

self._fit(context, sampler)

Expand Down
4 changes: 0 additions & 4 deletions autointent/configs/_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,3 @@ class InferenceNodeConfig:
"""Configuration of the module"""
load_path: str | None = None
"""Path to the module dump. If None, the module will be trained from scratch"""

def __post_init__(self) -> None:
if not isinstance(self.node_type, NodeType):
self.node_type = NodeType(self.node_type)
4 changes: 0 additions & 4 deletions autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ def dump(self) -> None:
optimization_results = self.optimization_info.dump_evaluation_results()

logs_dir = self.logging_config.dirpath
if logs_dir is None:
msg = "something's wrong with LoggingConfig"
raise ValueError(msg)

logs_dir.mkdir(parents=True, exist_ok=True)

logs_path = logs_dir / "logs.json"
Expand Down
2 changes: 0 additions & 2 deletions autointent/metrics/_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def transform(
:param y_pred: Y_pred values
:return:
"""
if isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray):
return y_true, y_pred
y_pred_ = np.array(y_pred)
y_true_ = np.array(y_true)
return y_true_, y_pred_
6 changes: 3 additions & 3 deletions autointent/modules/abc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]:
Calculate metric on test set and return metric value.

:param context: Context to score
:param split: Split to score on
:param metrics: Metrics to score
:return: Computed metrics value for the test set or error code of metrics
"""
if context.data_handler.config.scheme == "ho":
return self.score_ho(context, metrics)
if context.data_handler.config.scheme == "cv":
return self.score_cv(context, metrics)
msg = "Something's wrong with validation schemas"
raise RuntimeError(msg)
msg = f"Unknown scheme: {context.data_handler.config.scheme}"
raise ValueError(msg)

@abstractmethod
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: ...
Expand Down
7 changes: 4 additions & 3 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from optuna.trial import Trial
from pydantic import BaseModel, Field
from typing_extensions import assert_never

from autointent import Dataset
from autointent.context import Context
Expand Down Expand Up @@ -65,11 +66,12 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
Fit the node optimizer.

:param context: Context
:param sampler: Sampler to use for optimization
"""
self._logger.info("starting %s node optimization...", self.node_info.node_type)

for search_space in deepcopy(self.modules_search_spaces):
self._counter = 0
self._counter: int = 0
module_name = search_space.pop("module_name")
n_trials = None
if "n_trials" in search_space:
Expand All @@ -84,8 +86,7 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
n_trials = n_trials or 10
else:
msg = f"Unexpected sampler: {sampler}"
raise ValueError(msg)
assert_never(sampler)
study = optuna.create_study(direction="maximize", sampler=sampler_instance)
optuna.logging.set_verbosity(optuna.logging.WARNING)
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)
Expand Down
2 changes: 1 addition & 1 deletion autointent/schemas/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
prompts[TaskTypeEnum.sts.value] = self.sts_prompt
return prompts if len(prompts) > 0 else None

def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
def get_prompt_type(self, prompt_type: TaskTypeEnum | str | None) -> str | None: # noqa: PLR0911
"""Get the prompt type for the given task type.

:param prompt_type: Task type for which to get the prompt.
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ skip_empty = true
python_version = "3.10"
strict = true
warn_redundant_casts = true
# align with mypy 2.0 release
warn_unreachable = true
local_partial_types = true
plugins = [
"pydantic.mypy",
"numpy.typing.mypy_plugin",
Expand Down Expand Up @@ -193,3 +196,10 @@ module = [
"wandb",
]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"autointent._callbacks.*",
"autointent.modules.abc.*",
]
warn_unreachable = false