Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions spanner_graphs/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, NamedTuple
import json
import os
import csv
Expand All @@ -29,6 +29,19 @@
from google.cloud.spanner_v1.types import StructType, TypeCode, Type
import pydata_google_auth

class SpannerQueryResult(NamedTuple):
# A dict where each key is a field name returned in the query and the list
# contains all items of the same type found for the given field
data: Dict[str, List[Any]]
# A list representing the fields in the result set.
fields: List[Any]
# A list of rows as returned by the query execution.
rows: List[Any]
# An optional field to return the schema as JSON
schema_json: Any | None
# The error message if any
error: Exception | None

def _get_default_credentials_with_project():
return pydata_google_auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False)
Expand Down Expand Up @@ -87,7 +100,7 @@ def execute_query(
query: str,
limit: int = None,
is_test_query: bool = False,
):
) -> SpannerQueryResult:
"""
This method executes the provided `query`

Expand All @@ -96,13 +109,7 @@ def execute_query(
limit: An optional limit for the number of rows to return

Returns:
A tuple containing:
- Dict[str, List[Any]]: A dict where each key is a field name
returned in the query and the list contains all items of the same
type found for the given field.
- A list of StructType.Fields representing the fields in the result set.
- A list of rows as returned by the query execution.
- The error message if any.
A SpannerQueryResult tuple
"""
self.schema_json = None
if not is_test_query:
Expand Down Expand Up @@ -131,9 +138,14 @@ def execute_query(
data[field.name].append(json.loads(value.serialize()))
else:
data[field.name].append(value)
return SpannerQueryResult(
data=data,
fields=fields,
rows=rows,
schema_json=self.schema_json,
error=None
)

return data, fields, rows, self.schema_json, None

class MockSpannerResult:

def __init__(self, file_path: str):
Expand Down Expand Up @@ -180,7 +192,7 @@ def execute_query(
self,
_: str,
limit: int = 5
) -> Tuple[Dict[str, List[Any]], List[StructType.Field], List, str]:
) -> SpannerQueryResult:
"""Mock execution of query"""

# Before the actual query we fetch the schema as well
Expand All @@ -201,7 +213,13 @@ def execute_query(
for field, value in zip(fields, row):
data[field.name].append(value)

return data, fields, rows, self.schema_json, None
return SpannerQueryResult(
data=data,
fields=fields,
rows=rows,
schema_json=self.schema_json,
error=None
)


database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = {
Expand All @@ -221,4 +239,4 @@ def get_database_instance(project: str, instance: str, database: str, mock = Fal
db = SpannerDatabase(project, instance, database)
database_instances[key] = db

return db
return db
91 changes: 56 additions & 35 deletions spanner_graphs/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import threading
from enum import Enum
from typing import Union
from typing import Union, Dict, Any

import requests
import portpicker
Expand Down Expand Up @@ -206,49 +206,70 @@ def execute_node_expansion(

return execute_query(project, instance, database, query, mock=False)

def execute_query(project: str, instance: str, database: str, query: str, mock = False):
database = get_database_instance(project, instance, database, mock)
def execute_query(
project: str,
instance: str,
database: str,
query: str,
mock: bool = False,
) -> Dict[str, Any]:
"""Executes a query against a database and formats the result.

Connects to a database, runs the query, and processes the resulting object.
On success, it formats the data into nodes and edges for graph visualization.
If the query fails, it returns a detailed error message, optionally
including the database schema to aid in debugging.

Args:
project: The cloud project ID.
instance: The database instance name.
database: The database name.
query: The query string to execute.
mock: If True, use a mock database instance for testing. Defaults to False.

Returns:
A dictionary containing either the structured 'response' with nodes,
edges, and other data, or an 'error' key with a descriptive message.
"""
try:
query_result, fields, rows, schema_json, err = database.execute_query(query)
if len(rows) == 0 and err : # if query returned an error
if schema_json: # if the schema exists
return {
"response": {
"schema": schema_json,
"query_result": query_result,
"nodes": [],
"edges": [],
"rows": []
},
"error": f"We've detected an error in your query. To help you troubleshoot, the graph schema's layout is shown above." + "\n\n" + f"Query error: \n{getattr(err, 'message', str(err))}"
}
if not schema_json: # if the schema does not exist
return {
"response": {
"schema": schema_json,
"query_result": query_result,
"nodes": [],
"edges": [],
"rows": []
},
"error": f"Query error: \n{getattr(err, 'message', str(err))}"
}
nodes, edges = get_nodes_edges(query_result, fields, schema_json)

db_instance = get_database_instance(project, instance, database, mock)
result: SpannerQueryResult = db_instance.execute_query(query)

if len(result.rows) == 0 and result.err:
error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}"
if result.schema_json:
# Prepend a helpful message if the schema is available
error_message = (
"We've detected an error in your query. To help you troubleshoot, "
"the graph schema's layout is shown above.\n\n" + error_message
)

# Consolidate the repetitive error response into a single return
return {
"response": {
"schema": result.schema_json,
"query_result": result.data,
"nodes": [],
"edges": [],
"rows": [],
},
"error": error_message,
}

# Process a successful query result
nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json)

return {
"response": {
"nodes": [node.to_json() for node in nodes],
"edges": [edge.to_json() for edge in edges],
"schema": schema_json,
"rows": rows,
"query_result": query_result
"schema": result.schema_json,
"rows": result.rows,
"query_result": result.data,
}
}
except Exception as e:
return {
"error": getattr(e, "message", str(e))
}
return {"error": getattr(e, "message", str(e))}


class GraphServer:
Expand Down
4 changes: 2 additions & 2 deletions tests/conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def test_get_nodes_edges(self) -> None:
"""
# Get data from mock database
mock_db = MockSpannerDatabase()
data, fields, _, schema_json = mock_db.execute_query("")
query_result = mock_db.execute_query("")

# Convert data to nodes and edges
nodes, edges = get_nodes_edges(data, fields)
nodes, edges = get_nodes_edges(query_result.data, query_result.fields)

# Verify we got some nodes and edges
self.assertTrue(len(nodes) > 0, "Should have at least one node")
Expand Down