Skip to content

Commit 82a56bb

Browse files
committed
Inspect base database back-end instance to figure out base classes
The PostGIS back-end doesn't set `DatabaseWrapper.ops_class` properly and instead initializes its own `Operations` class in the constructor. This results in django-postgres-extra subclassing the wrong base class. As a solution, we create an instance and inspect what class is actually used for `DatabaseWrapper.ops_class` and `DatabaseWrapper.introspection_class` and subclass from that. [#181799346]
1 parent bf9cc99 commit 82a56bb

File tree

5 files changed

+73
-47
lines changed

5 files changed

+73
-47
lines changed

psqlextra/backend/base.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,34 @@ class DatabaseWrapper(base_impl.backend()):
2424

2525
def __init__(self, *args, **kwargs):
2626
super().__init__(*args, **kwargs)
27-
if not isinstance(self.ops, PostgresOperations):
28-
# PostGis replaces the ops object instead of setting the ops_class attribute
29-
if self.ops.compiler_module != 'django.db.models.sql.compiler':
30-
raise NotImplementedError(
31-
f'''The Django ops object has been replaced by {self.ops} and a custom compiler module {self.ops.compiler_module} has been set.
32-
Replacing both at the same time is incompatible with psqlextra. '''
33-
)
34-
self.ops._compiler_cache = None
35-
self.ops.compiler_module = 'psqlextra.compiler'
3627

28+
# Some base back-ends such as the PostGIS back-end don't properly
29+
# set `ops_class` and `introspection_class` and initialize these
30+
# classes themselves.
31+
#
32+
# This can lead to broken functionality. We fix this automatically.
33+
34+
if not isinstance(self.ops, self.introspection_class):
35+
self.introspection = self.introspection_class(self)
36+
37+
if not isinstance(self.ops, self.ops_class):
38+
self.ops = self.ops_class(self)
39+
40+
for expected_compiler_class in self.ops.compiler_classes:
41+
compiler_class = self.ops.compiler(expected_compiler_class.__name__)
42+
43+
if not issubclass(compiler_class, expected_compiler_class):
44+
logger.warning(
45+
"Compiler '%s.%s' is not properly deriving from '%s.%s'."
46+
% (
47+
compiler_class.__module__,
48+
compiler_class.__name__,
49+
expected_compiler_class.__module__,
50+
expected_compiler_class.__name__,
51+
)
52+
)
3753

38-
def prepare_database(self):
54+
def prepare_database(self):
3955
"""Ran to prepare the configured database.
4056
4157
This is where we enable the `hstore` extension if it wasn't

psqlextra/backend/base_impl.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import importlib
2-
import os
32

43
from django.conf import settings
54
from django.core.exceptions import ImproperlyConfigured
5+
from django.db import DEFAULT_DB_ALIAS, connections
66

77
from django.db.backends.postgresql.base import ( # isort:skip
88
DatabaseWrapper as Psycopg2DatabaseWrapper,
99
)
1010

1111

12-
def backend():
13-
"""Gets the base class for the custom database back-end.
12+
def base_backend_instance():
13+
"""Gets an instance of the base class for the custom database back-end.
1414
1515
This should be the Django PostgreSQL back-end. However,
1616
some people are already using a custom back-end from
@@ -20,11 +20,15 @@ def backend():
2020
As long as the specified base eventually also has
2121
the PostgreSQL back-end as a base, then everything should
2222
work as intended.
23+
24+
We create an instance to inspect what classes to subclass
25+
because not all back-ends set properties such as `ops_class`
26+
properly. The PostGIS back-end is a good example.
2327
"""
2428
base_class_name = getattr(
2529
settings,
2630
"POSTGRES_EXTRA_DB_BACKEND_BASE",
27-
os.environ.get("POSTGRES_EXTRA_DB_BACKEND_BASE") or "django.db.backends.postgresql",
31+
"django.db.backends.postgresql",
2832
)
2933

3034
base_class_module = importlib.import_module(base_class_name + ".base")
@@ -50,7 +54,24 @@ def backend():
5054
% base_class_name
5155
)
5256

53-
return base_class
57+
base_instance = base_class(connections.databases[DEFAULT_DB_ALIAS])
58+
if base_instance.connection:
59+
raise ImproperlyConfigured(
60+
(
61+
"'%s' establishes a connection during initialization."
62+
" This is not expected and can lead to more connections"
63+
" being established than neccesarry."
64+
)
65+
% base_class_name
66+
)
67+
68+
return base_instance
69+
70+
71+
def backend():
72+
"""Gets the base class for the database back-end."""
73+
74+
return base_backend_instance().__class__
5475

5576

5677
def schema_editor():
@@ -60,7 +81,7 @@ def schema_editor():
6081
this.
6182
"""
6283

63-
return backend().SchemaEditorClass
84+
return base_backend_instance().SchemaEditorClass
6485

6586

6687
def introspection():
@@ -70,7 +91,7 @@ def introspection():
7091
for this.
7192
"""
7293

73-
return backend().introspection_class
94+
return base_backend_instance().introspection.__class__
7495

