Skip to content

Commit f6dc4da

Browse files
authored
fix get_bind with polymorphic table inheritance (#1157)
2 parents 38df842 + 0dcaff3 commit f6dc4da

File tree

4 files changed

+74
-20
lines changed

4 files changed

+74
-20
lines changed

Diff for: CHANGES.rst

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Unreleased
55

66
- Show helpful errors when mistakenly using multiple ``SQLAlchemy`` instances for the
77
same app, or without calling ``init_app``. :pr:`1151`
8+
- Fix issue with getting the engine associated with a model that uses polymorphic
9+
table inheritance. :issue:`1155`
810

911

1012
Version 3.0.2

Diff for: setup.cfg

+13-10
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
# Config goes in pyproject.toml unless a tool doesn't support that.
22

33
[flake8]
4-
# B = bugbear
5-
# E = pycodestyle errors
6-
# F = flake8 pyflakes
7-
# W = pycodestyle warnings
8-
# B9 = bugbear opinions
9-
# ISC = implicit-str-concat
10-
select = B, E, F, W, B9, ISC
11-
ignore =
4+
extend-select =
5+
# bugbear
6+
B
7+
# bugbear opinions
8+
B9
9+
# implicit str concat
10+
ISC
11+
extend-ignore =
1212
# slice notation whitespace, invalid
1313
E203
1414
# line length, handled by bugbear B950
1515
E501
1616
# bare except, handled by bugbear B001
1717
E722
18-
# bin op line break, invalid
19-
W503
18+
# zip with strict=, requires python >= 3.10
19+
B905
20+
# string formatting opinion, B028 renamed to B907
21+
B028
22+
B907
2023
# up to 88 allowed by bugbear B950
2124
max-line-length = 80

Diff for: src/flask_sqlalchemy/session.py

+31-10
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def get_bind( # type: ignore[override]
3838
"""Select an engine based on the ``bind_key`` of the metadata associated with
3939
the model or table being queried. If no bind key is set, uses the default bind.
4040
41+
.. versionchanged:: 3.0.3
42+
Fix finding the bind for a joined inheritance model.
43+
4144
.. versionchanged:: 3.0
4245
The implementation more closely matches the base SQLAlchemy implementation.
4346
@@ -47,6 +50,8 @@ def get_bind( # type: ignore[override]
4750
if bind is not None:
4851
return bind
4952

53+
engines = self._db.engines
54+
5055
if mapper is not None:
5156
try:
5257
mapper = sa.inspect(mapper)
@@ -56,26 +61,42 @@ def get_bind( # type: ignore[override]
5661

5762
raise
5863

59-
clause = mapper.persist_selectable
64+
engine = _clause_to_engine(mapper.local_table, engines)
6065

61-
engines = self._db.engines
62-
63-
if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info:
64-
key = clause.metadata.info["bind_key"]
66+
if engine is not None:
67+
return engine
6568

66-
if key not in engines:
67-
raise sa.exc.UnboundExecutionError(
68-
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
69-
)
69+
if clause is not None:
70+
engine = _clause_to_engine(clause, engines)
7071

71-
return engines[key]
72+
if engine is not None:
73+
return engine
7274

7375
if None in engines:
7476
return engines[None]
7577

7678
return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs)
7779

7880

81+
def _clause_to_engine(
82+
clause: t.Any | None, engines: t.Mapping[str | None, sa.engine.Engine]
83+
) -> sa.engine.Engine | None:
84+
"""If the clause is a table, return the engine associated with the table's
85+
metadata's bind key.
86+
"""
87+
if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info:
88+
key = clause.metadata.info["bind_key"]
89+
90+
if key not in engines:
91+
raise sa.exc.UnboundExecutionError(
92+
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
93+
)
94+
95+
return engines[key]
96+
97+
return None
98+
99+
79100
def _app_ctx_id() -> int:
80101
"""Get the id of the current Flask application context for the session scope."""
81102
return id(app_ctx._get_current_object()) # type: ignore[attr-defined]

Diff for: tests/test_session.py

+28
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,31 @@ class Post(db.Model):
6464

6565
assert db.session.get_bind(mapper=User) is db.engine
6666
assert db.session.get_bind(mapper=Post) is db.engines["a"]
67+
68+
69+
@pytest.mark.usefixtures("app_ctx")
70+
def test_get_bind_inheritance(app: Flask) -> None:
71+
app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"}
72+
db = SQLAlchemy(app)
73+
74+
class User(db.Model):
75+
__bind_key__ = "a"
76+
id = sa.Column(sa.Integer, primary_key=True)
77+
type = sa.Column(sa.String, nullable=False)
78+
79+
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}
80+
81+
class Admin(User):
82+
id = sa.Column( # type: ignore[assignment]
83+
sa.ForeignKey(User.id), primary_key=True
84+
)
85+
org = sa.Column(sa.String, nullable=False)
86+
87+
__mapper_args__ = {"polymorphic_identity": "admin"}
88+
89+
db.create_all()
90+
db.session.add(Admin(org="pallets"))
91+
db.session.commit()
92+
admin = db.session.execute(db.select(Admin)).scalar_one()
93+
db.session.expire(admin)
94+
assert admin.org == "pallets"

0 commit comments

Comments
 (0)