-
Notifications
You must be signed in to change notification settings - Fork 104
/
Copy pathcompiler.py
515 lines (395 loc) · 16.9 KB
/
compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import inspect
import os
import sys
from collections.abc import Iterable
from typing import Optional, Tuple, Union
import django
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db.models import Expression, Field, Model, Q
from django.db.models.fields.related import RelatedField
from django.db.models.sql import compiler as django_compiler
from django.db.utils import ProgrammingError
from .expressions import HStoreValue
from .types import ConflictAction, UpsertOperation
def append_caller_to_sql(sql) -> str:
"""Append the caller to SQL queries.
Adds the calling file and function as an SQL comment to each query.
Examples:
INSERT INTO "tests_47ee19d1" ("id", "title")
VALUES (1, 'Test')
RETURNING "tests_47ee19d1"."id"
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 55 */
SELECT "tests_47ee19d1"."id", "tests_47ee19d1"."title"
FROM "tests_47ee19d1"
WHERE "tests_47ee19d1"."id" = 1
LIMIT 1
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 69 */
UPDATE "tests_47ee19d1"
SET "title" = 'success'
WHERE "tests_47ee19d1"."id" = 1
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 64 */
DELETE FROM "tests_47ee19d1"
WHERE "tests_47ee19d1"."id" IN (1)
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 74 */
Slow and blocking queries could be easily tracked down to their originator
within the source code using the "pg_stat_activity" table.
Enable "POSTGRES_EXTRA_ANNOTATE_SQL" within the database settings to enable this feature.
"""
if not getattr(settings, "POSTGRES_EXTRA_ANNOTATE_SQL", None):
return sql
try:
# Search for the first non-Django caller
stack = inspect.stack()
for stack_frame in stack[1:]:
frame_filename = stack_frame[1]
frame_line = stack_frame[2]
frame_function = stack_frame[3]
if "/django/" in frame_filename or "/psqlextra/" in frame_filename:
continue
return f"{sql} /* {os.getpid()} {frame_function} {frame_filename} {frame_line} */"
# Django internal commands (like migrations) end up here
return f"{sql} /* {os.getpid()} {sys.argv[0]} */"
except Exception:
# Don't break anything because this convinence function runs into an unexpected situation
return sql
class SQLCompiler(django_compiler.SQLCompiler): # type: ignore [attr-defined]
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): # type: ignore [name-defined]
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): # type: ignore [name-defined]
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): # type: ignore [name-defined]
"""Compiler for SQL UPDATE statements that allows us to use expressions
inside HStore values.
Like:
.update(name=dict(en=F('test')))
"""
def as_sql(self, *args, **kwargs):
self._prepare_query_values()
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
def _prepare_query_values(self):
"""Extra prep on query values by converting dictionaries into.
:see:HStoreValue expressions.
This allows putting expressions in a dictionary. The
:see:HStoreValue will take care of resolving the expressions
inside the dictionary.
"""
if not self.query.values:
return
new_query_values = []
for field, model, val in self.query.values:
if not isinstance(val, dict):
new_query_values.append((field, model, val))
continue
if not self._does_dict_contain_expression(val):
new_query_values.append((field, model, val))
continue
expression = HStoreValue(dict(val))
new_query_values.append((field, model, expression))
self.query.values = new_query_values
@staticmethod
def _does_dict_contain_expression(data: dict) -> bool:
"""Gets whether the specified dictionary contains any expressions that
need to be resolved."""
for value in data.values():
if hasattr(value, "resolve_expression"):
return True
if hasattr(value, "as_sql"):
return True
return False
class SQLInsertCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
"""Compiler for SQL INSERT statements."""
def as_sql(self, *args, **kwargs):
"""Builds the SQL INSERT statement."""
queries = [
(append_caller_to_sql(sql), params)
for sql, params in super().as_sql(*args, **kwargs)
]
return queries
class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
"""Compiler for SQL INSERT statements."""
RETURNING_OPERATION_TYPE_CLAUSE = (
f"CASE WHEN xmax::text::int > 0 "
f"THEN '{UpsertOperation.UPDATE.value}' "
f"ELSE '{UpsertOperation.INSERT.value}' END"
)
RETURNING_OPERATION_TYPE_FIELD = "_operation_type"
def __init__(self, *args, **kwargs):
"""Initializes a new instance of
:see:PostgresInsertOnConflictCompiler."""
super().__init__(*args, **kwargs)
self.qn = self.connection.ops.quote_name
def as_sql(
self,
return_id=False,
return_operation_type=False,
*args,
**kwargs,
):
"""Builds the SQL INSERT statement."""
queries = [
self._rewrite_insert(sql, params, return_id, return_operation_type)
for sql, params in super().as_sql(*args, **kwargs)
]
return queries
def execute_sql(self, return_id=False, return_operation_type=False):
# execute all the generate queries
with self.connection.cursor() as cursor:
rows = []
for sql, params in self.as_sql(return_id, return_operation_type):
cursor.execute(sql, params)
try:
rows.extend(cursor.fetchall())
except ProgrammingError:
pass
description = cursor.description
# create a mapping between column names and column value
return [
{
column.name: row[column_index]
for column_index, column in enumerate(description)
if row
}
for row in rows
]
def _rewrite_insert(
self, sql, params, return_id=False, return_operation_type=False
):
"""Rewrites a formed SQL INSERT query to include the ON CONFLICT
clause.
Arguments:
sql:
The SQL INSERT query to rewrite.
params:
The parameters passed to the query.
returning:
What to put in the `RETURNING` clause
of the resulting query.
Returns:
A tuple of the rewritten SQL query and new params.
"""
returning = (
self.qn(self.query.model._meta.pk.attname) if return_id else "*"
)
# Return metadata about the row, so we can tell if it was inserted or
# updated by checking the `xmax` Postgres system column.
if return_operation_type:
returning += f", ({self.RETURNING_OPERATION_TYPE_CLAUSE}) AS {self.RETURNING_OPERATION_TYPE_FIELD}"
(sql, params) = self._rewrite_insert_on_conflict(
sql,
params,
self.query.conflict_action.value,
returning,
)
return append_caller_to_sql(sql), params
def _rewrite_insert_on_conflict(
self,
sql: str,
params: list,
conflict_action: ConflictAction,
returning: str,
) -> Tuple[str, list]:
"""Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
clause."""
update_columns = ", ".join(
[
"{0} = EXCLUDED.{0}".format(self.qn(field.column))
for field in self.query.update_fields # type: ignore[attr-defined]
]
)
# build the conflict target, the columns to watch
# for conflicts
on_conflict_clause = self._build_on_conflict_clause()
index_predicate = self.query.index_predicate # type: ignore[attr-defined]
update_condition = self.query.conflict_update_condition # type: ignore[attr-defined]
rewritten_sql = f"{sql} {on_conflict_clause}"
if index_predicate:
expr_sql, expr_params = self._compile_expression(index_predicate)
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)
rewritten_sql += f" DO {conflict_action}"
if conflict_action == ConflictAction.UPDATE.value:
rewritten_sql += f" SET {update_columns}"
if update_condition:
expr_sql, expr_params = self._compile_expression(
update_condition
)
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)
rewritten_sql += f" RETURNING {returning}"
return (rewritten_sql, params)
def _build_on_conflict_clause(self):
if django.VERSION >= (2, 2):
from django.db.models.constraints import BaseConstraint
from django.db.models.indexes import Index
if isinstance(
self.query.conflict_target, BaseConstraint
) or isinstance(self.query.conflict_target, Index):
return "ON CONFLICT ON CONSTRAINT %s" % self.qn(
self.query.conflict_target.name
)
conflict_target = self._build_conflict_target()
return f"ON CONFLICT {conflict_target}"
def _build_conflict_target(self):
"""Builds the `conflict_target` for the ON CONFLICT clause."""
if not isinstance(self.query.conflict_target, Iterable):
raise SuspiciousOperation(
(
"%s is not a valid conflict target, specify "
"a list of column names, or tuples with column "
"names and hstore key."
)
% str(self.query.conflict_target)
)
conflict_target = self._build_conflict_target_by_index()
if conflict_target:
return conflict_target
return self._build_conflict_target_by_fields()
def _build_conflict_target_by_fields(self):
"""Builds the `conflict_target` for the ON CONFLICT clauses by matching
the fields specified in the specified conflict target against the
model's fields.
This requires some special handling because the fields names
might not be same as the column names.
"""
conflict_target = []
for field_name in self.query.conflict_target:
self._assert_valid_field(field_name)
# special handling for hstore keys
if isinstance(field_name, tuple):
conflict_target.append(
"(%s->'%s')"
% (self._format_field_name(field_name), field_name[1])
)
else:
conflict_target.append(self._format_field_name(field_name))
return "(%s)" % ",".join(conflict_target)
def _build_conflict_target_by_index(self):
"""Builds the `conflict_target` for the ON CONFLICT clause by trying to
find an index that matches the specified conflict target on the query.
Conflict targets must match some unique constraint, usually this
is a `UNIQUE INDEX`.
"""
matching_index = next(
(
index
for index in self.query.model._meta.indexes
if list(index.fields) == list(self.query.conflict_target)
),
None,
)
if not matching_index:
return None
with self.connection.schema_editor() as schema_editor:
stmt = matching_index.create_sql(self.query.model, schema_editor)
return "(%s)" % stmt.parts["columns"]
def _get_model_field(self, name: str) -> Optional[Field]:
"""Gets the field on a model with the specified name.
Arguments:
name:
The name of the field to look for.
This can be both the actual field name, or
the name of the column, both will work :)
Returns:
The field with the specified name or None if
no such field exists.
"""
field_name = self._normalize_field_name(name)
if not self.query.model:
return None
# 'pk' has special meaning and always refers to the primary
# key of a model, we have to respect this de-facto standard behaviour
if field_name == "pk" and self.query.model._meta.pk:
return self.query.model._meta.pk
for field in self.query.model._meta.local_concrete_fields: # type: ignore[attr-defined]
if field.name == field_name or field.column == field_name:
return field
return None
def _format_field_name(self, field_name):
"""Formats a field's name for usage in SQL.
Arguments:
field_name:
The field name to format.
Returns:
The specified field name formatted for
usage in SQL.
"""
field = self._get_model_field(field_name)
return self.qn(field.column)
def _format_field_value(self, field_name):
"""Formats a field's value for usage in SQL.
Arguments:
field_name:
The name of the field to format
the value of.
Returns:
The field's value formatted for usage
in SQL.
"""
field_name = self._normalize_field_name(field_name)
field = self._get_model_field(field_name)
value = getattr(self.query.objs[0], field.attname)
if isinstance(field, RelatedField) and isinstance(value, Model):
value = value.pk
return django_compiler.SQLInsertCompiler.prepare_value( # type: ignore[attr-defined]
self,
field,
# Note: this deliberately doesn't use `pre_save_val` as we don't
# want things like auto_now on DateTimeField (etc.) to change the
# value. We rely on pre_save having already been done by the
# underlying compiler so that things like FileField have already had
# the opportunity to save out their data.
value,
)
def _compile_expression(
self,
expression: Union[Expression, Q, str],
) -> Tuple[str, Union[tuple, list]]:
"""Compiles an expression, Q object or raw SQL string into SQL and
tuple of parameters."""
if isinstance(expression, Q):
if django.VERSION < (3, 1):
raise SuspiciousOperation(
"Q objects in psqlextra can only be used with Django 3.1 and newer"
)
return self.query.build_where(expression).as_sql(
self, self.connection
)
elif isinstance(expression, Expression):
return self.compile(expression)
return expression, tuple()
def _assert_valid_field(self, field_name: str) -> None:
"""Asserts that a field with the specified name exists on the model and
raises :see:SuspiciousOperation if it does not."""
field_name = self._normalize_field_name(field_name)
if self._get_model_field(field_name):
return
raise SuspiciousOperation(
(
"%s is not a valid conflict target, specify "
"a list of column names, or tuples with column "
"names and hstore key."
)
% str(field_name)
)
@staticmethod
def _normalize_field_name(field_name: str) -> str:
"""Normalizes a field name into a string by extracting the field name
if it was specified as a reference to a HStore key (as a tuple).
Arguments:
field_name:
The field name to normalize.
Returns:
The normalized field name.
"""
if isinstance(field_name, tuple):
field_name, _ = field_name
return field_name