From f95ffb1af6cd68a132eab037dd694ef2e7dd7130 Mon Sep 17 00:00:00 2001 From: Palash Shah <35114859+Palashio@users.noreply.github.com> Date: Thu, 9 Mar 2023 14:14:06 -0500 Subject: [PATCH] support for sqlalchemy 2.0 (#219) --- .github/workflows/ci.yaml | 32 ++++++++++++++++++++++++++++++++ CHANGELOG.md | 2 +- setup.py | 2 +- src/sql/connection.py | 35 +++++++++++++++++++++++------------ src/sql/magic.py | 2 +- tasks.py | 2 +- 6 files changed, 59 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0e3adec95..99e07b0e3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,6 +43,38 @@ jobs: run: | pytest --durations-min=5 --ignore=src/tests/integration + test-sqlalchemy-v1: + strategy: + matrix: + python-version: ['3.11'] + os: [ubuntu-latest, macos-latest, windows-latest] + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Lint with flake8 + run: | + python -m pip install --upgrade pip + # run flake8 on .py files + pip install flake8 + flake8 + # run flake8 on notebooks (.ipynb, .md, etc) + pip install jupytext nbqa + nbqa flake8 . + - name: Install dependencies + run: | + pip install "sqlalchemy<2" + pip install ".[dev]" + - name: Test with pytest + run: | + pytest --durations-min=5 --ignore=src/tests/integration # run: pkgmt check check: diff --git a/CHANGELOG.md b/CHANGELOG.md index 736a83bf3..d9d9f013b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ ## 0.6.3 (2023-03-06) * [Fix] Displaying variable substitution warning only when the variable to expand exists in the user's namespace - +* [Fix] Adds support for SQL Alchemy 2.0 ## 0.6.2 (2023-03-05) * [Fix] Deprecation warning incorrectly displayed [#213](https://github.com/ploomber/jupysql/issues/213) diff --git a/setup.py b/setup.py index d27fbafc4..b672e287b 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ install_requires = [ "prettytable", "ipython>=1.0", - "sqlalchemy>=0.6.7,<2.0", + "sqlalchemy", "sqlparse", "ipython-genutils>=0.1.0", "jinja2", diff --git a/src/sql/connection.py b/src/sql/connection.py index 516d79fd9..ea340f06d 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -11,6 +11,7 @@ "For technical support: https://ploomber.io/community" "\nDocumentation: https://jupysql.ploomber.io/en/latest/connecting.html" ) +IS_SQLALCHEMY_ONE = int(sqlalchemy.__version__.split(".")[0]) == 1 # Check Full List: https://docs.sqlalchemy.org/en/20/dialects MISSING_PACKAGE_LIST_EXCEPT_MATCHERS = { @@ -193,11 +194,23 @@ def _error_module_not_found(cls, e): return ModuleNotFoundError("test") def __init__(self, engine, alias=None): - self.dialect = engine.url.get_dialect() - self.metadata = sqlalchemy.MetaData(bind=engine) + self.url = engine.url self.name = self.assign_name(engine) + self.dialect = self.url.get_dialect() self.session = engine.connect() - self.connections[alias or repr(self.metadata.bind.url)] = self + + if IS_SQLALCHEMY_ONE: + self.metadata = sqlalchemy.MetaData(bind=engine) + + self.connections[ + alias + or ( + repr(sqlalchemy.MetaData(bind=engine).bind.url) + if IS_SQLALCHEMY_ONE + else repr(engine.url) + ) + ] = self + self.connect_args = None self.alias = alias Connection.current = self @@ -298,7 +311,7 @@ def connection_list(cls): result = [] for key in sorted(cls.connections): conn = cls.connections[key] - engine_url = conn.metadata.bind.url # type: sqlalchemy.engine.url.URL + engine_url = conn.metadata.bind.url if IS_SQLALCHEMY_ONE else conn.url prefix = "* " if conn == cls.current else " " @@ -312,7 +325,7 @@ def connection_list(cls): return "\n".join(result) @classmethod - def _close(cls, descriptor): + def close(cls, descriptor): if isinstance(descriptor, Connection): conn = descriptor else: @@ -328,12 +341,10 @@ def _close(cls, descriptor): if descriptor in cls.connections: cls.connections.pop(descriptor) else: - cls.connections.pop(str(conn.metadata.bind.url)) - - conn.session.close() - - def close(self): - self.__class__._close(self) + cls.connections.pop( + str(conn.metadata.bind.url) if IS_SQLALCHEMY_ONE else str(conn.url) + ) + conn.session.close() @classmethod def _get_curr_connection_info(cls): @@ -341,7 +352,7 @@ def _get_curr_connection_info(cls): if not cls.current: return None - engine = cls.current.metadata.bind + engine = cls.current.metadata.bind if IS_SQLALCHEMY_ONE else cls.current return { "dialect": getattr(engine.dialect, "name", None), "driver": getattr(engine.dialect, "driver", None), diff --git a/src/sql/magic.py b/src/sql/magic.py index c15cdcd8e..5c8e50381 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -266,7 +266,7 @@ def _execute(self, payload, line, cell, local_ns): if args.connections: return sql.connection.Connection.connections elif args.close: - return sql.connection.Connection._close(args.close) + return sql.connection.Connection.close(args.close) connect_arg = command.connection diff --git a/tasks.py b/tasks.py index a84a784c1..8cdea7e48 100644 --- a/tasks.py +++ b/tasks.py @@ -32,7 +32,7 @@ def setup(c, version=None, doc=False): @task(aliases=["d"]) def doc(c): - with c.cd('doc'): + with c.cd("doc"): c.run( "python3 -m sphinx -T -E -W --keep-going -b html \ -d _build/doctrees -D language=en . _build/html"