Skip to content

Commit ac38a61

Browse files
author
Sailesh Mukil
committed
Abstract away "database" implementations and remove strong coupling of DB implementation with APIs
1. Abstracts SpannerDatabase and MockSpannerDatabase with clear APIs 2. Introduces CloudSpannerDatabase as an implementation of SpannerDatabase 3. Removes further tight coupling with the cloud spanner client by adding a SpannerFieldInfo dataclass to replace usage of StructType.Field
1 parent 39635ea commit ac38a61

File tree

7 files changed

+337
-175
lines changed

7 files changed

+337
-175
lines changed

spanner_graphs/cloud_database.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2024 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module contains the cloud-specific implementation for talking to a Spanner database.
17+
"""
18+
19+
from __future__ import annotations
20+
import json
21+
from typing import Any, Dict, List, Tuple
22+
23+
from google.cloud import spanner
24+
from google.cloud.spanner_v1 import JsonObject
25+
from google.api_core.client_options import ClientOptions
26+
from google.cloud.spanner_v1.types import StructType, Type
27+
import pydata_google_auth
28+
29+
from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo, get_as_field_info_list
30+
31+
def _get_default_credentials_with_project():
32+
return pydata_google_auth.default(
33+
scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False)
34+
35+
class CloudSpannerDatabase(SpannerDatabase):
36+
"""Concrete implementation for Spanner database on the cloud."""
37+
def __init__(self, project_id: str, instance_id: str,
38+
database_id: str) -> None:
39+
credentials, _ = _get_default_credentials_with_project()
40+
self.client = spanner.Client(
41+
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
42+
self.instance = self.client.instance(instance_id)
43+
self.database = self.instance.database(database_id)
44+
self.schema_json: Any | None = None
45+
46+
def __repr__(self) -> str:
47+
return (f"<CloudSpannerDatabase["
48+
f"project:{self.client.project_name},"
49+
f"instance:{self.instance.name},"
50+
f"db:{self.database.name}]>")
51+
52+
def _extract_graph_name(self, query: str) -> str:
53+
words = query.strip().split()
54+
if len(words) < 3:
55+
raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)")
56+
57+
if words[0].upper() != "GRAPH":
58+
raise ValueError("invalid query: GRAPH must be the first word")
59+
60+
return words[1]
61+
62+
def _get_schema_for_graph(self, graph_query: str) -> Any | None:
63+
try:
64+
graph_name = self._extract_graph_name(graph_query)
65+
except ValueError:
66+
return None
67+
68+
with self.database.snapshot() as snapshot:
69+
schema_query = """
70+
SELECT property_graph_name, property_graph_metadata_json
71+
FROM information_schema.property_graphs
72+
WHERE property_graph_name = @graph_name
73+
"""
74+
params = {"graph_name": graph_name}
75+
param_type = {"graph_name": spanner.param_types.STRING}
76+
77+
result = snapshot.execute_sql(schema_query, params=params, param_types=param_type)
78+
schema_rows = list(result)
79+
80+
if schema_rows:
81+
return schema_rows[0][1]
82+
else:
83+
return None
84+
85+
def execute_query(
86+
self,
87+
query: str,
88+
limit: int = None,
89+
is_test_query: bool = False,
90+
) -> SpannerQueryResult:
91+
"""
92+
This method executes the provided `query`
93+
94+
Args:
95+
query: The SQL query to execute against the database
96+
limit: An optional limit for the number of rows to return
97+
is_test_query: If true, skips schema fetching for graph queries.
98+
99+
Returns:
100+
A `SpannerQueryResult`
101+
"""
102+
self.schema_json = None
103+
if not is_test_query:
104+
self.schema_json = self._get_schema_for_graph(query)
105+
106+
with self.database.snapshot() as snapshot:
107+
params = None
108+
param_types = None
109+
if limit and limit > 0:
110+
params = dict(limit=limit)
111+
112+
try:
113+
results = snapshot.execute_sql(query, params=params, param_types=param_types)
114+
rows = list(results)
115+
except Exception as e:
116+
return {}, [], [], self.schema_json, e
117+
118+
fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields)
119+
data = {field.name: [] for field in fields}
120+
121+
if len(fields) == 0:
122+
return SpannerQueryResult(
123+
data=data,
124+
fields=fields,
125+
rows=rows,
126+
schema_json=self.schema_json,
127+
error=None
128+
)
129+
130+
for row_data in rows:
131+
for field, value in zip(fields, row_data):
132+
if isinstance(value, JsonObject):
133+
data[field.name].append(json.loads(value.serialize()))
134+
else:
135+
data[field.name].append(value)
136+
137+
return SpannerQueryResult(
138+
data=data,
139+
fields=fields,
140+
rows=rows,
141+
schema_json=self.schema_json,
142+
error=None
143+
)
144+
145+
class CloudMockSpannerResult:
146+
147+
def __init__(self, file_path: str):
148+
self.file_path = file_path
149+
self.fields: List[StructType] = []
150+
self._rows: List[List[Any]] = []
151+
self._load_data()
152+
153+
def _load_data(self):
154+
with open(self.file_path, "r", encoding="utf-8") as csvfile:
155+
csv_reader = csv.reader(csvfile)
156+
headers = next(csv_reader)
157+
self.fields = [
158+
StructType.Field(name=header, type_=Type(code=TypeCode.JSON))
159+
for header in headers
160+
]
161+
162+
for row in csv_reader:
163+
parsed_row = []
164+
for value in row:
165+
try:
166+
js = bytes(value, "utf-8").decode("unicode_escape")
167+
parsed_row.append(json.loads(js))
168+
except json.JSONDecodeError:
169+
pass
170+
self._rows.append(parsed_row)
171+
172+
def __iter__(self):
173+
return iter(self._rows)
174+
175+
176+
class CloudMockSpannerDatabase(MockSpannerDatabase):
177+
"""Cloud Mock database class"""
178+
179+
def __init__(self):
180+
dirname = os.path.dirname(__file__)
181+
self.graph_csv_path = os.path.join(
182+
dirname, "graph_mock_data.csv")
183+
self.schema_json_path = os.path.join(
184+
dirname, "graph_mock_schema.json")
185+
self.schema_json: dict = {}
186+
187+
def execute_query(
188+
self,
189+
_: str,
190+
limit: int = 5
191+
) -> SpannerQueryResult:
192+
"""Mock execution of query"""
193+
194+
# Before the actual query we fetch the schema as well
195+
with open(self.schema_json_path, "r", encoding="utf-8") as js:
196+
self.schema_json = json.load(js)
197+
198+
results = CloudMockSpannerResult(self.graph_csv_path)
199+
fields: List[SpannerFieldInfo] = get_as_field_info(results.fields)
200+
rows = list(results)
201+
data = {field.name: [] for field in fields}
202+
203+
if len(fields) == 0:
204+
return data, fields, rows
205+
206+
for i, row in enumerate(results):
207+
if limit is not None and i >= limit:
208+
break
209+
for field, value in zip(fields, row):
210+
data[field.name].append(value)
211+
212+
return SpannerQueryResult(
213+
data=data,
214+
fields=fields,
215+
rows=rows,
216+
schema_json=self.schema_json,
217+
error=None
218+
)

spanner_graphs/conversion.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323

2424
from google.cloud.spanner_v1.types import TypeCode, StructType
2525

26+
from spanner_graphs.database import SpannerFieldInfo
2627
from spanner_graphs.graph_entities import Node, Edge
2728
from spanner_graphs.schema_manager import SchemaManager
2829

29-
def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]:
30+
def get_nodes_edges(data: Dict[str, List[Any]], fields: List[SpannerFieldInfo], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]:
3031
schema_manager = SchemaManager(schema_json)
3132
nodes: List[Node] = []
3233
edges: List[Edge] = []
@@ -37,15 +38,15 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field],
3738
for field in fields:
3839
column_name = field.name
3940
column_data = data[column_name]
40-
41+
4142
# Only process JSON and Array of JSON types
42-
if field.type_.code not in [TypeCode.JSON, TypeCode.ARRAY]:
43+
if field.typename not in ["JSON", "ARRAY"]:
4344
continue
4445

4546
# Process each value in the column
4647
for value in column_data:
4748
items_to_process = []
48-
49+
4950
# Handle both single JSON and arrays of JSON
5051
if isinstance(value, list):
5152
items_to_process.extend(value)
@@ -92,4 +93,4 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field],
9293
nodes.append(Node.make_intermediate(identifier))
9394
node_identifiers.add(identifier)
9495

95-
return nodes, edges
96+
return nodes, edges

0 commit comments

Comments
 (0)