11
11
"For technical support: https://ploomber.io/community"
12
12
"\n Documentation: https://jupysql.ploomber.io/en/latest/connecting.html"
13
13
)
14
+ IS_SQLALCHEMY_ONE = int (sqlalchemy .__version__ .split ("." )[0 ]) == 1
14
15
15
16
# Check Full List: https://docs.sqlalchemy.org/en/20/dialects
16
17
MISSING_PACKAGE_LIST_EXCEPT_MATCHERS = {
@@ -193,11 +194,23 @@ def _error_module_not_found(cls, e):
193
194
return ModuleNotFoundError ("test" )
194
195
195
196
def __init__ (self , engine , alias = None ):
196
- self .dialect = engine .url .get_dialect ()
197
- self .metadata = sqlalchemy .MetaData (bind = engine )
197
+ self .url = engine .url
198
198
self .name = self .assign_name (engine )
199
+ self .dialect = self .url .get_dialect ()
199
200
self .session = engine .connect ()
200
- self .connections [alias or repr (self .metadata .bind .url )] = self
201
+
202
+ if IS_SQLALCHEMY_ONE :
203
+ self .metadata = sqlalchemy .MetaData (bind = engine )
204
+
205
+ self .connections [
206
+ alias
207
+ or (
208
+ repr (sqlalchemy .MetaData (bind = engine ).bind .url )
209
+ if IS_SQLALCHEMY_ONE
210
+ else repr (engine .url )
211
+ )
212
+ ] = self
213
+
201
214
self .connect_args = None
202
215
self .alias = alias
203
216
Connection .current = self
@@ -298,7 +311,7 @@ def connection_list(cls):
298
311
result = []
299
312
for key in sorted (cls .connections ):
300
313
conn = cls .connections [key ]
301
- engine_url = conn .metadata .bind .url # type: sqlalchemy.engine. url.URL
314
+ engine_url = conn .metadata .bind .url if IS_SQLALCHEMY_ONE else conn . url
302
315
303
316
prefix = "* " if conn == cls .current else " "
304
317
@@ -312,7 +325,7 @@ def connection_list(cls):
312
325
return "\n " .join (result )
313
326
314
327
@classmethod
315
- def _close (cls , descriptor ):
328
+ def close (cls , descriptor ):
316
329
if isinstance (descriptor , Connection ):
317
330
conn = descriptor
318
331
else :
@@ -328,20 +341,18 @@ def _close(cls, descriptor):
328
341
if descriptor in cls .connections :
329
342
cls .connections .pop (descriptor )
330
343
else :
331
- cls .connections .pop (str (conn .metadata .bind .url ))
332
-
333
- conn .session .close ()
334
-
335
- def close (self ):
336
- self .__class__ ._close (self )
344
+ cls .connections .pop (
345
+ str (conn .metadata .bind .url ) if IS_SQLALCHEMY_ONE else str (conn .url )
346
+ )
347
+ conn .session .close ()
337
348
338
349
@classmethod
339
350
def _get_curr_connection_info (cls ):
340
351
"""Returns the dialect, driver, and database server version info"""
341
352
if not cls .current :
342
353
return None
343
354
344
- engine = cls .current .metadata .bind
355
+ engine = cls .current .metadata .bind if IS_SQLALCHEMY_ONE else cls . current
345
356
return {
346
357
"dialect" : getattr (engine .dialect , "name" , None ),
347
358
"driver" : getattr (engine .dialect , "driver" , None ),
0 commit comments