diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 3d27ca99302f..f6a962101af4 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -49,8 +49,10 @@ ) from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback from mypy.plugins.functools import ( + functools_lru_cache_callback, functools_total_ordering_maker_callback, functools_total_ordering_makers, + lru_cache_wrapper_call_callback, partial_call_callback, partial_new_callback, ) @@ -101,6 +103,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] return create_singledispatch_function_callback elif fullname == "functools.partial": return partial_new_callback + elif fullname == "functools.lru_cache": + return functools_lru_cache_callback elif fullname == "enum.member": return enum_member_callback return None @@ -160,6 +164,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No return call_singledispatch_function_after_register_argument elif fullname == "functools.partial.__call__": return partial_call_callback + elif fullname == "functools._lru_cache_wrapper.__call__": + return lru_cache_wrapper_call_callback return None def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index c8b370f15e6d..d3a033641553 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -41,6 +41,7 @@ _ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"} PARTIAL: Final = "functools.partial" +LRU_CACHE: Final = "functools.lru_cache" class _MethodInfo(NamedTuple): @@ -393,3 +394,135 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) return result + + +def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """Infer a more precise return type for functools.lru_cache decorator""" + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals + return ctx.default_return_type + + # Only handle the very specific case: @lru_cache (without parentheses) + # where a single function is passed directly as the only argument + if ( + len(ctx.arg_types) == 1 + and len(ctx.arg_types[0]) == 1 + and len(ctx.args) == 1 + and len(ctx.args[0]) == 1 + ): + + first_arg_type = ctx.arg_types[0][0] + + # Explicitly reject literal types, instances, and None + from mypy.types import Instance, LiteralType, NoneType + + proper_first_arg_type = get_proper_type(first_arg_type) + if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)): + return ctx.default_return_type + + # Try to extract callable type + fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type) + if fn_type is not None: + # This is the @lru_cache case (function passed directly) + return fn_type + + # For all other cases (parameterized, multiple args, etc.), don't interfere + return ctx.default_return_type + + +def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type: + """Handle calls to functools._lru_cache_wrapper objects to provide parameter validation""" + if not isinstance(ctx.api, mypy.checker.TypeChecker): + return ctx.default_return_type + + # Safety check: ensure we have the required context + if not ctx.context or not ctx.args or not ctx.arg_types: + return ctx.default_return_type + + # Try to find the original function signature using AST/symbol table analysis + original_signature = _find_original_function_signature(ctx) + + if original_signature is not None: + # Validate the call against the original function signature + actual_args = [] + actual_arg_kinds = [] + actual_arg_names = [] + seen_args = set() + + for i, param in enumerate(ctx.args): + for j, a in enumerate(param): + if a in seen_args: + continue + seen_args.add(a) + actual_args.append(a) + actual_arg_kinds.append(ctx.arg_kinds[i][j]) + actual_arg_names.append(ctx.arg_names[i][j]) + + # Check the call against the original signature + try: + result, _ = ctx.api.expr_checker.check_call( + callee=original_signature, + args=actual_args, + arg_kinds=actual_arg_kinds, + arg_names=actual_arg_names, + context=ctx.context, + ) + return result + except Exception: + # If check_call fails, fall back gracefully + pass + + return ctx.default_return_type + + +def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None: + """ + Attempt to find the original function signature from the call context. + + Returns the CallableType of the original function if found, None otherwise. + This function safely traverses the AST structure to locate the original + function signature that was decorated with @lru_cache. + """ + from mypy.nodes import CallExpr, Decorator, NameExpr + + try: + # Ensure we have the required context structure + if not isinstance(ctx.context, CallExpr): + return None + + callee = ctx.context.callee + if not isinstance(callee, NameExpr) or not callee.name: + return None + + func_name = callee.name + + # Safely access the API globals + if not hasattr(ctx.api, "globals") or not isinstance(ctx.api.globals, dict): + return None + + if func_name not in ctx.api.globals: + return None + + symbol = ctx.api.globals[func_name] + + # Validate symbol structure before accessing node + if not hasattr(symbol, "node") or symbol.node is None: + return None + + # Check if this is a decorator node containing our function + if isinstance(symbol.node, Decorator): + decorator_node = symbol.node + + # Safely access the decorated function + if not hasattr(decorator_node, "func") or decorator_node.func is None: + return None + + func_def = decorator_node.func + + # Verify we have a callable type + if hasattr(func_def, "type") and isinstance(func_def.type, CallableType): + return func_def.type + + return None + except (AttributeError, TypeError, KeyError): + # If anything goes wrong in AST traversal, fail gracefully + return None diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index fa2cacda275d..80fd5d59c1b8 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -726,3 +726,135 @@ def outer_c(arg: Tc) -> None: use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \ # N: "partial[str].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], str]" [builtins fixtures/tuple.pyi] + +[case testLruCacheBasicValidation] +from functools import lru_cache + +@lru_cache +def f(v: str, at: int) -> str: + return v + +f() # E: Missing positional arguments "v", "at" in call to "f" +f("abc") # E: Missing positional argument "at" in call to "f" +f("abc", 123) # OK +f("abc", at=123) # OK +f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithReturnType] +from functools import lru_cache + +@lru_cache +def multiply(x: int, y: int) -> int: + return 42 + +reveal_type(multiply) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int" +reveal_type(multiply(2, 3)) # N: Revealed type is "builtins.int" +multiply("a", 3) # E: Argument 1 to "multiply" has incompatible type "str"; expected "int" +multiply(2, "b") # E: Argument 2 to "multiply" has incompatible type "str"; expected "int" +multiply(2) # E: Missing positional argument "y" in call to "multiply" +multiply(1, 2, 3) # E: Too many arguments for "multiply" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithOptionalArgs] +from functools import lru_cache + +@lru_cache +def greet(name: str, greeting: str = "Hello") -> str: + return "result" + +greet("World") # OK +greet("World", "Hi") # OK +greet("World", greeting="Hi") # OK +greet() # E: Missing positional argument "name" in call to "greet" +greet(123) # E: Argument 1 to "greet" has incompatible type "int"; expected "str" +greet("World", 123) # E: Argument 2 to "greet" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] + +[case testLruCacheGenericFunction] +from functools import lru_cache +from typing import TypeVar + +T = TypeVar('T') + +@lru_cache +def identity(x: T) -> T: + return x + +reveal_type(identity(42)) # N: Revealed type is "builtins.int" +reveal_type(identity("hello")) # N: Revealed type is "builtins.str" +identity() # E: Missing positional argument "x" in call to "identity" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithParentheses] +from functools import lru_cache + +@lru_cache() +def f(v: str, at: int) -> str: + return v + +f() # E: Missing positional arguments "v", "at" in call to "f" +f("abc") # E: Missing positional argument "at" in call to "f" +f("abc", 123) # OK +f("abc", at=123) # OK +f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithMaxsize] +from functools import lru_cache + +@lru_cache(maxsize=128) +def g(v: str, at: int) -> str: + return v + +g() # E: Missing positional arguments "v", "at" in call to "g" +g("abc") # E: Missing positional argument "at" in call to "g" +g("abc", 123) # OK +g("abc", at=123) # OK +g("abc", at="wrong_type") # E: Argument "at" to "g" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheGenericWithParameters] +from functools import lru_cache +from typing import TypeVar + +T = TypeVar('T') + +@lru_cache() +def identity_empty(x: T) -> T: + return x + +@lru_cache(maxsize=128) +def identity_maxsize(x: T) -> T: + return x + +reveal_type(identity_empty(42)) # N: Revealed type is "builtins.int" +reveal_type(identity_maxsize("hello")) # N: Revealed type is "builtins.str" +identity_empty() # E: Missing positional argument "x" in call to "identity_empty" +identity_maxsize() # E: Missing positional argument "x" in call to "identity_maxsize" +[builtins fixtures/dict.pyi] + +[case testLruCacheMaxsizeNone] +from functools import lru_cache + +@lru_cache(maxsize=None) +def unlimited_cache(x: int, y: str) -> str: + return y + +unlimited_cache(42, "test") # OK +unlimited_cache() # E: Missing positional arguments "x", "y" in call to "unlimited_cache" +unlimited_cache(42) # E: Missing positional argument "y" in call to "unlimited_cache" +unlimited_cache("wrong", "test") # E: Argument 1 to "unlimited_cache" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheMaxsizeZero] +from functools import lru_cache + +@lru_cache(maxsize=0) +def no_cache(value: str) -> str: + return value + +no_cache("hello") # OK +no_cache() # E: Missing positional argument "value" in call to "no_cache" +no_cache(123) # E: Argument 1 to "no_cache" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/lib-stub/functools.pyi b/test-data/unit/lib-stub/functools.pyi index b8d47e1da2b5..abd4bff37a90 100644 --- a/test-data/unit/lib-stub/functools.pyi +++ b/test-data/unit/lib-stub/functools.pyi @@ -37,3 +37,11 @@ class cached_property(Generic[_T]): class partial(Generic[_T]): def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ... def __call__(__self, *args: Any, **kwargs: Any) -> _T: ... + +class _lru_cache_wrapper(Generic[_T]): + def __call__(__self, *args: Any, **kwargs: Any) -> _T: ... + +@overload +def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ... +@overload +def lru_cache(__func: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ...