diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index b879d73..d3e06fc 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -18,7 +18,7 @@ """ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, NamedTuple import json import os import csv @@ -29,6 +29,19 @@ from google.cloud.spanner_v1.types import StructType, TypeCode, Type import pydata_google_auth +class SpannerQueryResult(NamedTuple): + # A dict where each key is a field name returned in the query and the list + # contains all items of the same type found for the given field + data: Dict[str, List[Any]] + # A list representing the fields in the result set. + fields: List[Any] + # A list of rows as returned by the query execution. + rows: List[Any] + # An optional field to return the schema as JSON + schema_json: Any | None + # The error message if any + error: Exception | None + def _get_default_credentials_with_project(): return pydata_google_auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) @@ -87,7 +100,7 @@ def execute_query( query: str, limit: int = None, is_test_query: bool = False, - ): + ) -> SpannerQueryResult: """ This method executes the provided `query` @@ -96,13 +109,7 @@ def execute_query( limit: An optional limit for the number of rows to return Returns: - A tuple containing: - - Dict[str, List[Any]]: A dict where each key is a field name - returned in the query and the list contains all items of the same - type found for the given field. - - A list of StructType.Fields representing the fields in the result set. - - A list of rows as returned by the query execution. - - The error message if any. + A SpannerQueryResult tuple """ self.schema_json = None if not is_test_query: @@ -131,9 +138,14 @@ def execute_query( data[field.name].append(json.loads(value.serialize())) else: data[field.name].append(value) + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) - return data, fields, rows, self.schema_json, None - class MockSpannerResult: def __init__(self, file_path: str): @@ -180,7 +192,7 @@ def execute_query( self, _: str, limit: int = 5 - ) -> Tuple[Dict[str, List[Any]], List[StructType.Field], List, str]: + ) -> SpannerQueryResult: """Mock execution of query""" # Before the actual query we fetch the schema as well @@ -201,7 +213,13 @@ def execute_query( for field, value in zip(fields, row): data[field.name].append(value) - return data, fields, rows, self.schema_json, None + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = { @@ -221,4 +239,4 @@ def get_database_instance(project: str, instance: str, database: str, mock = Fal db = SpannerDatabase(project, instance, database) database_instances[key] = db - return db \ No newline at end of file + return db diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 7860c49..fe34d36 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -17,7 +17,7 @@ import json import threading from enum import Enum -from typing import Union +from typing import Union, Dict, Any import requests import portpicker @@ -206,49 +206,70 @@ def execute_node_expansion( return execute_query(project, instance, database, query, mock=False) -def execute_query(project: str, instance: str, database: str, query: str, mock = False): - database = get_database_instance(project, instance, database, mock) +def execute_query( + project: str, + instance: str, + database: str, + query: str, + mock: bool = False, +) -> Dict[str, Any]: + """Executes a query against a database and formats the result. + Connects to a database, runs the query, and processes the resulting object. + On success, it formats the data into nodes and edges for graph visualization. + If the query fails, it returns a detailed error message, optionally + including the database schema to aid in debugging. + + Args: + project: The cloud project ID. + instance: The database instance name. + database: The database name. + query: The query string to execute. + mock: If True, use a mock database instance for testing. Defaults to False. + + Returns: + A dictionary containing either the structured 'response' with nodes, + edges, and other data, or an 'error' key with a descriptive message. + """ try: - query_result, fields, rows, schema_json, err = database.execute_query(query) - if len(rows) == 0 and err : # if query returned an error - if schema_json: # if the schema exists - return { - "response": { - "schema": schema_json, - "query_result": query_result, - "nodes": [], - "edges": [], - "rows": [] - }, - "error": f"We've detected an error in your query. To help you troubleshoot, the graph schema's layout is shown above." + "\n\n" + f"Query error: \n{getattr(err, 'message', str(err))}" - } - if not schema_json: # if the schema does not exist - return { - "response": { - "schema": schema_json, - "query_result": query_result, - "nodes": [], - "edges": [], - "rows": [] - }, - "error": f"Query error: \n{getattr(err, 'message', str(err))}" - } - nodes, edges = get_nodes_edges(query_result, fields, schema_json) - + db_instance = get_database_instance(project, instance, database, mock) + result: SpannerQueryResult = db_instance.execute_query(query) + + if len(result.rows) == 0 and result.err: + error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}" + if result.schema_json: + # Prepend a helpful message if the schema is available + error_message = ( + "We've detected an error in your query. To help you troubleshoot, " + "the graph schema's layout is shown above.\n\n" + error_message + ) + + # Consolidate the repetitive error response into a single return + return { + "response": { + "schema": result.schema_json, + "query_result": result.data, + "nodes": [], + "edges": [], + "rows": [], + }, + "error": error_message, + } + + # Process a successful query result + nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json) + return { "response": { "nodes": [node.to_json() for node in nodes], "edges": [edge.to_json() for edge in edges], - "schema": schema_json, - "rows": rows, - "query_result": query_result + "schema": result.schema_json, + "rows": result.rows, + "query_result": result.data, } } except Exception as e: - return { - "error": getattr(e, "message", str(e)) - } + return {"error": getattr(e, "message", str(e))} class GraphServer: diff --git a/tests/conversion_test.py b/tests/conversion_test.py index e53c56c..b014d3e 100644 --- a/tests/conversion_test.py +++ b/tests/conversion_test.py @@ -38,10 +38,10 @@ def test_get_nodes_edges(self) -> None: """ # Get data from mock database mock_db = MockSpannerDatabase() - data, fields, _, schema_json = mock_db.execute_query("") + query_result = mock_db.execute_query("") # Convert data to nodes and edges - nodes, edges = get_nodes_edges(data, fields) + nodes, edges = get_nodes_edges(query_result.data, query_result.fields) # Verify we got some nodes and edges self.assertTrue(len(nodes) > 0, "Should have at least one node")