Skip to content

Commit

Permalink
support for sqlalchemy 2.0 (catherinedevlin#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
Palashio authored Mar 9, 2023
1 parent c4e978a commit f95ffb1
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 16 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,38 @@ jobs:
run: |
pytest --durations-min=5 --ignore=src/tests/integration
test-sqlalchemy-v1:
strategy:
matrix:
python-version: ['3.11']
os: [ubuntu-latest, macos-latest, windows-latest]

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Lint with flake8
run: |
python -m pip install --upgrade pip
# run flake8 on .py files
pip install flake8
flake8
# run flake8 on notebooks (.ipynb, .md, etc)
pip install jupytext nbqa
nbqa flake8 .
- name: Install dependencies
run: |
pip install "sqlalchemy<2"
pip install ".[dev]"
- name: Test with pytest
run: |
pytest --durations-min=5 --ignore=src/tests/integration
# run: pkgmt check
check:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
## 0.6.3 (2023-03-06)

* [Fix] Displaying variable substitution warning only when the variable to expand exists in the user's namespace

* [Fix] Adds support for SQL Alchemy 2.0
## 0.6.2 (2023-03-05)

* [Fix] Deprecation warning incorrectly displayed [#213](https://github.com/ploomber/jupysql/issues/213)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
install_requires = [
"prettytable",
"ipython>=1.0",
"sqlalchemy>=0.6.7,<2.0",
"sqlalchemy",
"sqlparse",
"ipython-genutils>=0.1.0",
"jinja2",
Expand Down
35 changes: 23 additions & 12 deletions src/sql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"For technical support: https://ploomber.io/community"
"\nDocumentation: https://jupysql.ploomber.io/en/latest/connecting.html"
)
IS_SQLALCHEMY_ONE = int(sqlalchemy.__version__.split(".")[0]) == 1

# Check Full List: https://docs.sqlalchemy.org/en/20/dialects
MISSING_PACKAGE_LIST_EXCEPT_MATCHERS = {
Expand Down Expand Up @@ -193,11 +194,23 @@ def _error_module_not_found(cls, e):
return ModuleNotFoundError("test")

def __init__(self, engine, alias=None):
self.dialect = engine.url.get_dialect()
self.metadata = sqlalchemy.MetaData(bind=engine)
self.url = engine.url
self.name = self.assign_name(engine)
self.dialect = self.url.get_dialect()
self.session = engine.connect()
self.connections[alias or repr(self.metadata.bind.url)] = self

if IS_SQLALCHEMY_ONE:
self.metadata = sqlalchemy.MetaData(bind=engine)

self.connections[
alias
or (
repr(sqlalchemy.MetaData(bind=engine).bind.url)
if IS_SQLALCHEMY_ONE
else repr(engine.url)
)
] = self

self.connect_args = None
self.alias = alias
Connection.current = self
Expand Down Expand Up @@ -298,7 +311,7 @@ def connection_list(cls):
result = []
for key in sorted(cls.connections):
conn = cls.connections[key]
engine_url = conn.metadata.bind.url # type: sqlalchemy.engine.url.URL
engine_url = conn.metadata.bind.url if IS_SQLALCHEMY_ONE else conn.url

prefix = "* " if conn == cls.current else " "

Expand All @@ -312,7 +325,7 @@ def connection_list(cls):
return "\n".join(result)

@classmethod
def _close(cls, descriptor):
def close(cls, descriptor):
if isinstance(descriptor, Connection):
conn = descriptor
else:
Expand All @@ -328,20 +341,18 @@ def _close(cls, descriptor):
if descriptor in cls.connections:
cls.connections.pop(descriptor)
else:
cls.connections.pop(str(conn.metadata.bind.url))

conn.session.close()

def close(self):
self.__class__._close(self)
cls.connections.pop(
str(conn.metadata.bind.url) if IS_SQLALCHEMY_ONE else str(conn.url)
)
conn.session.close()

@classmethod
def _get_curr_connection_info(cls):
"""Returns the dialect, driver, and database server version info"""
if not cls.current:
return None

engine = cls.current.metadata.bind
engine = cls.current.metadata.bind if IS_SQLALCHEMY_ONE else cls.current
return {
"dialect": getattr(engine.dialect, "name", None),
"driver": getattr(engine.dialect, "driver", None),
Expand Down
2 changes: 1 addition & 1 deletion src/sql/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _execute(self, payload, line, cell, local_ns):
if args.connections:
return sql.connection.Connection.connections
elif args.close:
return sql.connection.Connection._close(args.close)
return sql.connection.Connection.close(args.close)

connect_arg = command.connection

Expand Down
2 changes: 1 addition & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setup(c, version=None, doc=False):

@task(aliases=["d"])
def doc(c):
with c.cd('doc'):
with c.cd("doc"):
c.run(
"python3 -m sphinx -T -E -W --keep-going -b html \
-d _build/doctrees -D language=en . _build/html"
Expand Down

0 comments on commit f95ffb1

Please sign in to comment.