Skip to content

Commit 39635ea

Browse files
author
Sailesh Mukil
committed
Standardize return type for execute_query() functions
- Also removes cloud-spanner specific fields from the return type. Specifically `StructType.Field` is removed from the return type. Removing this tightly coupled logic is required to allow new DB implementations.
1 parent 0f99a10 commit 39635ea

File tree

3 files changed

+90
-51
lines changed

3 files changed

+90
-51
lines changed

spanner_graphs/database.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
from __future__ import annotations
21-
from typing import Any, Dict, List, Tuple
21+
from typing import Any, Dict, List, Tuple, NamedTuple
2222
import json
2323
import os
2424
import csv
@@ -29,6 +29,19 @@
2929
from google.cloud.spanner_v1.types import StructType, TypeCode, Type
3030
import pydata_google_auth
3131

32+
class SpannerQueryResult(NamedTuple):
33+
# A dict where each key is a field name returned in the query and the list
34+
# contains all items of the same type found for the given field
35+
data: Dict[str, List[Any]]
36+
# A list representing the fields in the result set.
37+
fields: List[Any]
38+
# A list of rows as returned by the query execution.
39+
rows: List[Any]
40+
# An optional field to return the schema as JSON
41+
schema_json: Any | None
42+
# The error message if any
43+
error: Exception | None
44+
3245
def _get_default_credentials_with_project():
3346
return pydata_google_auth.default(
3447
scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False)
@@ -87,7 +100,7 @@ def execute_query(
87100
query: str,
88101
limit: int = None,
89102
is_test_query: bool = False,
90-
):
103+
) -> SpannerQueryResult:
91104
"""
92105
This method executes the provided `query`
93106
@@ -96,13 +109,7 @@ def execute_query(
96109
limit: An optional limit for the number of rows to return
97110
98111
Returns:
99-
A tuple containing:
100-
- Dict[str, List[Any]]: A dict where each key is a field name
101-
returned in the query and the list contains all items of the same
102-
type found for the given field.
103-
- A list of StructType.Fields representing the fields in the result set.
104-
- A list of rows as returned by the query execution.
105-
- The error message if any.
112+
A SpannerQueryResult tuple
106113
"""
107114
self.schema_json = None
108115
if not is_test_query:
@@ -131,9 +138,14 @@ def execute_query(
131138
data[field.name].append(json.loads(value.serialize()))
132139
else:
133140
data[field.name].append(value)
141+
return SpannerQueryResult(
142+
data=data,
143+
fields=fields,
144+
rows=rows,
145+
schema_json=self.schema_json,
146+
error=None
147+
)
134148

135-
return data, fields, rows, self.schema_json, None
136-
137149
class MockSpannerResult:
138150

139151
def __init__(self, file_path: str):
@@ -180,7 +192,7 @@ def execute_query(
180192
self,
181193
_: str,
182194
limit: int = 5
183-
) -> Tuple[Dict[str, List[Any]], List[StructType.Field], List, str]:
195+
) -> SpannerQueryResult:
184196
"""Mock execution of query"""
185197

