Skip to content

Commit 7e3f4bf

Browse files
authored
Fix ForeignKey type for self-reference defined in the abstract model (#200)
1 parent db9ff6a commit 7e3f4bf

File tree

6 files changed

+82
-16
lines changed

6 files changed

+82
-16
lines changed

mypy_django_plugin/django/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,21 @@ def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], *, method
209209
return expected_types
210210

211211
@cached_property
212-
def model_base_classes(self) -> Set[str]:
212+
def all_registered_model_classes(self) -> Set[Type[models.Model]]:
213213
model_classes = self.apps_registry.get_models()
214214

215215
all_model_bases = set()
216216
for model_cls in model_classes:
217217
for base_cls in model_cls.mro():
218218
if issubclass(base_cls, models.Model):
219-
all_model_bases.add(helpers.get_class_fullname(base_cls))
219+
all_model_bases.add(base_cls)
220220

221221
return all_model_bases
222222

223+
@cached_property
224+
def all_registered_model_class_fullnames(self) -> Set[str]:
225+
return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes}
226+
223227
def get_attname(self, field: Field) -> str:
224228
attname = field.attname
225229
return attname

mypy_django_plugin/lib/helpers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont
282282

283283

284284
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
285-
return (info.fullname() in django_context.model_base_classes
285+
return (info.fullname() in django_context.all_registered_model_class_fullnames
286286
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
287287

288288

@@ -292,3 +292,15 @@ def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
292292
api.check_subtype(actual_type, expected_type,
293293
ctx.context, error_message,
294294
'got', 'expected')
295+
296+
297+
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
298+
# type=: type of the variable itself
299+
var = Var(name=name, type=sym_type)
300+
# var.info: type of the object variable is bound to
301+
var.info = info
302+
var._fullname = info.fullname() + '.' + name
303+
var.is_initialized_in_class = True
304+
var.is_inferred = True
305+
info.names[name] = SymbolTableNode(MDEF, var,
306+
plugin_generated=True)

mypy_django_plugin/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_method_hook(self, fullname: str
219219

220220
def get_base_class_hook(self, fullname: str
221221
) -> Optional[Callable[[ClassDefContext], None]]:
222-
if (fullname in self.django_context.model_base_classes
222+
if (fullname in self.django_context.all_registered_model_class_fullnames
223223
or fullname in self._get_current_model_bases()):
224224
return partial(transform_model_class, django_context=self.django_context)
225225

mypy_django_plugin/transformers/fields.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def _get_current_field_from_assignment(ctx: FunctionContext, django_context: Dja
3737
return current_field
3838

3939

40+
def reparametrize_related_field_type(related_field_type: Instance, set_type, get_type) -> Instance:
41+
args = [
42+
helpers.convert_any_to_type(related_field_type.args[0], set_type),
43+
helpers.convert_any_to_type(related_field_type.args[1], get_type),
44+
]
45+
return helpers.reparametrize_instance(related_field_type, new_args=args)
46+
47+
4048
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
4149
current_field = _get_current_field_from_assignment(ctx, django_context)
4250
if current_field is None:
@@ -48,6 +56,25 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
4856
if related_model_cls is None:
4957
return AnyType(TypeOfAny.from_error)
5058

59+
default_related_field_type = set_descriptor_types_for_field(ctx)
60+
61+
# self reference with abstract=True on the model where ForeignKey is defined
62+
current_model_cls = current_field.model
63+
if (current_model_cls._meta.abstract
64+
and current_model_cls == related_model_cls):
65+
# for all derived non-abstract classes, set variable with this name to
66+
# __get__/__set__ of ForeignKey of derived model
67+
for model_cls in django_context.all_registered_model_classes:
68+
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
69+
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
70+
if derived_model_info is not None:
71+
fk_ref_type = Instance(derived_model_info, [])
72+
derived_fk_type = reparametrize_related_field_type(default_related_field_type,
73+
set_type=fk_ref_type, get_type=fk_ref_type)
74+
helpers.add_new_sym_for_info(derived_model_info,
75+
name=current_field.name,
76+
sym_type=derived_fk_type)
77+
5178
related_model = related_model_cls
5279
related_model_to_set = related_model_cls
5380
if related_model_to_set._meta.proxy_for_model is not None:
@@ -69,13 +96,10 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
6996
else:
7097
related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore
7198

72-
default_related_field_type = set_descriptor_types_for_field(ctx)
7399
# replace Any with referred_to_type
74-
args = [
75-
helpers.convert_any_to_type(default_related_field_type.args[0], related_model_to_set_type),
76-
helpers.convert_any_to_type(default_related_field_type.args[1], related_model_type),
77-
]
78-
return helpers.reparametrize_instance(default_related_field_type, new_args=args)
100+
return reparametrize_related_field_type(default_related_field_type,
101+
set_type=related_model_to_set_type,
102+
get_type=related_model_type)
79103

80104

81105
def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]:

mypy_django_plugin/transformers/models.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from django.db.models.fields.reverse_related import (
88
ManyToManyRel, ManyToOneRel, OneToOneRel,
99
)
10-
from mypy.nodes import (
11-
ARG_STAR2, MDEF, Argument, Context, SymbolTableNode, TypeInfo, Var,
12-
)
10+
from mypy.nodes import ARG_STAR2, Argument, Context, TypeInfo, Var
1311
from mypy.plugin import ClassDefContext
1412
from mypy.plugins import common
1513
from mypy.types import AnyType, Instance
@@ -51,8 +49,9 @@ def create_new_var(self, name: str, typ: MypyType) -> Var:
5149
return var
5250

5351
def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None:
54-
var = self.create_new_var(name, typ)
55-
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
52+
helpers.add_new_sym_for_info(self.model_classdef.info,
53+
name=name,
54+
sym_type=typ)
5655

5756
def run(self) -> None:
5857
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
@@ -114,6 +113,9 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
114113
AnyType(TypeOfAny.explicit))
115114
continue
116115

116+
if related_model_cls._meta.abstract:
117+
continue
118+
117119
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
118120
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
119121
is_nullable = self.django_context.get_field_nullability(field, None)

test-data/typecheck/fields/test_related.yml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,4 +623,28 @@
623623
class TransactionLog(models.Model):
624624
transaction = models.ForeignKey(Transaction, on_delete=models.CASCADE)
625625
626-
Transaction().test()
626+
Transaction().test()
627+
628+
629+
- case: resolve_primary_keys_for_foreign_keys_with_abstract_self_model
630+
main: |
631+
from myapp.models import User
632+
reveal_type(User().parent) # N: Revealed type is 'myapp.models.User*'
633+
reveal_type(User().parent_id) # N: Revealed type is 'builtins.int*'
634+
635+
reveal_type(User().parent2) # N: Revealed type is 'Union[myapp.models.User, None]'
636+
reveal_type(User().parent2_id) # N: Revealed type is 'Union[builtins.int, None]'
637+
installed_apps:
638+
- myapp
639+
files:
640+
- path: myapp/__init__.py
641+
- path: myapp/models.py
642+
content: |
643+
from django.db import models
644+
class AbstractUser(models.Model):
645+
parent = models.ForeignKey('self', on_delete=models.CASCADE)
646+
parent2 = models.ForeignKey('self', null=True, on_delete=models.CASCADE)
647+
class Meta:
648+
abstract = True
649+
class User(AbstractUser):
650+
pass

0 commit comments

Comments
 (0)