Skip to content

Commit bc8c9f8

Browse files
committed
Fix bulk_update raising syntax error for ForeignKey fields
When bulk_update was called with a ForeignKey field name (e.g., fields=['post']), the generated SQL used the Python object repr instead of the FK column value, causing a syntax error. Now FK fields are properly resolved to their source_field column name and the related object's PK value is used in the SQL.
1 parent 314273c commit bc8c9f8

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

tests/test_update.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ async def test_bulk_update_pk_uuid(db):
9797
assert (await UUIDFields.get(pk=objs[1].pk)).data == objs[1].data
9898

9999

100+
@pytest.mark.asyncio
101+
async def test_bulk_update_foreign_key(db):
102+
tournament1 = await Tournament.create(name="t1")
103+
tournament2 = await Tournament.create(name="t2")
104+
events = [
105+
await Event.create(name="e1", tournament=tournament1),
106+
await Event.create(name="e2", tournament=tournament1),
107+
]
108+
events[0].tournament = tournament2
109+
events[1].tournament = tournament2
110+
rows_affected = await Event.bulk_update(events, fields=["tournament"])
111+
assert rows_affected == 2
112+
e1 = await Event.get(pk=events[0].pk).select_related("tournament")
113+
e2 = await Event.get(pk=events[1].pk).select_related("tournament")
114+
assert e1.tournament.pk == tournament2.pk
115+
assert e2.tournament.pk == tournament2.pk
116+
117+
100118
@pytest.mark.asyncio
101119
async def test_bulk_renamed_pk_source_field(db):
102120
objs = [

tortoise/queryset.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,25 +1959,38 @@ def _make_queries(self) -> list[tuple[str, list[Any]]]:
19591959
for field in self.fields:
19601960
case = Case()
19611961
pk_list = []
1962+
field_obj = self.model._meta.fields_map[field]
1963+
is_fk = isinstance(field_obj, (ForeignKeyFieldInstance, OneToOneFieldInstance))
1964+
if is_fk:
1965+
fk_field = field_obj.source_field
1966+
underlying_field_obj = self.model._meta.fields_map[fk_field]
1967+
db_column = underlying_field_obj.source_field
1968+
else:
1969+
underlying_field_obj = field_obj
1970+
db_column = self.model._meta.fields_db_projection[field]
19621971
for obj in objects_item:
19631972
pk_value = self.model._meta.fields_map[pk_attr].to_db_value(obj.pk, None)
1964-
field_obj = obj._meta.fields_map[field]
1965-
field_value = field_obj.to_db_value(getattr(obj, field), obj)
1966-
case.when(
1967-
pk == pk_value,
1968-
(
1969-
Cast(
1970-
self.query._wrapper_cls(field_value),
1971-
field_obj.get_for_dialect(
1972-
self._db.schema_generator.DIALECT, "SQL_TYPE"
1973-
),
1974-
)
1975-
if self._db.schema_generator.DIALECT == "postgres"
1976-
else self.query._wrapper_cls(field_value)
1977-
),
1978-
)
1973+
if is_fk:
1974+
related_obj = getattr(obj, field)
1975+
self.model._validate_relation_type(field, related_obj)
1976+
field_value = underlying_field_obj.to_db_value(
1977+
getattr(related_obj, field_obj.to_field_instance.model_field_name),
1978+
None,
1979+
)
1980+
else:
1981+
field_value = underlying_field_obj.to_db_value(getattr(obj, field), obj)
1982+
if self._db.schema_generator.DIALECT == "postgres":
1983+
value_expr = Cast(
1984+
self.query._wrapper_cls(field_value),
1985+
underlying_field_obj.get_for_dialect(
1986+
self._db.schema_generator.DIALECT, "SQL_TYPE"
1987+
),
1988+
)
1989+
else:
1990+
value_expr = self.query._wrapper_cls(field_value)
1991+
case.when(pk == pk_value, value_expr)
19791992
pk_list.append(pk_value)
1980-
query = query.set(field, case)
1993+
query = query.set(db_column, case)
19811994
query = query.where(pk.isin(pk_list))
19821995
self._queries.append(query)
19831996
return [query.get_parameterized_sql() for query in self._queries]

0 commit comments

Comments
 (0)