186198
# Before the actual query we fetch the schema as well
@@ -201,7 +213,13 @@ def execute_query(
201213
for field, value in zip(fields, row):
202214
data[field.name].append(value)
203215

204-
return data, fields, rows, self.schema_json, None
216+
return SpannerQueryResult(
217+
data=data,
218+
fields=fields,
219+
rows=rows,
220+
schema_json=self.schema_json,
221+
error=None
222+
)
205223

206224

207225
database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = {
@@ -221,4 +239,4 @@ def get_database_instance(project: str, instance: str, database: str, mock = Fal
221239
db = SpannerDatabase(project, instance, database)
222240
database_instances[key] = db
223241

224-
return db
242+
return db

spanner_graphs/graph_server.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import json
1818
import threading
1919
from enum import Enum
20-
from typing import Union
20+
from typing import Union, Dict, Any
2121

2222
import requests
2323
import portpicker
@@ -206,49 +206,70 @@ def execute_node_expansion(
206206

207207
return execute_query(project, instance, database, query, mock=False)
208208

209-
def execute_query(project: str, instance: str, database: str, query: str, mock = False):
210-
database = get_database_instance(project, instance, database, mock)
209+
def execute_query(
210+
project: str,
211+
instance: str,
212+
database: str,
213+
query: str,
214+
mock: bool = False,
215+
) -> Dict[str, Any]:
216+
"""Executes a query against a database and formats the result.
211217
218+
Connects to a database, runs the query, and processes the resulting object.
219+
On success, it formats the data into nodes and edges for graph visualization.
220+
If the query fails, it returns a detailed error message, optionally
221+
including the database schema to aid in debugging.
222+
223+
Args:
224+
project: The cloud project ID.
225+
instance: The database instance name.
226+
database: The database name.
227+
query: The query string to execute.
228+
mock: If True, use a mock database instance for testing. Defaults to False.
229+
230+
Returns:
231+
A dictionary containing either the structured 'response' with nodes,
232+
edges, and other data, or an 'error' key with a descriptive message.
233+
"""
212234
try:
213-
query_result, fields, rows, schema_json, err = database.execute_query(query)
214-
if len(rows) == 0 and err : # if query returned an error
215-
if schema_json: # if the schema exists
216-
return {
217-
"response": {
218-
"schema": schema_json,
219-
"query_result": query_result,
220-
"nodes": [],
221-
"edges": [],
222-
"rows": []
223-
},
224-
"error": f"We've detected an error in your query. To help you troubleshoot, the graph schema's layout is shown above." + "\n\n" + f"Query error: \n{getattr(err, 'message', str(err))}"
225-
}
226-
if not schema_json: # if the schema does not exist
227-
return {
228-
"response": {
229-
"schema": schema_json,
230-
"query_result": query_result,
231-
"nodes": [],
232-
"edges": [],
233-
"rows": []
234-
},
235-
"error": f"Query error: \n{getattr(err, 'message', str(err))}"
236-
}
237-
nodes, edges = get_nodes_edges(query_result, fields, schema_json)
238-
235+
db_instance = get_database_instance(project, instance, database, mock)
236+
result: SpannerQueryResult = db_instance.execute_query(query)
237+
238+
if len(result.rows) == 0 and result.err:
239+
error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}"
240+
if result.schema_json:
241+
# Prepend a helpful message if the schema is available
242+
error_message = (
243+
"We've detected an error in your query. To help you troubleshoot, "
244+
"the graph schema's layout is shown above.\n\n" + error_message
245+
)
246+
247+
# Consolidate the repetitive error response into a single return
248+
return {
249+
"response": {
250+
"schema": result.schema_json,
251+
"query_result": result.data,
252+
"nodes": [],
253+
"edges": [],
254+
"rows": [],
255+
},
256+
"error": error_message,
257+
}
258+
259+
# Process a successful query result
260+
nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json)
261+
239262
return {
240263
"response": {
241264
"nodes": [node.to_json() for node in nodes],
242265
"edges": [edge.to_json() for edge in edges],
243-
"schema": schema_json,
244-
"rows": rows,
245-
"query_result": query_result
266+
"schema": result.schema_json,
267+
"rows": result.rows,
268+
"query_result": result.data,
246269
}
247270
}
248271
except Exception as e:
249-
return {
250-
"error": getattr(e, "message", str(e))
251-
}
272+
return {"error": getattr(e, "message", str(e))}
252273

253274

254275
class GraphServer:

tests/conversion_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def test_get_nodes_edges(self) -> None:
3838
"""
3939
# Get data from mock database
4040
mock_db = MockSpannerDatabase()
41-
data, fields, _, schema_json = mock_db.execute_query("")
41+
query_result = mock_db.execute_query("")
4242

4343
# Convert data to nodes and edges
44-
nodes, edges = get_nodes_edges(data, fields)
44+
nodes, edges = get_nodes_edges(query_result.data, query_result.fields)
4545

4646
# Verify we got some nodes and edges
4747
self.assertTrue(len(nodes) > 0, "Should have at least one node")

0 commit comments

Comments
 (0)