diff --git a/polymorphic/query.py b/polymorphic/query.py index 86774898..5eca6334 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -112,6 +112,10 @@ def __init__(self, *args, **kwargs): # to that queryset as well). self.polymorphic_deferred_loading = (set(), True) + self._polymorphic_select_related = {} + self._polymorphic_prefetch_related = {} + self._polymorphic_custom_queryset = {} + def _clone(self, *args, **kwargs): # Django's _clone only copies its own variables, so we need to copy ours here new = super()._clone(*args, **kwargs) @@ -120,6 +124,9 @@ def _clone(self, *args, **kwargs): copy.copy(self.polymorphic_deferred_loading[0]), self.polymorphic_deferred_loading[1], ) + new._polymorphic_select_related = copy.copy(self._polymorphic_select_related) + new._polymorphic_prefetch_related = copy.copy(self._polymorphic_prefetch_related) + new._polymorphic_custom_queryset = copy.copy(self._polymorphic_custom_queryset) return new def as_manager(cls): @@ -417,12 +424,30 @@ class self.model, but as a class derived from self.model. We want to re-fetch # TODO: defer(), only(): support for these would be around here for real_concrete_class, idlist in idlist_per_model.items(): indices = indexlist_per_model[real_concrete_class] - real_objects = real_concrete_class._base_objects.db_manager(self.db).filter( + if self._polymorphic_custom_queryset.get(real_concrete_class): + real_objects = self._polymorphic_custom_queryset[real_concrete_class] + else: + real_objects = real_concrete_class._base_objects.db_manager(self.db) + + real_objects = real_objects.filter( **{("%s__in" % pk_name): idlist} ) - # copy select related configuration to new qs + + # copy select_related() fields from base objects to real objects real_objects.query.select_related = self.query.select_related + # polymorphic select_related() fields if any + if real_concrete_class in self._polymorphic_select_related: + real_objects = real_objects.select_related( + *self._polymorphic_select_related[real_concrete_class] + ) + + # polymorphic prefetch related configuration to new qs + if real_concrete_class in self._polymorphic_prefetch_related: + real_objects = real_objects.prefetch_related( + *self._polymorphic_prefetch_related[real_concrete_class] + ) + # Copy deferred fields configuration to the new queryset deferred_loading_fields = [] existing_fields = self.polymorphic_deferred_loading[0] @@ -535,3 +560,22 @@ def get_real_instances(self, base_result_objects=None): return olist clist = PolymorphicQuerySet._p_list_class(olist) return clist + + def select_polymorphic_related(self, polymorphic_subclass, *fields): + if self.query.select_related is True: + raise ValueError( + "select_polymorphic_related() cannot be used together with select_related=True" + ) + clone = self._clone() + clone._polymorphic_select_related[polymorphic_subclass] = fields + return clone + + def prefetch_polymorphic_related(self, polymorphic_subclass, *lookups): + clone = self._clone() + clone._polymorphic_prefetch_related[polymorphic_subclass] = lookups + return clone + + def custom_queryset(self, polymorphic_subclass, queryset): + clone = self._clone() + clone._polymorphic_custom_queryset[polymorphic_subclass] = queryset + return clone