Skip to content

Commit 1fccf9a

Browse files
committed
Apply converters to results return in upsert_and_get
1 parent b93345c commit 1fccf9a

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

psqlextra/query.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, Iterable, List, Optional, Tuple, Union
44

55
from django.core.exceptions import SuspiciousOperation
6-
from django.db import models, router
6+
from django.db import connections, models, router
77
from django.db.models import Expression
88
from django.db.models.fields import NOT_PROVIDED
99

@@ -389,14 +389,36 @@ def is_empty(r):
389389
)
390390
return self.bulk_insert(rows, return_model, using=using)
391391

392-
def _create_model_instance(self, field_values, using: Optional[str] = None):
392+
def _create_model_instance(
393+
self, field_values: dict, using: str, apply_converters: bool = True
394+
):
393395
"""Creates a new instance of the model with the specified field.
394396
395397
Use this after the row was inserted into the database. The new
396398
instance will marked as "saved".
397399
"""
398400

399-
instance = self.model(**field_values)
401+
converted_field_values = field_values.copy()
402+
403+
if apply_converters:
404+
connection = connections[using]
405+
406+
for field in self.model._meta.local_concrete_fields:
407+
if field.attname not in converted_field_values:
408+
continue
409+
410+
# converters can be defined on the field, or by
411+
# the database back-end we're using
412+
converters = field.get_db_converters(
413+
connection
414+
) + connection.ops.get_db_converters(field)
415+
416+
for converter in converters:
417+
converted_field_values[field.attname] = converter(
418+
converted_field_values[field.attname], field, connection
419+
)
420+
421+
instance = self.model(**converted_field_values)
400422
instance._state.db = using
401423
instance._state.adding = False
402424

@@ -444,7 +466,9 @@ def _build_insert_compiler(
444466
).format(index)
445467
)
446468

447-
objs.append(self._create_model_instance(row, using))
469+
objs.append(
470+
self._create_model_instance(row, using, apply_converters=False)
471+
)
448472

449473
# get the fields to be used during update/insert
450474
insert_fields, update_fields = self._get_upsert_fields(first_row)

tests/test_upsert.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,22 @@ def test_upsert_with_update_condition():
106106
assert not obj1.active
107107

108108

109+
def test_upsert_and_get_applies_converters():
110+
"""Tests that converters are properly applied when using upsert_and_get."""
111+
112+
class MyCustomField(models.TextField):
113+
def from_db_value(self, value, expression, connection):
114+
return value.replace("hello", "bye")
115+
116+
model = get_fake_model({"title": MyCustomField(unique=True)})
117+
118+
obj = model.objects.upsert_and_get(
119+
conflict_target=["title"], fields=dict(title="hello")
120+
)
121+
122+
assert obj.title == "bye"
123+
124+
109125
def test_upsert_bulk():
110126
"""Tests whether bulk_upsert works properly."""
111127

0 commit comments

Comments
 (0)