Skip to content

Commit d9d5d00

Browse files
committed
Add functools.lru_cache plugin support
- Add lru_cache callback to functools plugin for type validation - Register callbacks in default plugin for decorator and wrapper calls - Support different lru_cache patterns: @lru_cache, @lru_cache(), @lru_cache(maxsize=N) Fixes issue #16261
1 parent db67888 commit d9d5d00

File tree

4 files changed

+279
-0
lines changed

4 files changed

+279
-0
lines changed

mypy/plugins/default.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@
4949
)
5050
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
5151
from mypy.plugins.functools import (
52+
functools_lru_cache_callback,
5253
functools_total_ordering_maker_callback,
5354
functools_total_ordering_makers,
55+
lru_cache_wrapper_call_callback,
5456
partial_call_callback,
5557
partial_new_callback,
5658
)
@@ -101,6 +103,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
101103
return create_singledispatch_function_callback
102104
elif fullname == "functools.partial":
103105
return partial_new_callback
106+
elif fullname == "functools.lru_cache":
107+
return functools_lru_cache_callback
104108
elif fullname == "enum.member":
105109
return enum_member_callback
106110
return None
@@ -160,6 +164,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
160164
return call_singledispatch_function_after_register_argument
161165
elif fullname == "functools.partial.__call__":
162166
return partial_call_callback
167+
elif fullname == "functools._lru_cache_wrapper.__call__":
168+
return lru_cache_wrapper_call_callback
163169
return None
164170

165171
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:

mypy/plugins/functools.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
4242

4343
PARTIAL: Final = "functools.partial"
44+
LRU_CACHE: Final = "functools.lru_cache"
4445

4546

4647
class _MethodInfo(NamedTuple):
@@ -393,3 +394,135 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
393394
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
394395

395396
return result
397+
398+
399+
def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
400+
"""Infer a more precise return type for functools.lru_cache decorator"""
401+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
402+
return ctx.default_return_type
403+
404+
# Only handle the very specific case: @lru_cache (without parentheses)
405+
# where a single function is passed directly as the only argument
406+
if (
407+
len(ctx.arg_types) == 1
408+
and len(ctx.arg_types[0]) == 1
409+
and len(ctx.args) == 1
410+
and len(ctx.args[0]) == 1
411+
):
412+
413+
first_arg_type = ctx.arg_types[0][0]
414+
415+
# Explicitly reject literal types, instances, and None
416+
from mypy.types import Instance, LiteralType, NoneType
417+
418+
proper_first_arg_type = get_proper_type(first_arg_type)
419+
if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)):
420+
return ctx.default_return_type
421+
422+
# Try to extract callable type
423+
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
424+
if fn_type is not None:
425+
# This is the @lru_cache case (function passed directly)
426+
return fn_type
427+
428+
# For all other cases (parameterized, multiple args, etc.), don't interfere
429+
return ctx.default_return_type
430+
431+
432+
def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
433+
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
434+
if not isinstance(ctx.api, mypy.checker.TypeChecker):
435+
return ctx.default_return_type
436+
437+
# Safety check: ensure we have the required context
438+
if not ctx.context or not ctx.args or not ctx.arg_types:
439+
return ctx.default_return_type
440+
441+
# Try to find the original function signature using AST/symbol table analysis
442+
original_signature = _find_original_function_signature(ctx)
443+
444+
if original_signature is not None:
445+
# Validate the call against the original function signature
446+
actual_args = []
447+
actual_arg_kinds = []
448+
actual_arg_names = []
449+
seen_args = set()
450+
451+
for i, param in enumerate(ctx.args):
452+
for j, a in enumerate(param):
453+
if a in seen_args:
454+
continue
455+
seen_args.add(a)
456+
actual_args.append(a)
457+
actual_arg_kinds.append(ctx.arg_kinds[i][j])
458+
actual_arg_names.append(ctx.arg_names[i][j])
459+
460+
# Check the call against the original signature
461+
try:
462+
result, _ = ctx.api.expr_checker.check_call(
463+
callee=original_signature,
464+
args=actual_args,
465+
arg_kinds=actual_arg_kinds,
466+
arg_names=actual_arg_names,
467+
context=ctx.context,
468+
)
469+
return result
470+
except Exception:
471+
# If check_call fails, fall back gracefully
472+
pass
473+
474+
return ctx.default_return_type
475+
476+
477+
def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
478+
"""
479+
Attempt to find the original function signature from the call context.
480+
481+
Returns the CallableType of the original function if found, None otherwise.
482+
This function safely traverses the AST structure to locate the original
483+
function signature that was decorated with @lru_cache.
484+
"""
485+
from mypy.nodes import CallExpr, Decorator, NameExpr
486+
487+
try:
488+
# Ensure we have the required context structure
489+
if not isinstance(ctx.context, CallExpr):
490+
return None
491+
492+
callee = ctx.context.callee
493+
if not isinstance(callee, NameExpr) or not callee.name:
494+
return None
495+
496+
func_name = callee.name
497+
498+
# Safely access the API globals
499+
if not hasattr(ctx.api, "globals") or not isinstance(ctx.api.globals, dict):
500+
return None
501+
502+
if func_name not in ctx.api.globals:
503+
return None
504+
505+
symbol = ctx.api.globals[func_name]
506+
507+
# Validate symbol structure before accessing node
508+
if not hasattr(symbol, "node") or symbol.node is None:
509+
return None
510+
511+
# Check if this is a decorator node containing our function
512+
if isinstance(symbol.node, Decorator):
513+
decorator_node = symbol.node
514+
515+
# Safely access the decorated function
516+
if not hasattr(decorator_node, "func") or decorator_node.func is None:
517+
return None
518+
519+
func_def = decorator_node.func
520+
521+
# Verify we have a callable type
522+
if hasattr(func_def, "type") and isinstance(func_def.type, CallableType):
523+
return func_def.type
524+
525+
return None
526+
except (AttributeError, TypeError, KeyError):
527+
# If anything goes wrong in AST traversal, fail gracefully
528+
return None

