From 5199b51cd4aa46e53344d3a24fd98c79370d066b Mon Sep 17 00:00:00 2001 From: Tony Kuo <123580782+tonykploomber@users.noreply.github.com> Date: Wed, 1 Mar 2023 18:32:00 -0500 Subject: [PATCH] Add: test - error when there is no active connection (#171) (#189) * Add: test * Fix: pytest * Revert * Revert * Move: _get_curr_connection_info to classmethod level * Move: pytest fixture cleanup * pins jupyter-book<0.14 (#181) * ci * pins jupyter book * Add: f-string fix * Fix: arg self -> cls * Fix: arg self -> cls --------- Co-authored-by: Eduardo Blancas --- src/sql/connection.py | 11 +++++------ src/sql/magic.py | 6 +++--- src/sql/plot.py | 10 ++++------ src/sql/run.py | 2 +- src/tests/test_connection.py | 9 +++++++-- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/sql/connection.py b/src/sql/connection.py index e8c130431..313e53d62 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -78,9 +78,7 @@ def get_missing_package_suggestion_str(e): module_name, MISSING_PACKAGE_LIST_EXCEPT_MATCHERS.keys() ) if close_matches: - return "Perhaps you meant to use driver the dialect: \"{}\"".format( - close_matches[0] - ) + return f"Perhaps you meant to use driver the dialect: \"{close_matches[0]}\"" # Not found return ( suggestion_prefix + "make sure you are using correct driver name:\n" @@ -337,12 +335,13 @@ def _close(cls, descriptor): def close(self): self.__class__._close(self) - def _get_curr_connection_info(self): + @classmethod + def _get_curr_connection_info(cls): """Returns the dialect, driver, and database server version info""" - if not self.current: + if not cls.current: return None - engine = self.current.metadata.bind + engine = cls.current.metadata.bind return { "dialect": getattr(engine.dialect, "name", None), "driver": getattr(engine.dialect, "driver", None), diff --git a/src/sql/magic.py b/src/sql/magic.py index 37085f5f2..9912350af 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -116,7 +116,6 @@ class SqlMagic(Magics, Configurable): @telemetry.log_call("init") def __init__(self, shell): - self._store = store Configurable.__init__(self, config=shell.config) @@ -296,7 +295,9 @@ def _execute(self, payload, line, cell, local_ns): creator=args.creator, alias=args.alias, ) - payload["connection_info"] = conn._get_curr_connection_info() + payload[ + "connection_info" + ] = sql.connection.Connection._get_curr_connection_info() if args.persist: return self._persist_dataframe( command.sql, conn, user_ns, append=False, index=not args.no_index @@ -344,7 +345,6 @@ def _execute(self, payload, line, cell, local_ns): return None else: - if command.result_var: self.shell.user_ns.update({command.result_var: result}) return None diff --git a/src/sql/plot.py b/src/sql/plot.py index 9319370e0..f0de99275 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -244,10 +244,9 @@ def boxplot(payload, table, column, *, orient="v", with_=None, conn=None): if not conn: conn = sql.connection.Connection.current.session - if sql.connection.Connection.current: - payload[ + payload[ "connection_info" - ] = sql.connection.Connection.current._get_curr_connection_info() + ] = sql.connection.Connection._get_curr_connection_info() ax = plt.gca() vert = orient == "v" @@ -329,10 +328,9 @@ def histogram(payload, table, column, bins, with_=None, conn=None): .. plot:: ../examples/plot_histogram_many.py """ ax = plt.gca() - if sql.connection.Connection.current: - payload[ + payload[ "connection_info" - ] = sql.connection.Connection.current._get_curr_connection_info() + ] = sql.connection.Connection._get_curr_connection_info() if isinstance(column, str): bin_, height = _histogram(table, column, bins, with_=with_, conn=conn) ax.bar(bin_, height, align="center", width=bin_[-1] - bin_[-2]) diff --git a/src/sql/run.py b/src/sql/run.py index 518bc81ec..faa5f153b 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -174,7 +174,7 @@ def DataFrame(self, payload): frame = pd.DataFrame(self, columns=(self and self.keys) or []) payload[ "connection_info" - ] = sql.connection.Connection.current._get_curr_connection_info() + ] = sql.connection.Connection._get_curr_connection_info() return frame @telemetry.log_call("polars-data-frame") diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 4e0370d1d..982d8903d 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -37,8 +37,8 @@ def test_alias(cleanup): def test_get_curr_connection_info(mock_postgres): - conn = Connection.from_connect_str("postgresql://user:topsecret@somedomain.com/db") - assert conn._get_curr_connection_info() == { + Connection.from_connect_str("postgresql://user:topsecret@somedomain.com/db") + assert Connection._get_curr_connection_info() == { "dialect": "postgresql", "driver": "psycopg2", "server_version_info": None, @@ -92,3 +92,8 @@ def test_missing_driver( Connection.from_connect_str(connect_str) assert "try to install package: " + missing_pkg in str(error.value) + + +def test_no_current_connection_and_get_info(monkeypatch): + monkeypatch.setattr(Connection, "current", None) + assert Connection._get_curr_connection_info() is None