Skip to content

Commit 42b785d

Browse files
authored
Require sqlalchemy>=2 and upgrade code. (#544)
1 parent ac44212 commit 42b785d

File tree

6 files changed

+24
-21
lines changed

6 files changed

+24
-21
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ repos:
6565
optree,
6666
pluggy,
6767
rich,
68-
sqlalchemy,
68+
sqlalchemy>2,
6969
types-setuptools,
7070
]
7171
pass_filenames: false

docs/source/changes.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask).
77

8-
## 0.4.5 - 2023-12-xx
8+
## 0.5.0 - 2024-xx-xx
9+
10+
- {pull}`544` requires sqlalchemy `>=2` and upgrades the syntax.
11+
12+
## 0.4.5 - 2024-01-xx
913

1014
- {pull}`515` enables tests with graphviz in CI. Thanks to {user}`NickCrews`.
1115
- {pull}`517` raises an error when the configuration file contains a non-existing path

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies:
1818
- pluggy >=1.0.0
1919
- optree >=0.9
2020
- rich
21-
- sqlalchemy >=1.4.36
21+
- sqlalchemy >=2
2222
- tomli >=1.0.0
2323
- typing_extensions
2424
- universal_pathlib

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ dependencies = [
3636
"packaging",
3737
"pluggy>=1",
3838
"rich",
39-
"sqlalchemy>=1.4.36",
39+
"sqlalchemy>=2",
4040
'tomli>=1; python_version < "3.11"',
4141
'typing-extensions; python_version < "3.9"',
4242
]

src/_pytask/database_utils.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import TYPE_CHECKING
55

66
from _pytask.dag_utils import node_and_neighbors
7-
from sqlalchemy import Column
87
from sqlalchemy import create_engine
9-
from sqlalchemy import String
10-
from sqlalchemy.orm import declarative_base
8+
from sqlalchemy.orm import DeclarativeBase
9+
from sqlalchemy.orm import Mapped
10+
from sqlalchemy.orm import mapped_column
1111
from sqlalchemy.orm import sessionmaker
1212

1313
if TYPE_CHECKING:
@@ -27,17 +27,18 @@
2727
DatabaseSession = sessionmaker()
2828

2929

30-
BaseTable = declarative_base()
30+
class BaseTable(DeclarativeBase):
31+
pass
3132

3233

33-
class State(BaseTable): # type: ignore[valid-type, misc]
34+
class State(BaseTable):
3435
"""Represent the state of a node in relation to a task."""
3536

3637
__tablename__ = "state"
3738

38-
task = Column(String, primary_key=True)
39-
node = Column(String, primary_key=True)
40-
hash_ = Column(String)
39+
task: Mapped[str] = mapped_column(primary_key=True)
40+
node: Mapped[str] = mapped_column(primary_key=True)
41+
hash_: Mapped[str]
4142

4243

4344
def create_database(url: str) -> None:
@@ -54,7 +55,7 @@ def _create_or_update_state(first_key: str, second_key: str, hash_: str) -> None
5455
if not state_in_db:
5556
session.add(State(task=first_key, node=second_key, hash_=hash_))
5657
else:
57-
state_in_db.hash_ = hash_ # type: ignore[assignment]
58+
state_in_db.hash_ = hash_
5859
session.commit()
5960

6061

src/_pytask/profile.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@
2929
from _pytask.session import Session
3030
from _pytask.traceback import Traceback
3131
from rich.table import Table
32-
from sqlalchemy import Column
33-
from sqlalchemy import Float
34-
from sqlalchemy import String
35-
32+
from sqlalchemy.orm import Mapped
33+
from sqlalchemy.orm import mapped_column
3634

3735
if TYPE_CHECKING:
3836
from _pytask.reports import ExecutionReport
@@ -51,9 +49,9 @@ class Runtime(BaseTable):
5149

5250
__tablename__ = "runtime"
5351

54-
task = Column(String, primary_key=True)
55-
date = Column(Float)
56-
duration = Column(Float)
52+
task: Mapped[str] = mapped_column(primary_key=True)
53+
date: Mapped[float]
54+
duration: Mapped[float]
5755

5856

5957
@hookimpl(tryfirst=True)
@@ -198,7 +196,7 @@ def _collect_runtimes(tasks: list[PTask]) -> dict[str, float]:
198196
"""Collect runtimes."""
199197
with DatabaseSession() as session:
200198
runtimes = [session.get(Runtime, task.signature) for task in tasks]
201-
return {task.name: r.duration for task, r in zip(tasks, runtimes) if r} # type: ignore[misc]
199+
return {task.name: r.duration for task, r in zip(tasks, runtimes) if r}
202200

203201

204202
class FileSizeNameSpace:

0 commit comments

Comments
 (0)