7596

7697
def operations():
@@ -80,4 +101,4 @@ def operations():
80101
this.
81102
"""
82103

83-
return backend().ops_class
104+
return base_backend_instance().ops.__class__

psqlextra/backend/operations.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from importlib import import_module
2-
31
from psqlextra.compiler import (
42
SQLAggregateCompiler,
53
SQLCompiler,
@@ -14,16 +12,12 @@
1412
class PostgresOperations(base_impl.operations()):
1513
"""Simple operations specific to PostgreSQL."""
1614

17-
def __init__(self, *args, **kwargs):
18-
super().__init__(*args, **kwargs)
19-
20-
self._compiler_cache = None
21-
22-
def compiler(self, compiler_name: str):
23-
"""Gets the SQL compiler with the specified name."""
24-
25-
if self._cache is None:
26-
self._cache = import_module('psqlextra.compiler')
15+
compiler_module = "psqlextra.compiler"
2716

28-
# Let any parent module try to find the compiler as fallback. Better run without caller comment than break
29-
return getattr(self._cache, compiler_name, super().compiler(compiler_name))
17+
compiler_classes = [
18+
SQLCompiler,
19+
SQLDeleteCompiler,
20+
SQLAggregateCompiler,
21+
SQLUpdateCompiler,
22+
SQLInsertCompiler,
23+
]

psqlextra/compiler.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,7 @@
1111
from django.core.exceptions import SuspiciousOperation
1212
from django.db.models import Expression, Model, Q
1313
from django.db.models.fields.related import RelatedField
14-
from django.db.models.sql.compiler import (
15-
SQLAggregateCompiler as DjangoSQLAggregateCompiler,
16-
SQLCompiler as DjangoSQLCompiler,
17-
SQLDeleteCompiler as DjangoSQLDeleteCompiler,
18-
SQLInsertCompiler as DjangoSQLInsertCompiler,
19-
SQLUpdateCompiler as DjangoSQLUpdateCompiler,
20-
)
14+
from django.db.models.sql import compiler as django_compiler
2115
from django.db.utils import ProgrammingError
2216

2317
from .expressions import HStoreValue
@@ -77,25 +71,25 @@ def append_caller_to_sql(sql):
7771
return sql
7872

7973

80-
class SQLCompiler(DjangoSQLCompiler):
74+
class SQLCompiler(django_compiler.SQLCompiler):
8175
def as_sql(self, *args, **kwargs):
8276
sql, params = super().as_sql(*args, **kwargs)
8377
return append_caller_to_sql(sql), params
8478

8579

86-
class SQLDeleteCompiler(DjangoSQLDeleteCompiler):
80+
class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler):
8781
def as_sql(self, *args, **kwargs):
8882
sql, params = super().as_sql(*args, **kwargs)
8983
return append_caller_to_sql(sql), params
9084

9185

92-
class SQLAggregateCompiler(DjangoSQLAggregateCompiler):
86+
class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler):
9387
def as_sql(self, *args, **kwargs):
9488
sql, params = super().as_sql(*args, **kwargs)
9589
return append_caller_to_sql(sql), params
9690

9791

98-
class SQLUpdateCompiler(DjangoSQLUpdateCompiler):
92+
class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler):
9993
"""Compiler for SQL UPDATE statements that allows us to use expressions
10094
inside HStore values.
10195
@@ -152,7 +146,7 @@ def _does_dict_contain_expression(data: dict) -> bool:
152146
return False
153147

154148

155-
class SQLInsertCompiler(DjangoSQLInsertCompiler):
149+
class SQLInsertCompiler(django_compiler.SQLInsertCompiler):
156150
"""Compiler for SQL INSERT statements."""
157151

158152
def as_sql(self, *args, **kwargs):
@@ -165,7 +159,7 @@ def as_sql(self, *args, **kwargs):
165159
return queries
166160

167161

168-
class PostgresInsertOnConflictCompiler(DjangoSQLInsertCompiler):
162+
class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler):
169163
"""Compiler for SQL INSERT statements."""
170164

171165
def __init__(self, *args, **kwargs):
@@ -407,7 +401,7 @@ def _format_field_value(self, field_name) -> str:
407401
if isinstance(field, RelatedField) and isinstance(value, Model):
408402
value = value.pk
409403

410-
return DjangoSQLInsertCompiler.prepare_value(
404+
return django_compiler.SQLInsertCompiler.prepare_value(
411405
self,
412406
field,
413407
# Note: this deliberately doesn't use `pre_save_val` as we don't

psqlextra/sql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from django.db.models import sql
99
from django.db.models.constants import LOOKUP_SEP
1010

11-
from .compiler import PostgresInsertOnConflictCompiler, SQLUpdateCompiler as PostgresUpdateCompiler
11+
from .compiler import PostgresInsertOnConflictCompiler
12+
from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler
1213
from .expressions import HStoreColumn
1314
from .fields import HStoreField
1415
from .types import ConflictAction

0 commit comments

Comments
 (0)