Skip to content

Commit

Permalink
refactor: Use version-specific refresh_from_db patch
Browse files Browse the repository at this point in the history
  • Loading branch information
last-partizan committed Dec 1, 2024
1 parent 1ec8e73 commit a371e1e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
44 changes: 43 additions & 1 deletion modeltranslation/_compat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable

import django
from typing import Iterable
from typing import Optional

if TYPE_CHECKING:
from django.db.models import QuerySet
from django.db.models.fields.reverse_related import ForeignObjectRel


Expand All @@ -25,7 +28,46 @@ def clear_ForeignObjectRel_caches(field: ForeignObjectRel):
field.__dict__.pop(name, None)


def build_refresh_from_db(
old_refresh_from_db: Callable[
[Any, Optional[str], Optional[Iterable[str]], QuerySet[Any] | None], None
],
):
from modeltranslation.manager import append_translated

def refresh_from_db(
self: Any,
using: str | None = None,
fields: Iterable[str] | None = None,
from_queryset: QuerySet[Any] | None = None,
) -> None:
if fields is not None:
fields = append_translated(self.__class__, fields)
return old_refresh_from_db(self, using, fields, from_queryset)

return refresh_from_db


if django.VERSION <= (5, 1):

def is_hidden(field: ForeignObjectRel) -> bool:
return field.is_hidden()


if django.VERSION <= (5, 0):
# Django versions below 5.0 do not have `from_queryset` argument.
def build_refresh_from_db( # type: ignore[misc]
old_refresh_from_db: Callable[[Any, Optional[str], Optional[Iterable[str]]], None],
):
from modeltranslation.manager import append_translated

def refresh_from_db(
self: Any,
using: str | None = None,
fields: Iterable[str] | None = None,
) -> None:
if fields is not None:
fields = append_translated(self.__class__, fields)
return old_refresh_from_db(self, using, fields)

return refresh_from_db
13 changes: 2 additions & 11 deletions modeltranslation/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from modeltranslation.manager import (
MultilingualManager,
MultilingualQuerysetManager,
append_translated,
rewrite_lookup_key,
)
from modeltranslation.thread_context import auto_populate_mode
Expand All @@ -43,7 +42,7 @@
# Re-export the decorator for convenience
from modeltranslation.decorators import register

from ._compat import is_hidden
from ._compat import is_hidden, build_refresh_from_db
from ._typing import _ListOrTuple

__all__ = [
Expand Down Expand Up @@ -374,16 +373,8 @@ def patch_refresh_from_db(model: type[Model]) -> None:
"""
if not hasattr(model, "refresh_from_db"):
return
old_refresh_from_db = model.refresh_from_db

def new_refresh_from_db(
self, using: str | None = None, fields: Iterable[str] | None = None
) -> None:
if fields is not None:
fields = append_translated(self.__class__, fields)
return old_refresh_from_db(self, using, fields)

model.refresh_from_db = new_refresh_from_db
model.refresh_from_db = build_refresh_from_db(model.refresh_from_db)


def delete_cache_fields(model: type[Model]) -> None:
Expand Down

0 comments on commit a371e1e

Please sign in to comment.