Skip to content

Commit

Permalink
Add: test - error when there is no active connection (catherinedevlin…
Browse files Browse the repository at this point in the history
…#171) (catherinedevlin#189)

* Add: test

* Fix: pytest

* Revert

* Revert

* Move: _get_curr_connection_info to classmethod level

* Move: pytest fixture cleanup

* pins jupyter-book<0.14 (catherinedevlin#181)

* ci

* pins jupyter book

* Add: f-string fix

* Fix: arg self -> cls

* Fix: arg self -> cls

---------

Co-authored-by: Eduardo Blancas <[email protected]>
  • Loading branch information
tonykploomber and edublancas authored Mar 1, 2023
1 parent ecd91c1 commit 5199b51
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 18 deletions.
11 changes: 5 additions & 6 deletions src/sql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def get_missing_package_suggestion_str(e):
module_name, MISSING_PACKAGE_LIST_EXCEPT_MATCHERS.keys()
)
if close_matches:
return "Perhaps you meant to use driver the dialect: \"{}\"".format(
close_matches[0]
)
return f"Perhaps you meant to use driver the dialect: \"{close_matches[0]}\""
# Not found
return (
suggestion_prefix + "make sure you are using correct driver name:\n"
Expand Down Expand Up @@ -337,12 +335,13 @@ def _close(cls, descriptor):
def close(self):
self.__class__._close(self)

def _get_curr_connection_info(self):
@classmethod
def _get_curr_connection_info(cls):
"""Returns the dialect, driver, and database server version info"""
if not self.current:
if not cls.current:
return None

engine = self.current.metadata.bind
engine = cls.current.metadata.bind
return {
"dialect": getattr(engine.dialect, "name", None),
"driver": getattr(engine.dialect, "driver", None),
Expand Down
6 changes: 3 additions & 3 deletions src/sql/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class SqlMagic(Magics, Configurable):

@telemetry.log_call("init")
def __init__(self, shell):

self._store = store

Configurable.__init__(self, config=shell.config)
Expand Down Expand Up @@ -296,7 +295,9 @@ def _execute(self, payload, line, cell, local_ns):
creator=args.creator,
alias=args.alias,
)
payload["connection_info"] = conn._get_curr_connection_info()
payload[
"connection_info"
] = sql.connection.Connection._get_curr_connection_info()
if args.persist:
return self._persist_dataframe(
command.sql, conn, user_ns, append=False, index=not args.no_index
Expand Down Expand Up @@ -344,7 +345,6 @@ def _execute(self, payload, line, cell, local_ns):

return None
else:

if command.result_var:
self.shell.user_ns.update({command.result_var: result})
return None
Expand Down
10 changes: 4 additions & 6 deletions src/sql/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,9 @@ def boxplot(payload, table, column, *, orient="v", with_=None, conn=None):
if not conn:
conn = sql.connection.Connection.current.session

if sql.connection.Connection.current:
payload[
payload[
"connection_info"
] = sql.connection.Connection.current._get_curr_connection_info()
] = sql.connection.Connection._get_curr_connection_info()

ax = plt.gca()
vert = orient == "v"
Expand Down Expand Up @@ -329,10 +328,9 @@ def histogram(payload, table, column, bins, with_=None, conn=None):
.. plot:: ../examples/plot_histogram_many.py
"""
ax = plt.gca()
if sql.connection.Connection.current:
payload[
payload[
"connection_info"
] = sql.connection.Connection.current._get_curr_connection_info()
] = sql.connection.Connection._get_curr_connection_info()
if isinstance(column, str):
bin_, height = _histogram(table, column, bins, with_=with_, conn=conn)
ax.bar(bin_, height, align="center", width=bin_[-1] - bin_[-2])
Expand Down
2 changes: 1 addition & 1 deletion src/sql/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def DataFrame(self, payload):
frame = pd.DataFrame(self, columns=(self and self.keys) or [])
payload[
"connection_info"
] = sql.connection.Connection.current._get_curr_connection_info()
] = sql.connection.Connection._get_curr_connection_info()
return frame

@telemetry.log_call("polars-data-frame")
Expand Down
9 changes: 7 additions & 2 deletions src/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_alias(cleanup):


def test_get_curr_connection_info(mock_postgres):
conn = Connection.from_connect_str("postgresql://user:[email protected]/db")
assert conn._get_curr_connection_info() == {
Connection.from_connect_str("postgresql://user:[email protected]/db")
assert Connection._get_curr_connection_info() == {
"dialect": "postgresql",
"driver": "psycopg2",
"server_version_info": None,
Expand Down Expand Up @@ -92,3 +92,8 @@ def test_missing_driver(
Connection.from_connect_str(connect_str)

assert "try to install package: " + missing_pkg in str(error.value)


def test_no_current_connection_and_get_info(monkeypatch):
monkeypatch.setattr(Connection, "current", None)
assert Connection._get_curr_connection_info() is None

0 comments on commit 5199b51

Please sign in to comment.