Skip to content

Commit 5199b51

Browse files
Add: test - error when there is no active connection (catherinedevlin#171) (catherinedevlin#189)
* Add: test * Fix: pytest * Revert * Revert * Move: _get_curr_connection_info to classmethod level * Move: pytest fixture cleanup * pins jupyter-book<0.14 (catherinedevlin#181) * ci * pins jupyter book * Add: f-string fix * Fix: arg self -> cls * Fix: arg self -> cls --------- Co-authored-by: Eduardo Blancas <[email protected]>
1 parent ecd91c1 commit 5199b51

File tree

5 files changed

+20
-18
lines changed

5 files changed

+20
-18
lines changed

src/sql/connection.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ def get_missing_package_suggestion_str(e):
7878
module_name, MISSING_PACKAGE_LIST_EXCEPT_MATCHERS.keys()
7979
)
8080
if close_matches:
81-
return "Perhaps you meant to use driver the dialect: \"{}\"".format(
82-
close_matches[0]
83-
)
81+
return f"Perhaps you meant to use driver the dialect: \"{close_matches[0]}\""
8482
# Not found
8583
return (
8684
suggestion_prefix + "make sure you are using correct driver name:\n"
@@ -337,12 +335,13 @@ def _close(cls, descriptor):
337335
def close(self):
338336
self.__class__._close(self)
339337

340-
def _get_curr_connection_info(self):
338+
@classmethod
339+
def _get_curr_connection_info(cls):
341340
"""Returns the dialect, driver, and database server version info"""
342-
if not self.current:
341+
if not cls.current:
343342
return None
344343

345-
engine = self.current.metadata.bind
344+
engine = cls.current.metadata.bind
346345
return {
347346
"dialect": getattr(engine.dialect, "name", None),
348347
"driver": getattr(engine.dialect, "driver", None),

src/sql/magic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ class SqlMagic(Magics, Configurable):
116116

117117
@telemetry.log_call("init")
118118
def __init__(self, shell):
119-
120119
self._store = store
121120

122121
Configurable.__init__(self, config=shell.config)
@@ -296,7 +295,9 @@ def _execute(self, payload, line, cell, local_ns):
296295
creator=args.creator,
297296
alias=args.alias,
298297
)
299-
payload["connection_info"] = conn._get_curr_connection_info()
298+
payload[
299+
"connection_info"
300+
] = sql.connection.Connection._get_curr_connection_info()
300301
if args.persist:
301302
return self._persist_dataframe(
302303
command.sql, conn, user_ns, append=False, index=not args.no_index
@@ -344,7 +345,6 @@ def _execute(self, payload, line, cell, local_ns):
344345

345346
return None
346347
else:
347-
348348
if command.result_var:
349349
self.shell.user_ns.update({command.result_var: result})
350350
return None

src/sql/plot.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,9 @@ def boxplot(payload, table, column, *, orient="v", with_=None, conn=None):
244244
if not conn:
245245
conn = sql.connection.Connection.current.session
246246

247-
if sql.connection.Connection.current:
248-
payload[
247+
payload[
249248
"connection_info"
250-
] = sql.connection.Connection.current._get_curr_connection_info()
249+
] = sql.connection.Connection._get_curr_connection_info()
251250

252251
ax = plt.gca()
253252
vert = orient == "v"
@@ -329,10 +328,9 @@ def histogram(payload, table, column, bins, with_=None, conn=None):
329328
.. plot:: ../examples/plot_histogram_many.py
330329
"""
331330
ax = plt.gca()
332-
if sql.connection.Connection.current:
333-
payload[
331+
payload[
334332
"connection_info"
335-
] = sql.connection.Connection.current._get_curr_connection_info()
333+
] = sql.connection.Connection._get_curr_connection_info()
336334
if isinstance(column, str):
337335
bin_, height = _histogram(table, column, bins, with_=with_, conn=conn)
338336
ax.bar(bin_, height, align="center", width=bin_[-1] - bin_[-2])

src/sql/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def DataFrame(self, payload):
174174
frame = pd.DataFrame(self, columns=(self and self.keys) or [])
175175
payload[
176176
"connection_info"
177-
] = sql.connection.Connection.current._get_curr_connection_info()
177+
] = sql.connection.Connection._get_curr_connection_info()
178178
return frame
179179

180180
@telemetry.log_call("polars-data-frame")

src/tests/test_connection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def test_alias(cleanup):
3737

3838

3939
def test_get_curr_connection_info(mock_postgres):
40-
conn = Connection.from_connect_str("postgresql://user:[email protected]/db")
41-
assert conn._get_curr_connection_info() == {
40+
Connection.from_connect_str("postgresql://user:[email protected]/db")
41+
assert Connection._get_curr_connection_info() == {
4242
"dialect": "postgresql",
4343
"driver": "psycopg2",
4444
"server_version_info": None,
@@ -92,3 +92,8 @@ def test_missing_driver(
9292
Connection.from_connect_str(connect_str)
9393

9494
assert "try to install package: " + missing_pkg in str(error.value)
95+
96+
97+
def test_no_current_connection_and_get_info(monkeypatch):
98+
monkeypatch.setattr(Connection, "current", None)
99+
assert Connection._get_curr_connection_info() is None

0 commit comments

Comments
 (0)