test-data/unit/check-functools.test

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,3 +726,135 @@ def outer_c(arg: Tc) -> None:
726726
use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \
727727
# N: "partial[str].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], str]"
728728
[builtins fixtures/tuple.pyi]
729+
730+
[case testLruCacheBasicValidation]
731+
from functools import lru_cache
732+
733+
@lru_cache
734+
def f(v: str, at: int) -> str:
735+
return v
736+
737+
f() # E: Missing positional arguments "v", "at" in call to "f"
738+
f("abc") # E: Missing positional argument "at" in call to "f"
739+
f("abc", 123) # OK
740+
f("abc", at=123) # OK
741+
f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int"
742+
[builtins fixtures/dict.pyi]
743+
744+
[case testLruCacheWithReturnType]
745+
from functools import lru_cache
746+
747+
@lru_cache
748+
def multiply(x: int, y: int) -> int:
749+
return 42
750+
751+
reveal_type(multiply) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int"
752+
reveal_type(multiply(2, 3)) # N: Revealed type is "builtins.int"
753+
multiply("a", 3) # E: Argument 1 to "multiply" has incompatible type "str"; expected "int"
754+
multiply(2, "b") # E: Argument 2 to "multiply" has incompatible type "str"; expected "int"
755+
multiply(2) # E: Missing positional argument "y" in call to "multiply"
756+
multiply(1, 2, 3) # E: Too many arguments for "multiply"
757+
[builtins fixtures/dict.pyi]
758+
759+
[case testLruCacheWithOptionalArgs]
760+
from functools import lru_cache
761+
762+
@lru_cache
763+
def greet(name: str, greeting: str = "Hello") -> str:
764+
return "result"
765+
766+
greet("World") # OK
767+
greet("World", "Hi") # OK
768+
greet("World", greeting="Hi") # OK
769+
greet() # E: Missing positional argument "name" in call to "greet"
770+
greet(123) # E: Argument 1 to "greet" has incompatible type "int"; expected "str"
771+
greet("World", 123) # E: Argument 2 to "greet" has incompatible type "int"; expected "str"
772+
[builtins fixtures/dict.pyi]
773+
774+
[case testLruCacheGenericFunction]
775+
from functools import lru_cache
776+
from typing import TypeVar
777+
778+
T = TypeVar('T')
779+
780+
@lru_cache
781+
def identity(x: T) -> T:
782+
return x
783+
784+
reveal_type(identity(42)) # N: Revealed type is "builtins.int"
785+
reveal_type(identity("hello")) # N: Revealed type is "builtins.str"
786+
identity() # E: Missing positional argument "x" in call to "identity"
787+
[builtins fixtures/dict.pyi]
788+
789+
[case testLruCacheWithParentheses]
790+
from functools import lru_cache
791+
792+
@lru_cache()
793+
def f(v: str, at: int) -> str:
794+
return v
795+
796+
f() # E: Missing positional arguments "v", "at" in call to "f"
797+
f("abc") # E: Missing positional argument "at" in call to "f"
798+
f("abc", 123) # OK
799+
f("abc", at=123) # OK
800+
f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int"
801+
[builtins fixtures/dict.pyi]
802+
803+
[case testLruCacheWithMaxsize]
804+
from functools import lru_cache
805+
806+
@lru_cache(maxsize=128)
807+
def g(v: str, at: int) -> str:
808+
return v
809+
810+
g() # E: Missing positional arguments "v", "at" in call to "g"
811+
g("abc") # E: Missing positional argument "at" in call to "g"
812+
g("abc", 123) # OK
813+
g("abc", at=123) # OK
814+
g("abc", at="wrong_type") # E: Argument "at" to "g" has incompatible type "str"; expected "int"
815+
[builtins fixtures/dict.pyi]
816+
817+
[case testLruCacheGenericWithParameters]
818+
from functools import lru_cache
819+
from typing import TypeVar
820+
821+
T = TypeVar('T')
822+
823+
@lru_cache()
824+
def identity_empty(x: T) -> T:
825+
return x
826+
827+
@lru_cache(maxsize=128)
828+
def identity_maxsize(x: T) -> T:
829+
return x
830+
831+
reveal_type(identity_empty(42)) # N: Revealed type is "builtins.int"
832+
reveal_type(identity_maxsize("hello")) # N: Revealed type is "builtins.str"
833+
identity_empty() # E: Missing positional argument "x" in call to "identity_empty"
834+
identity_maxsize() # E: Missing positional argument "x" in call to "identity_maxsize"
835+
[builtins fixtures/dict.pyi]
836+
837+
[case testLruCacheMaxsizeNone]
838+
from functools import lru_cache
839+
840+
@lru_cache(maxsize=None)
841+
def unlimited_cache(x: int, y: str) -> str:
842+
return y
843+
844+
unlimited_cache(42, "test") # OK
845+
unlimited_cache() # E: Missing positional arguments "x", "y" in call to "unlimited_cache"
846+
unlimited_cache(42) # E: Missing positional argument "y" in call to "unlimited_cache"
847+
unlimited_cache("wrong", "test") # E: Argument 1 to "unlimited_cache" has incompatible type "str"; expected "int"
848+
[builtins fixtures/dict.pyi]
849+
850+
[case testLruCacheMaxsizeZero]
851+
from functools import lru_cache
852+
853+
@lru_cache(maxsize=0)
854+
def no_cache(value: str) -> str:
855+
return value
856+
857+
no_cache("hello") # OK
858+
no_cache() # E: Missing positional argument "value" in call to "no_cache"
859+
no_cache(123) # E: Argument 1 to "no_cache" has incompatible type "int"; expected "str"
860+
[builtins fixtures/dict.pyi]

test-data/unit/lib-stub/functools.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,11 @@ class cached_property(Generic[_T]):
3737
class partial(Generic[_T]):
3838
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
3939
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...
40+
41+
class _lru_cache_wrapper(Generic[_T]):
42+
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...
43+
44+
@overload
45+
def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ...
46+
@overload
47+
def lru_cache(__func: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ...

0 commit comments

Comments
 (0)