Skip to content

Commit b0fc307

Browse files
authored
Multiple databases (#10)
* Manage several databases on server side * update tests * Adds a side panel to display databases summary * Get database schema * Add a database icon * Link a database to a cell * prettier * lint * lint * fix python test * get databse schema for sync engine * Add server and application tests * Adds tests on sidepanels * lint * Specify database used for tests
1 parent 95511e2 commit b0fc307

18 files changed

+1364
-51
lines changed

jupyter_sql_cell/app.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from jupyter_server.extension.application import ExtensionApp
44
from jupyter_server.utils import url_path_join
5-
from traitlets import Unicode
5+
from traitlets import Dict, Integer, List, Unicode
66

7-
from .handlers import ExampleHandler, ExecuteHandler
7+
from .handlers import DatabasesHandler, DatabaseSchemaHandler, ExampleHandler, ExecuteHandler
88
from .sqlconnector import SQLConnector
99

1010

@@ -13,9 +13,29 @@ class JupyterSqlCell(ExtensionApp):
1313
name = "JupyterSqlCell"
1414
default_url = "/jupyter-sql-cell"
1515

16-
db_url = Unicode(
17-
"",
18-
help="The database URL"
16+
database = Dict(per_key_traits={
17+
"alias": Unicode(default_value=None, allow_none=True),
18+
"database": Unicode(),
19+
"dbms": Unicode(),
20+
"driver": Unicode(default_value=None, allow_none=True),
21+
"host": Unicode(default_value=None, allow_none=True),
22+
"port": Integer(default_value=None, allow_none=True)
23+
},
24+
default_value={},
25+
help="The databases description"
26+
).tag(config=True)
27+
28+
databases = List(
29+
Dict(per_key_traits={
30+
"alias": Unicode(default_value=None, allow_none=True),
31+
"database": Unicode(),
32+
"dbms": Unicode(),
33+
"driver": Unicode(default_value=None, allow_none=True),
34+
"host": Unicode(default_value=None, allow_none=True),
35+
"port": Integer(default_value=None, allow_none=True)
36+
}),
37+
default_value=[],
38+
help="The databases description",
1939
).tag(config=True)
2040

2141

@@ -24,17 +44,36 @@ def __init__(self) -> None:
2444

2545
def initialize(self):
2646
path = pathlib.Path(__file__)
27-
if not self.db_url:
47+
if self.database:
48+
self.databases.append(self.database)
49+
50+
if not self.databases:
2851
path = pathlib.Path(__file__).parent / "tests" / "data" / "world.sqlite"
29-
self.db_url = f"sqlite+aiosqlite:///{path}"
30-
SQLConnector.db_url = self.db_url
52+
self.databases = [{
53+
"alias": "default",
54+
"database": str(path),
55+
"dbms": "sqlite",
56+
"driver": None,
57+
"host": None,
58+
"port": None
59+
}]
60+
for database in self.databases:
61+
for option in ["alias", "driver", "host", "port"]:
62+
if not option in database.keys():
63+
database[option] = None
64+
SQLConnector.add_database(database)
65+
3166
return super().initialize()
3267

3368
def initialize_handlers(self):
3469
super().initialize_handlers()
35-
example_pattern = url_path_join("/jupyter-sql-cell", "get-example")
36-
execute_pattern = url_path_join("/jupyter-sql-cell", "execute")
70+
example_pattern = url_path_join(self.default_url, "get-example")
71+
databases_pattern = url_path_join(self.default_url, "databases")
72+
execute_pattern = url_path_join(self.default_url, "execute")
73+
schema_pattern = url_path_join(self.default_url, "schema")
3774
self.handlers.extend([
75+
(databases_pattern, DatabasesHandler),
3876
(example_pattern, ExampleHandler),
39-
(execute_pattern, ExecuteHandler)
77+
(execute_pattern, ExecuteHandler),
78+
(schema_pattern, DatabaseSchemaHandler)
4079
])

jupyter_sql_cell/handlers.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,99 @@
66
from .sqlconnector import SQLConnector
77

88

9+
def reply_error(api: APIHandler, msg: StopIteration):
10+
api.set_status(500)
11+
api.log.error(msg)
12+
reply = {"message": msg}
13+
api.finish(json.dumps(reply))
14+
15+
16+
class DatabasesHandler(APIHandler):
17+
@tornado.web.authenticated
18+
def get(self):
19+
try:
20+
databases = SQLConnector.get_databases()
21+
self.finish(json.dumps(databases))
22+
except Exception as e:
23+
self.log.error(f"Databases error\n{e}")
24+
self.write_error(500, exec_info=e)
25+
26+
927
class ExecuteHandler(APIHandler):
1028
# The following decorator should be present on all verb methods (head, get, post,
1129
# patch, put, delete, options) to ensure only authorized user can request the
1230
# Jupyter server
1331
@tornado.gen.coroutine
1432
@tornado.web.authenticated
1533
def post(self):
16-
query = json.loads(self.request.body).get("query", None)
34+
body = json.loads(self.request.body)
35+
id = body.get("id", None)
36+
query = body.get("query", None)
1737

38+
if id is None:
39+
reply_error(self, "The database id has not been provided")
40+
return
41+
if not query:
42+
reply_error(self, "No query has been provided")
43+
return
1844
try:
19-
connector = SQLConnector()
45+
connector = SQLConnector(int(id))
46+
if connector.errors:
47+
reply_error(self, connector.errors[0])
48+
return
2049
except Exception as e:
2150
self.log.error(f"Connector error\n{e}")
2251
self.write_error(500, exec_info=e)
52+
return
2353

2454
try:
2555
result = yield connector.execute(query)
2656
self.finish(json.dumps({
27-
"data": result
57+
"alias": connector.database["alias"],
58+
"data": result,
59+
"id": id,
60+
"query": query,
61+
}))
62+
except Exception as e:
63+
self.log.error(f"Query error\n{e}")
64+
self.write_error(500, exec_info=e)
65+
66+
67+
class DatabaseSchemaHandler(APIHandler):
68+
@tornado.gen.coroutine
69+
@tornado.web.authenticated
70+
def get(self):
71+
id = self.get_argument("id", "")
72+
target = self.get_argument("target", "tables")
73+
table = self.get_argument("table", "")
74+
75+
if not id:
76+
reply_error(self, "The database id has not been provided")
77+
return
78+
if target not in ["tables", "columns"]:
79+
reply_error(self, "Target must be \"tables\" or \"columns\"")
80+
return
81+
if target == "columns" and not table:
82+
reply_error(self, "The table has not been provided")
83+
return
84+
85+
try:
86+
connector = SQLConnector(int(id))
87+
if connector.errors:
88+
reply_error(self, connector.errors[0])
89+
return
90+
except Exception as e:
91+
self.log.error(f"Connector error\n{e}")
92+
self.write_error(500, exec_info=e)
93+
return
94+
95+
try:
96+
data = yield connector.get_schema(target, table)
97+
self.finish(json.dumps({
98+
"data": data,
99+
"id": id,
100+
"table": table,
101+
"target": target
28102
}))
29103
except Exception as e:
30104
self.log.error(f"Query error\n{e}")

jupyter_sql_cell/sqlconnector.py

Lines changed: 135 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,71 @@
1-
from jupyter_server import log
2-
from sqlalchemy.ext.asyncio import create_async_engine
3-
from sqlalchemy import CursorResult, text
4-
from typing import Any, Dict, List
1+
from sqlalchemy.exc import InvalidRequestError, NoSuchModuleError
2+
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
3+
from sqlalchemy import CursorResult, Inspector, URL, create_engine, inspect, text
4+
from typing import Any, Dict, List, Optional, TypedDict
5+
6+
ASYNC_DRIVERS = {
7+
"mariadb": ["asyncmy", "aiomysql"],
8+
"mysql": ["asyncmy", "aiomysql"],
9+
"postgres": ["asyncpg", "psycopg"],
10+
"sqlite": ["aiosqlite"],
11+
}
12+
13+
14+
class DatabaseDesc(TypedDict):
15+
alias: Optional[str]
16+
database: str
17+
dbms: str
18+
driver: Optional[str]
19+
host: Optional[str]
20+
port: Optional[int]
21+
22+
23+
class Database(TypedDict):
24+
alias: str
25+
id: int
26+
is_async: bool
27+
url: URL
28+
29+
30+
class DatabaseSummary(DatabaseDesc):
31+
id: int
32+
is_async: bool
533

634

735
class SQLConnector:
836

9-
db_url: str = ""
37+
databases: [Database] = []
38+
warnings = []
39+
40+
def __init__(self, database_id: int):
41+
self.engine = None
42+
self.errors = []
43+
self.database: Database = next(filter(lambda db: db["id"] == database_id, self.databases), None)
1044

11-
engine = None
45+
if not self.database:
46+
self.errors.append(f"There is no registered database with id {database_id}")
47+
else:
48+
if self.database["is_async"]:
49+
self.engine = create_async_engine(self.database["url"])
50+
else:
51+
self.engine = create_engine(self.database["url"])
1252

13-
def __init__(self) -> None:
14-
if not self.db_url:
15-
log.warn("The database URL is not set")
16-
self.engine = create_async_engine(self.db_url)
53+
async def get_schema(self, target: str, table: str = "") -> [str]:
54+
if self.database["is_async"]:
55+
async with self.engine.connect() as conn:
56+
schema = await conn.run_sync(self.use_inspector, target, table)
57+
else:
58+
with self.engine.connect() as conn:
59+
schema = self.use_inspector(conn, target, table)
60+
return schema
61+
62+
def use_inspector(self, conn: AsyncConnection, target: str, table: str) -> [str]:
63+
inspector: Inspector = inspect(conn)
64+
if target == "tables":
65+
return inspector.get_table_names()
66+
elif target == "columns":
67+
columns = inspector.get_columns(table)
68+
return sorted([column['name'] for column in columns])
1769

1870
async def execute(self, query: str) -> str:
1971
if not self.engine:
@@ -27,6 +79,79 @@ async def execute_request(self, query: str) -> CursorResult[Any]:
2779
cursor: CursorResult[Any] = await connection.execute(text(query))
2880
return cursor
2981

82+
@classmethod
83+
def add_database(cls, db_desc: DatabaseDesc):
84+
id = 0
85+
if cls.databases:
86+
id = max([db["id"] for db in cls.databases]) + 1
87+
88+
if db_desc["alias"]:
89+
alias = db_desc["alias"]
90+
else:
91+
alias = f"{db_desc['dbms']}_{id}"
92+
93+
if db_desc["driver"]:
94+
drivers = [db_desc["driver"]]
95+
else:
96+
drivers = ASYNC_DRIVERS.get(db_desc["dbms"], [])
97+
98+
for driver in drivers:
99+
url = URL.create(
100+
drivername=f"{db_desc['dbms']}+{driver}",
101+
host=db_desc["host"],
102+
port=db_desc["port"],
103+
database=db_desc["database"]
104+
)
105+
try:
106+
create_async_engine(url)
107+
cls.databases.append({
108+
"alias": alias,
109+
"id": id,
110+
"url": url,
111+
"is_async": True
112+
})
113+
return
114+
except (InvalidRequestError, NoSuchModuleError):
115+
# InvalidRequestError is raised if the driver is not async.
116+
# NoSuchModuleError is raised if the driver is not installed.
117+
continue
118+
119+
driver = f"+{db_desc['driver']}" if db_desc["driver"] else ""
120+
url = URL.create(
121+
drivername=f"{db_desc['dbms']}{driver}",
122+
host=db_desc["host"],
123+
port=db_desc["port"],
124+
database=db_desc["database"]
125+
)
126+
create_engine(url)
127+
cls.databases.append({
128+
"alias": alias,
129+
"id": id,
130+
"url": url,
131+
"is_async": False
132+
})
133+
cls.warnings.append("No async driver found, the query will be executed synchronously")
134+
print(cls.warnings[-1])
135+
136+
@classmethod
137+
def get_databases(cls):
138+
summary_databases: [DatabaseSummary] = []
139+
for database in cls.databases:
140+
url: URL = database["url"]
141+
summary: DatabaseSummary = {
142+
"alias": database["alias"],
143+
"database": url.database,
144+
"driver": url.drivername,
145+
"id": database["id"],
146+
"is_async": database["is_async"]
147+
}
148+
if url.host:
149+
summary["host"] = url.host
150+
if url.port:
151+
summary["port"] = url.port
152+
summary_databases.append(summary)
153+
return summary_databases
154+
30155
@staticmethod
31156
def to_list(cursor: CursorResult[Any]) -> List[Dict]:
32157
return [row._asdict() for row in cursor.fetchall()]
864 KB
Binary file not shown.

0 commit comments

Comments
 (0)