diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index ccae54b4..9b479e74 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -10,10 +10,12 @@ import sqlalchemy.event as sa_event import sqlalchemy.exc as sa_exc import sqlalchemy.orm as sa_orm +import typing_extensions as te from flask import abort from flask import current_app from flask import Flask from flask import has_app_context +from sqlalchemy.util import typing as compat_typing from .model import _QueryProperty from .model import BindMixin @@ -32,15 +34,14 @@ # Type accepted for model_class argument -_FSA_MCT = t.TypeVar( - "_FSA_MCT", - bound=t.Union[ - t.Type[Model], - sa_orm.DeclarativeMeta, - t.Type[sa_orm.DeclarativeBase], - t.Type[sa_orm.DeclarativeBaseNoMeta], - ], -) +_FSA_MCT = t.Union[ + t.Type[Model], + sa_orm.DeclarativeMeta, + t.Type[sa_orm.DeclarativeBase], + t.Type[sa_orm.DeclarativeBaseNoMeta], + t.Type[sa_orm.MappedAsDataclass], +] +_FSA_MCT_T = t.TypeVar("_FSA_MCT_T", bound=_FSA_MCT, covariant=True) # Type returned by make_declarative_base @@ -48,6 +49,109 @@ class _FSAModel(Model): metadata: sa.MetaData +if t.TYPE_CHECKING: + + class _FSAModel_KW(_FSAModel): + def __init__(self, **kw: t.Any) -> None: ... + +else: + # To minimize side effects, the type hint only works for static type checker. + # At run time, `_FSAModel_KW` falls back to `_FSAModel` + _FSAModel_KW = _FSAModel + + +if t.TYPE_CHECKING: + + @compat_typing.dataclass_transform( + field_specifiers=( + sa_orm.MappedColumn, + sa_orm.RelationshipProperty, + sa_orm.Composite, + sa_orm.Synonym, + sa_orm.mapped_column, + sa_orm.relationship, + sa_orm.composite, + sa_orm.synonym, + sa_orm.deferred, + ), + ) + class _FSAModel_DataClass(_FSAModel): ... + +else: + # To minimize side effects, the type hint only works for static type checker. + # At run time, `_FSAModel_DataClass` falls back to `_FSAModel` + _FSAModel_DataClass = _FSAModel + + +class ModelGetter: + """Model getter for the ``SQLAlchemy().Model`` property. + + This getter is used for determining the correct type of ``SQLAlchemy().Model``. + + When ``SQLAlchemy`` is initialized by + + .. code-block:: python + + db = SQLAlchemy(model_class=MappedAsDataclass) + + the ``db.Model`` property needs to be a class decorated by ``dataclass_transform``. + + Otherwise, the ``db.Model`` property needs to provide a synthesized initialization + method accepting unknown keyword arguments. These keyword arguments are not + annotated but limited in the range of data items. This rule is guaranteed by the + featuers of all other candidates of ``model_class``. + + Calling the class property ``SQLAlchemy.Model`` will return this descriptor + directly. + """ + + # This variant is at first. Its priority is highest for making SQLAlchemy[Any] + # exports a Model with type[_FSAModel_KW]. + # Note that in actual using cases, users do not need to inherit Model classes. + @te.overload + def __get__( + self, obj: SQLAlchemy[type[Model]], obj_cls: t.Any = None + ) -> type[_FSAModel_KW]: ... + + # This variant needs to be prior than DeclarativeBase, because a class may inherit + # multiple classes. When both MappedAsDataclass and DeclarativeBase are in the MRO + # list, this configuration make type[_FSAModel_DataClass] preferred. + @te.overload + def __get__( + self, obj: SQLAlchemy[type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None + ) -> type[_FSAModel_DataClass]: ... + + @te.overload + def __get__( + self, obj: SQLAlchemy[type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None + ) -> type[_FSAModel_KW]: ... + + @te.overload + def __get__( + self, + obj: SQLAlchemy[type[sa_orm.DeclarativeBaseNoMeta]], + obj_cls: t.Any = None, + ) -> type[_FSAModel_KW]: ... + + @te.overload + def __get__( + self, obj: SQLAlchemy[sa_orm.DeclarativeMeta], obj_cls: t.Any = None + ) -> type[_FSAModel_KW]: ... + + @te.overload + def __get__( + self: te.Self, obj: None, obj_cls: type[SQLAlchemy[t.Any]] | None = None + ) -> type[_FSAModel]: ... + + def __get__( + self: te.Self, obj: SQLAlchemy[t.Any] | None, obj_cls: t.Any = None + ) -> te.Self | type[Model] | type[t.Any]: + if isinstance(obj, SQLAlchemy): + return obj._Model + else: + return self + + def _get_2x_declarative_bases( model_class: _FSA_MCT, ) -> list[type[sa_orm.DeclarativeBase | sa_orm.DeclarativeBaseNoMeta]]: @@ -58,7 +162,7 @@ def _get_2x_declarative_bases( ] -class SQLAlchemy: +class SQLAlchemy(t.Generic[_FSA_MCT_T]): """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, associating tables and models with specific engines, and cleaning up connections and sessions after each request. @@ -168,7 +272,7 @@ def __init__( metadata: sa.MetaData | None = None, session_options: dict[str, t.Any] | None = None, query_class: type[Query] = Query, - model_class: _FSA_MCT = Model, # type: ignore[assignment] + model_class: _FSA_MCT_T = Model, # type: ignore[assignment] engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, disable_autonaming: bool = False, @@ -241,29 +345,17 @@ def __init__( This is a subclass of SQLAlchemy's ``Table`` rather than a function. """ - self.Model = self._make_declarative_base( + self._Model = self._make_declarative_base( model_class, disable_autonaming=disable_autonaming ) - """A SQLAlchemy declarative model class. Subclass this to define database - models. - - If a model does not set ``__tablename__``, it will be generated by converting - the class name from ``CamelCase`` to ``snake_case``. It will not be generated - if the model looks like it uses single-table inheritance. - - If a model or parent class sets ``__bind_key__``, it will use that metadata and - database engine. Otherwise, it will use the default :attr:`metadata` and - :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. - - For code using the SQLAlchemy 1.x API, customize this model by subclassing - :class:`.Model` and passing the ``model_class`` parameter to the extension. - A fully created declarative model class can be - passed as well, to use a custom metaclass. - - For code using the SQLAlchemy 2.x API, customize this model by subclassing - :class:`sqlalchemy.orm.DeclarativeBase` or - :class:`sqlalchemy.orm.DeclarativeBaseNoMeta` - and passing the ``model_class`` parameter to the extension. + """A SQLAlchemy declarative model class. This private model class is returned + by ``_make_declarative_base``. + + At run time, this class is the same as ``SQLAlchemy.Model``. Accessing + ``SQLAlchemy.Model`` rather than this class is more recommended because + ``SQLAlchemy.Model`` can provide better type hints. + + :meta private: """ if engine_options is None: @@ -277,6 +369,31 @@ def __init__( if app is not None: self.init_app(app) + # Need to be placed after __init__ because __init__ takes a default value + # named `Model`. + Model = ModelGetter() + """A SQLAlchemy declarative model class. Subclass this to define database + models. + + If a model does not set ``__tablename__``, it will be generated by converting + the class name from ``CamelCase`` to ``snake_case``. It will not be generated + if the model looks like it uses single-table inheritance. + + If a model or parent class sets ``__bind_key__``, it will use that metadata and + database engine. Otherwise, it will use the default :attr:`metadata` and + :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. + + For code using the SQLAlchemy 1.x API, customize this model by subclassing + :class:`.Model` and passing the ``model_class`` parameter to the extension. + A fully created declarative model class can be + passed as well, to use a custom metaclass. + + For code using the SQLAlchemy 2.x API, customize this model by subclassing + :class:`sqlalchemy.orm.DeclarativeBase` or + :class:`sqlalchemy.orm.DeclarativeBaseNoMeta` + and passing the ``model_class`` parameter to the extension. + """ + def __repr__(self) -> str: if not has_app_context(): return f"<{type(self).__name__}>" @@ -534,7 +651,7 @@ def _make_declarative_base( ``model`` can be an already created declarative model class. """ model: type[_FSAModel] - declarative_bases = _get_2x_declarative_bases(model_class) + declarative_bases = _get_2x_declarative_bases(t.cast(t.Any, model_class)) if len(declarative_bases) > 1: # raise error if more than one declarative base is found raise ValueError( @@ -547,11 +664,14 @@ def _make_declarative_base( mixin_classes = [BindMixin, NameMixin, Model] if disable_autonaming: mixin_classes.remove(NameMixin) - model = types.new_class( - "FlaskSQLAlchemyBase", - (*mixin_classes, *model_class.__bases__), - {"metaclass": type(declarative_bases[0])}, - lambda ns: ns.update(body), + model = t.cast( + t.Type[_FSAModel], + types.new_class( + "FlaskSQLAlchemyBase", + (*mixin_classes, *model_class.__bases__), + {"metaclass": type(declarative_bases[0])}, + lambda ns: ns.update(body), + ), ) elif not isinstance(model_class, sa_orm.DeclarativeMeta): metadata = self._make_metadata(None) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 6468a734..e3f54285 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -31,7 +31,7 @@ class Model: already created declarative model class as ``model_class``. """ - __fsa__: t.ClassVar[SQLAlchemy] + __fsa__: t.ClassVar[SQLAlchemy[t.Any]] """Internal reference to the extension object. :meta private: @@ -73,7 +73,7 @@ class BindMetaMixin(type): directly on the child model. """ - __fsa__: SQLAlchemy + __fsa__: SQLAlchemy[t.Any] metadata: sa.MetaData def __init__( @@ -104,7 +104,7 @@ class BindMixin: .. versionchanged:: 3.1.0 """ - __fsa__: SQLAlchemy + __fsa__: SQLAlchemy[t.Any] metadata: sa.MetaData @classmethod diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index 631fffa8..b1c15710 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -23,7 +23,7 @@ class Session(sa_orm.Session): Renamed from ``SignallingSession``. """ - def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None: + def __init__(self, db: SQLAlchemy[t.Any], **kwargs: t.Any) -> None: super().__init__(**kwargs) self._db = db self._model_changes: dict[object, tuple[t.Any, str]] = {} diff --git a/tests/conftest.py b/tests/conftest.py index d4ab92f4..44270a5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,7 +63,7 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: @pytest.fixture(params=test_classes) -def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy: +def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy[t.Any]: if request.param is not Model: return SQLAlchemy(app, model_class=types.new_class(*request.param)) else: @@ -79,7 +79,7 @@ def model_class(request: pytest.FixtureRequest) -> t.Any: @pytest.fixture -def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]: +def Todo(app: Flask, db: SQLAlchemy[t.Any]) -> t.Generator[t.Any, None, None]: if issubclass(db.Model, (sa_orm.MappedAsDataclass)): class Todo(db.Model): diff --git a/tests/test_cli.py b/tests/test_cli.py index 91672733..42131592 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ @pytest.mark.usefixtures("app_ctx") -def test_shell_context(db: SQLAlchemy, Todo: t.Any) -> None: +def test_shell_context(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: context = add_models_to_shell() assert context["db"] is db assert context["Todo"] is Todo diff --git a/tests/test_engine.py b/tests/test_engine.py index 0e88d5e3..dcdb0ad8 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,7 +12,7 @@ from flask_sqlalchemy import SQLAlchemy -def test_default_engine(app: Flask, db: SQLAlchemy) -> None: +def test_default_engine(app: Flask, db: SQLAlchemy[t.Any]) -> None: with app.app_context(): assert db.engine is db.engines[None] diff --git a/tests/test_extension_object.py b/tests/test_extension_object.py index 0cb5a608..e0ae8699 100644 --- a/tests/test_extension_object.py +++ b/tests/test_extension_object.py @@ -13,7 +13,7 @@ @pytest.mark.usefixtures("app_ctx") -def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: item = Todo() db.session.add(item) db.session.commit() diff --git a/tests/test_extension_repr.py b/tests/test_extension_repr.py index cfa94c75..d9ff5b1c 100644 --- a/tests/test_extension_repr.py +++ b/tests/test_extension_repr.py @@ -3,10 +3,11 @@ from flask import Flask from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model def test_repr_no_context() -> None: - db = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" @@ -15,7 +16,7 @@ def test_repr_no_context() -> None: def test_repr_default() -> None: - db = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" @@ -25,7 +26,7 @@ def test_repr_default() -> None: def test_repr_default_plustwo() -> None: - db = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" app.config["SQLALCHEMY_BINDS"] = { @@ -39,7 +40,7 @@ def test_repr_default_plustwo() -> None: def test_repr_nodefault() -> None: - db = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_BINDS"] = {"x": "sqlite:///:memory:"} @@ -49,7 +50,7 @@ def test_repr_nodefault() -> None: def test_repr_nodefault_plustwo() -> None: - db = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_BINDS"] = { "a": "sqlite:///:memory:", diff --git a/tests/test_legacy_query.py b/tests/test_legacy_query.py index 170e5bb7..7d073d84 100644 --- a/tests/test_legacy_query.py +++ b/tests/test_legacy_query.py @@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model from flask_sqlalchemy.query import Query @@ -25,7 +26,7 @@ def ignore_query_warning() -> t.Generator[None, None, None]: @pytest.mark.usefixtures("app_ctx") -def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: item = Todo() db.session.add(item) db.session.commit() @@ -36,7 +37,7 @@ def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_first_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.commit() assert Todo.query.filter_by(title="a").first_or_404().title == "a" @@ -46,7 +47,7 @@ def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_one_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.add(Todo(title="b")) db.session.add(Todo(title="b")) @@ -63,7 +64,7 @@ def test_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: +def test_paginate(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add_all(Todo() for _ in range(150)) db.session.commit() p = Todo.query.paginate() @@ -75,7 +76,7 @@ def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_default_query_class(db: SQLAlchemy) -> None: +def test_default_query_class(db: SQLAlchemy[t.Any]) -> None: class Parent(db.Model): id = sa.Column(sa.Integer, primary_key=True) children1 = db.relationship("Child", backref="parent1", lazy="dynamic") @@ -101,7 +102,7 @@ def test_custom_query_class(app: Flask) -> None: class CustomQuery(Query): pass - db = SQLAlchemy(app, query_class=CustomQuery) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app, query_class=CustomQuery) class Parent(db.Model): id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 8b54e5bc..e1fc4c8f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -13,7 +13,7 @@ from flask_sqlalchemy.model import Model -def test_default_metadata(db: SQLAlchemy) -> None: +def test_default_metadata(db: SQLAlchemy[t.Any]) -> None: assert db.metadata is db.metadatas[None] assert db.metadata.info["bind_key"] is None assert db.Model.metadata is db.metadata @@ -21,7 +21,7 @@ def test_default_metadata(db: SQLAlchemy) -> None: def test_custom_metadata_1x() -> None: metadata = sa.MetaData() - db = SQLAlchemy(metadata=metadata) + db: SQLAlchemy[t.Any] = SQLAlchemy(metadata=metadata) assert db.metadata is metadata assert db.metadata.info["bind_key"] is None assert db.Model.metadata is db.metadata @@ -34,7 +34,9 @@ class Base(sa_orm.DeclarativeBase): pass with pytest.deprecated_call(): - db = SQLAlchemy(model_class=Base, metadata=custom_metadata) + db: SQLAlchemy[type[Base]] = SQLAlchemy( + model_class=Base, metadata=custom_metadata + ) assert db.metadata is Base.metadata assert db.metadata.info["bind_key"] is None @@ -88,7 +90,7 @@ def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: model_class.metadata = sa.MetaData( naming_convention={"pk": "spk_%(table_name)s"} ) - db = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[type[t.Any]] = SQLAlchemy(app, model_class=model_class) else: db = SQLAlchemy( app, metadata=sa.MetaData(naming_convention={"pk": "spk_%(table_name)s"}) @@ -100,7 +102,7 @@ def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") def test_create_drop_all(app: Flask) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -131,7 +133,7 @@ class Post(db.Model): @pytest.mark.parametrize("bind_key", ["a", ["a"]]) def test_create_key_spec(app: Flask, bind_key: str | list[str | None]) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -151,7 +153,7 @@ class Post(db.Model): def test_reflect(app: Flask) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///user.db" app.config["SQLALCHEMY_BINDS"] = {"post": "sqlite:///post.db"} - db = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) db.Table("post", sa.Column("id", sa.Integer, primary_key=True), bind_key="post") db.create_all() diff --git a/tests/test_model.py b/tests/test_model.py index 86eff905..ed24bb93 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -20,7 +20,7 @@ def now() -> datetime: def test_default_model_class_1x(app: Flask) -> None: - db = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) assert db.Model.query_class is db.Query assert db.Model.metadata is db.metadata @@ -32,7 +32,7 @@ def test_custom_model_class_1x(app: Flask) -> None: class CustomModel(Model): pass - db = SQLAlchemy(app, model_class=CustomModel) + db: SQLAlchemy[type[CustomModel]] = SQLAlchemy(app, model_class=CustomModel) assert issubclass(db.Model, CustomModel) assert isinstance(db.Model, DefaultMeta) @@ -44,7 +44,7 @@ class CustomMeta(DefaultMeta): pass CustomModel = sa_orm.declarative_base(cls=base, name="Model", metaclass=CustomMeta) - db = SQLAlchemy(app, model_class=CustomModel) + db: SQLAlchemy[CustomMeta] = SQLAlchemy(app, model_class=CustomModel) assert db.Model is CustomModel assert db.Model.query_class is db.Query assert "query" in db.Model.__dict__ @@ -87,11 +87,11 @@ class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass): @pytest.mark.usefixtures("app_ctx") -def test_declaredattr(app: Flask, model_class: t.Any) -> None: +def test_declaredattr(app: Flask, model_class: type[Model]) -> None: if model_class is Model: class IdModel(Model): - @sa.orm.declared_attr + @sa_orm.declared_attr @classmethod def id(cls: type[Model]): # type: ignore[no-untyped-def] for base in cls.__mro__[1:-1]: @@ -101,7 +101,9 @@ def id(cls: type[Model]): # type: ignore[no-untyped-def] return sa.Column(sa.ForeignKey(base.id), primary_key=True) return sa.Column(sa.Integer, primary_key=True) - db = SQLAlchemy(app, model_class=IdModel) + db: SQLAlchemy[type[IdModel]] | SQLAlchemy[type[Base]] = SQLAlchemy( + app, model_class=IdModel + ) class User(db.Model): name = db.Column(db.String) @@ -145,7 +147,7 @@ class Employee(User): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") def test_abstractmodel(app: Flask, model_class: t.Any) -> None: - db = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[t.Any] = SQLAlchemy(app, model_class=model_class) if issubclass(db.Model, (sa_orm.MappedAsDataclass)): @@ -204,7 +206,7 @@ class Post(TimestampModel): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") def test_mixinmodel(app: Flask, model_class: t.Any) -> None: - db = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[type[t.Any]] = SQLAlchemy(app, model_class=model_class) if issubclass(db.Model, (sa_orm.MappedAsDataclass)): @@ -219,7 +221,7 @@ class TimestampMixin(sa_orm.MappedAsDataclass): init=False, ) - class Post(TimestampMixin, db.Model): + class Post(db.Model, TimestampMixin): id: sa_orm.Mapped[int] = sa_orm.mapped_column( db.Integer, primary_key=True, init=False ) @@ -235,7 +237,7 @@ class TimestampMixin: # type: ignore[no-redef] db.DateTime, default=now, onupdate=now ) - class Post(TimestampMixin, db.Model): # type: ignore[no-redef] + class Post(db.Model, TimestampMixin): # type: ignore[no-redef] id: sa_orm.Mapped[int] = sa_orm.mapped_column(db.Integer, primary_key=True) title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False) @@ -245,7 +247,7 @@ class TimestampMixin: # type: ignore[no-redef] created = db.Column(db.DateTime, nullable=False, default=now) updated = db.Column(db.DateTime, onupdate=now, default=now) - class Post(TimestampMixin, db.Model): # type: ignore[no-redef] + class Post(db.Model, TimestampMixin): # type: ignore[no-redef] id = db.Column(db.Integer, primary_key=True) title = db.Column(db.String, nullable=False) @@ -259,7 +261,7 @@ class Post(TimestampMixin, db.Model): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") -def test_model_repr(db: SQLAlchemy) -> None: +def test_model_repr(db: SQLAlchemy[type[Model]]) -> None: class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -287,7 +289,7 @@ class Base(sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta): # type: ignor @pytest.mark.usefixtures("app_ctx") def test_disable_autonaming_true_sql1(app: Flask) -> None: - db = SQLAlchemy(app, disable_autonaming=True) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app, disable_autonaming=True) with pytest.raises(sa_exc.InvalidRequestError): diff --git a/tests/test_model_bind.py b/tests/test_model_bind.py index 7c633c83..c058b3bc 100644 --- a/tests/test_model_bind.py +++ b/tests/test_model_bind.py @@ -1,18 +1,20 @@ from __future__ import annotations +import typing as t + import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy -def test_bind_key_default(db: SQLAlchemy) -> None: +def test_bind_key_default(db: SQLAlchemy[t.Any]) -> None: class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) assert User.metadata is db.metadata -def test_metadata_per_bind(db: SQLAlchemy) -> None: +def test_metadata_per_bind(db: SQLAlchemy[t.Any]) -> None: class User(db.Model): __bind_key__ = "other" id = sa.Column(sa.Integer, primary_key=True) @@ -20,7 +22,7 @@ class User(db.Model): assert User.metadata is db.metadatas["other"] -def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: +def test_multiple_binds_same_table_name(db: SQLAlchemy[t.Any]) -> None: class UserA(db.Model): __tablename__ = "user" id = sa.Column(sa.Integer, primary_key=True) @@ -35,7 +37,7 @@ class UserB(db.Model): assert UserA.__table__.metadata is not UserB.__table__.metadata -def test_inherit_parent(db: SQLAlchemy) -> None: +def test_inherit_parent(db: SQLAlchemy[t.Any]) -> None: class User(db.Model): __bind_key__ = "auth" id = sa.Column(sa.Integer, primary_key=True) @@ -51,7 +53,7 @@ class Admin(User): assert "metadata" not in Admin.__dict__ -def test_inherit_abstract_parent(db: SQLAlchemy) -> None: +def test_inherit_abstract_parent(db: SQLAlchemy[t.Any]) -> None: class AbstractUser(db.Model): __abstract__ = True __bind_key__ = "auth" @@ -63,7 +65,7 @@ class User(AbstractUser): assert "metadata" not in User.__dict__ -def test_explicit_metadata(db: SQLAlchemy) -> None: +def test_explicit_metadata(db: SQLAlchemy[t.Any]) -> None: other_metadata = sa.MetaData() class User(db.Model): @@ -75,7 +77,7 @@ class User(db.Model): assert "other" not in db.metadatas -def test_explicit_table(db: SQLAlchemy) -> None: +def test_explicit_table(db: SQLAlchemy[t.Any]) -> None: user_table = db.Table( "user", sa.Column("id", sa.Integer, primary_key=True), diff --git a/tests/test_model_name.py b/tests/test_model_name.py index 1b8cf87c..fef987f9 100644 --- a/tests/test_model_name.py +++ b/tests/test_model_name.py @@ -48,7 +48,7 @@ def test_camel_to_snake_case(name: str, expect: str) -> None: assert camel_to_snake_case(name) == expect -def test_name(db: SQLAlchemy) -> None: +def test_name(db: SQLAlchemy[t.Any]) -> None: class FOOBar(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -64,7 +64,7 @@ class Ham(db.Model): assert Ham.__tablename__ == "spam" -def test_single_name(db: SQLAlchemy) -> None: +def test_single_name(db: SQLAlchemy[t.Any]) -> None: """Single table inheritance should not set a new name.""" class Duck(db.Model): @@ -77,7 +77,7 @@ class Mallard(Duck): assert Mallard.__tablename__ == "duck" -def test_joined_name(db: SQLAlchemy) -> None: +def test_joined_name(db: SQLAlchemy[t.Any]) -> None: """Model has a separate primary key; it should set a new name.""" class Duck(db.Model): @@ -89,7 +89,7 @@ class Donald(Duck): assert Donald.__tablename__ == "donald" -def test_mixin_id(db: SQLAlchemy) -> None: +def test_mixin_id(db: SQLAlchemy[t.Any]) -> None: """Primary key provided by mixin should still allow model to set tablename. """ @@ -104,7 +104,7 @@ class Duck(Base, db.Model): assert Duck.__tablename__ == "duck" -def test_mixin_attr(db: SQLAlchemy) -> None: +def test_mixin_attr(db: SQLAlchemy[t.Any]) -> None: """A declared attr tablename will be used down multiple levels of inheritance. """ @@ -130,7 +130,7 @@ class Mallard(Duck): assert Mallard.__tablename__ == "MALLARD" -def test_abstract_name(db: SQLAlchemy) -> None: +def test_abstract_name(db: SQLAlchemy[t.Any]) -> None: """Abstract model should not set a name. Subclass should set a name.""" class Base(db.Model): @@ -144,7 +144,7 @@ class Duck(Base): assert Duck.__tablename__ == "duck" -def test_complex_inheritance(db: SQLAlchemy) -> None: +def test_complex_inheritance(db: SQLAlchemy[t.Any]) -> None: """Joined table inheritance, but the new primary key is provided by a mixin, not directly on the class. """ @@ -163,7 +163,7 @@ class RubberDuck(IdMixin, Duck): # type: ignore[misc] assert RubberDuck.__tablename__ == "rubber_duck" -def test_manual_name(db: SQLAlchemy) -> None: +def test_manual_name(db: SQLAlchemy[t.Any]) -> None: """Setting a manual name prevents generation for the immediate model. A name is generated for joined but not single-table inheritance. """ @@ -189,7 +189,7 @@ class Donald(Duck): assert Donald.__tablename__ == "DUCK" -def test_primary_constraint(db: SQLAlchemy) -> None: +def test_primary_constraint(db: SQLAlchemy[t.Any]) -> None: """Primary key will be picked up from table args.""" class Duck(db.Model): @@ -201,7 +201,7 @@ class Duck(db.Model): assert Duck.__tablename__ == "duck" -def test_no_access_to_class_property(db: SQLAlchemy) -> None: +def test_no_access_to_class_property(db: SQLAlchemy[t.Any]) -> None: """Ensure the implementation doesn't access class properties or declared attrs while inspecting the unmapped model. """ @@ -237,7 +237,7 @@ def floats(self) -> None: assert not ns.floats -def test_metadata_has_table(db: SQLAlchemy) -> None: +def test_metadata_has_table(db: SQLAlchemy[t.Any]) -> None: user = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) class User(db.Model): @@ -246,7 +246,7 @@ class User(db.Model): assert User.__table__ is user -def test_correct_error_for_no_primary_key(db: SQLAlchemy) -> None: +def test_correct_error_for_no_primary_key(db: SQLAlchemy[t.Any]) -> None: with pytest.raises(sa_exc.ArgumentError) as info: class User(db.Model): @@ -255,7 +255,7 @@ class User(db.Model): assert "could not assemble any primary key" in str(info.value) -def test_single_has_parent_table(db: SQLAlchemy) -> None: +def test_single_has_parent_table(db: SQLAlchemy[t.Any]) -> None: class Duck(db.Model): id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 14e24a9e..5149b5ca 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -119,7 +119,7 @@ def test_iter_pages_short(page: int) -> None: class _PaginateCallable: - def __init__(self, app: Flask, db: SQLAlchemy, Todo: t.Any) -> None: + def __init__(self, app: Flask, db: SQLAlchemy[t.Any], Todo: t.Any) -> None: self.app = app self.db = db self.Todo = Todo @@ -143,7 +143,7 @@ def __call__( @pytest.fixture -def paginate(app: Flask, db: SQLAlchemy, Todo: t.Any) -> _PaginateCallable: +def paginate(app: Flask, db: SQLAlchemy[t.Any], Todo: t.Any) -> _PaginateCallable: with app.app_context(): for i in range(1, 251): db.session.add(Todo(title=f"task {i}")) @@ -197,7 +197,7 @@ def test_error_out(paginate: _PaginateCallable, page: t.Any, per_page: t.Any) -> @pytest.mark.usefixtures("app_ctx") -def test_no_items_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_no_items_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: p = db.paginate(db.select(Todo)) assert len(p.items) == 0 diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py index c5cc73a2..6d5d932f 100644 --- a/tests/test_record_queries.py +++ b/tests/test_record_queries.py @@ -8,13 +8,14 @@ from flask import Flask from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model from flask_sqlalchemy.record_queries import get_recorded_queries @pytest.mark.usefixtures("app_ctx") def test_query_info(app: Flask) -> None: app.config["SQLALCHEMY_RECORD_QUERIES"] = True - db = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) # Copied and pasted from conftest.py if issubclass(db.Model, (sa_orm.MappedAsDataclass)): diff --git a/tests/test_session.py b/tests/test_session.py index cf75626a..5ff170e1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -11,7 +11,7 @@ from flask_sqlalchemy.session import Session -def test_scope(app: Flask, db: SQLAlchemy) -> None: +def test_scope(app: Flask, db: SQLAlchemy[t.Any]) -> None: with pytest.raises(RuntimeError): db.session() diff --git a/tests/test_table_bind.py b/tests/test_table_bind.py index fd83d1a9..78a85fc6 100644 --- a/tests/test_table_bind.py +++ b/tests/test_table_bind.py @@ -1,23 +1,25 @@ from __future__ import annotations +import typing as t + import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy -def test_bind_key_default(db: SQLAlchemy) -> None: +def test_bind_key_default(db: SQLAlchemy[t.Any]) -> None: user_table = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) assert user_table.metadata is db.metadata -def test_metadata_per_bind(db: SQLAlchemy) -> None: +def test_metadata_per_bind(db: SQLAlchemy[t.Any]) -> None: user_table = db.Table( "user", sa.Column("id", sa.Integer, primary_key=True), bind_key="other" ) assert user_table.metadata is db.metadatas["other"] -def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: +def test_multiple_binds_same_table_name(db: SQLAlchemy[t.Any]) -> None: user1_table = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) user2_table = db.Table( "user", sa.Column("id", sa.Integer, primary_key=True), bind_key="other" @@ -27,7 +29,7 @@ def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: assert user2_table.metadata is db.metadatas["other"] -def test_explicit_metadata(db: SQLAlchemy) -> None: +def test_explicit_metadata(db: SQLAlchemy[t.Any]) -> None: other_metadata = sa.MetaData() user_table = db.Table( "user", diff --git a/tests/test_view_query.py b/tests/test_view_query.py index c1d056c1..0557b40b 100644 --- a/tests/test_view_query.py +++ b/tests/test_view_query.py @@ -12,7 +12,7 @@ @pytest.mark.usefixtures("app_ctx") -def test_view_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_view_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: item = Todo() db.session.add(item) db.session.commit() @@ -22,7 +22,7 @@ def test_view_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_first_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.commit() result = db.first_or_404(db.select(Todo).filter_by(title="a")) @@ -33,7 +33,7 @@ def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_view_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_view_one_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.add(Todo(title="b")) db.session.add(Todo(title="b")) @@ -51,7 +51,7 @@ def test_view_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: +def test_paginate(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add_all(Todo() for _ in range(150)) db.session.commit() p = db.paginate(db.select(Todo)) @@ -64,7 +64,7 @@ def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: # This test creates its own inline model so that it can use that as the type @pytest.mark.usefixtures("app_ctx") -def test_view_get_or_404_typed(db: SQLAlchemy, app: Flask) -> None: +def test_view_get_or_404_typed(db: SQLAlchemy[t.Any], app: Flask) -> None: # Copied and pasted from conftest.py if issubclass(db.Model, (sa_orm.MappedAsDataclass)): diff --git a/tox.ini b/tox.ini index b7c6bc07..d8385ab3 100644 --- a/tox.ini +++ b/tox.ini @@ -21,12 +21,14 @@ commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} deps = pre-commit skip_install = true commands = pre-commit run --all-files +allowlist_externals = mypy [testenv:typing] deps = -r requirements/typing.txt commands = mypy --python-version 3.8 mypy --python-version 3.12 +allowlist_externals = mypy [testenv:docs] deps = -r requirements/docs.txt