Skip to content

Commit b064be3

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: add chat first-party tool
This tool answers questions about structured data in BigQuery using natural language. PiperOrigin-RevId: 776833839
1 parent 96a0d4b commit b064be3

File tree

12 files changed

+984
-8
lines changed

12 files changed

+984
-8
lines changed

contributing/samples/bigquery/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ distributed via the `google.adk.tools.bigquery` module. These tools include:
2525

2626
Runs a SQL query in BigQuery.
2727

28+
1. `chat`
29+
30+
Natural language-in, natural language-out chat tool that answers questions
31+
about structured data in BigQuery. Provide a one-stop solution for generating
32+
insights from data.
33+
34+
2835
## How to use
2936

3037
Set up environment variables in your `.env` file for using

src/google/adk/tools/bigquery/bigquery_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def __init__(
6565
if credentials_config
6666
else None
6767
)
68-
self._tool_config = bigquery_tool_config
68+
self._tool_config = (
69+
bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig()
70+
)
6971

7072
@override
7173
async def run_async(

src/google/adk/tools/bigquery/bigquery_toolset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.adk.agents.readonly_context import ReadonlyContext
2222
from typing_extensions import override
2323

24+
from . import chat_tool
2425
from . import metadata_tool
2526
from . import query_tool
2627
from ...tools.base_tool import BaseTool
@@ -78,6 +79,7 @@ async def get_tools(
7879
metadata_tool.list_dataset_ids,
7980
metadata_tool.list_table_ids,
8081
query_tool.get_execute_sql(self._tool_config),
82+
chat_tool.chat,
8183
]
8284
]
8385

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from typing import Any
17+
from typing import Dict
18+
from typing import List
19+
20+
from google.auth.credentials import Credentials
21+
from google.cloud import bigquery
22+
import requests
23+
24+
from . import client
25+
from .config import BigQueryToolConfig
26+
27+
28+
def chat(
29+
project_id: str,
30+
user_query_with_context: str,
31+
table_references: List[Dict[str, str]],
32+
credentials: Credentials,
33+
config: BigQueryToolConfig,
34+
) -> Dict[str, Any]:
35+
"""Answers questions about structured data in BigQuery tables using natural language.
36+
37+
This function acts as a client for a "chat-with-your-data" service. It takes a
38+
user's question (which can include conversational history for context) and
39+
references to specific BigQuery tables, and sends them to a stateless
40+
conversational API.
41+
42+
The API uses a GenAI agent to understand the question, generate and execute
43+
SQL queries and Python code, and formulate an answer. This function returns a
44+
detailed, sequential log of this entire process, which includes any generated
45+
SQL or Python code, the data retrieved, and the final text answer.
46+
47+
Use this tool to perform data analysis, get insights, or answer complex
48+
questions about the contents of specific BigQuery tables.
49+
50+
Args:
51+
project_id (str): The project that the chat is performed in.
52+
user_query_with_context (str): The user's question, potentially including
53+
conversation history and system instructions for context.
54+
table_references (List[Dict[str, str]]): A list of dictionaries, each
55+
specifying a BigQuery table to be used as context for the question.
56+
credentials (Credentials): The credentials to use for the request.
57+
config (BigQueryToolConfig): The configuration for the tool.
58+
59+
Returns:
60+
A dictionary with two keys:
61+
- 'status': A string indicating the final status (e.g., "SUCCESS").
62+
- 'response': A list of dictionaries, where each dictionary
63+
represents a step in the API's execution process (e.g., SQL
64+
generation, data retrieval, final answer).
65+
66+
Example:
67+
A query joining multiple tables, showing the full return structure.
68+
>>> chat(
69+
... project_id="some-project-id",
70+
... user_query_with_context="Which customer from New York spent the
71+
most last month? "
72+
... "Context: The 'customers' table joins with
73+
the 'orders' table "
74+
... "on the 'customer_id' column.",
75+
... table_references=[
76+
... {
77+
... "projectId": "my-gcp-project",
78+
... "datasetId": "sales_data",
79+
... "tableId": "customers"
80+
... },
81+
... {
82+
... "projectId": "my-gcp-project",
83+
... "datasetId": "sales_data",
84+
... "tableId": "orders"
85+
... }
86+
... ]
87+
... )
88+
{
89+
"status": "SUCCESS",
90+
"response": [
91+
{
92+
"SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... "
93+
},
94+
{
95+
"Data Retrieved": {
96+
"headers": ["customer_name", "total_spent"],
97+
"rows": [["Jane Doe", 1234.56]],
98+
"summary": "Showing all 1 rows."
99+
}
100+
},
101+
{
102+
"Answer": "The customer who spent the most was Jane Doe."
103+
}
104+
]
105+
}
106+
"""
107+
try:
108+
location = "global"
109+
if not credentials.token:
110+
error_message = (
111+
"Error: The provided credentials object does not have a valid access"
112+
" token.\n\nThis is often because the credentials need to be"
113+
" refreshed or require specific API scopes. Please ensure the"
114+
" credentials are prepared correctly before calling this"
115+
" function.\n\nThere may be other underlying causes as well."
116+
)
117+
return {
118+
"status": "ERROR",
119+
"error_details": "Chat requires a valid access token.",
120+
}
121+
headers = {
122+
"Authorization": f"Bearer {credentials.token}",
123+
"Content-Type": "application/json",
124+
}
125+
chat_url = f"https://geminidataanalytics.googleapis.com/v1alpha/projects/{project_id}/locations/{location}:chat"
126+
127+
chat_payload = {
128+
"project": f"projects/{project_id}",
129+
"messages": [{"userMessage": {"text": user_query_with_context}}],
130+
"inlineContext": {
131+
"datasourceReferences": {
132+
"bq": {"tableReferences": table_references}
133+
},
134+
"options": {"chart": {"image": {"noImage": {}}}},
135+
},
136+
}
137+
138+
resp = _get_stream(
139+
chat_url, chat_payload, headers, config.max_query_result_rows
140+
)
141+
except Exception as ex: # pylint: disable=broad-except
142+
return {
143+
"status": "ERROR",
144+
"error_details": str(ex),
145+
}
146+
return {"status": "SUCCESS", "response": resp}
147+
148+
149+
def _get_stream(
150+
url: str,
151+
chat_payload: Dict[str, Any],
152+
headers: Dict[str, str],
153+
max_query_result_rows: int,
154+
) -> List[Dict[str, Any]]:
155+
"""Sends a JSON request to a streaming API and returns a list of messages."""
156+
s = requests.Session()
157+
158+
accumulator = ""
159+
messages = []
160+
161+
with s.post(url, json=chat_payload, headers=headers, stream=True) as resp:
162+
for line in resp.iter_lines():
163+
if not line:
164+
continue
165+
166+
decoded_line = str(line, encoding="utf-8")
167+
168+
if decoded_line == "[{":
169+
accumulator = "{"
170+
elif decoded_line == "}]":
171+
accumulator += "}"
172+
elif decoded_line == ",":
173+
continue
174+
else:
175+
accumulator += decoded_line
176+
177+
if not _is_json(accumulator):
178+
continue
179+
180+
data_json = json.loads(accumulator)
181+
if "systemMessage" not in data_json:
182+
if "error" in data_json:
183+
_append_message(messages, _handle_error(data_json["error"]))
184+
continue
185+
186+
system_message = data_json["systemMessage"]
187+
if "text" in system_message:
188+
_append_message(messages, _handle_text_response(system_message["text"]))
189+
elif "schema" in system_message:
190+
_append_message(
191+
messages,
192+
_handle_schema_response(system_message["schema"]),
193+
)
194+
elif "data" in system_message:
195+
_append_message(
196+
messages,
197+
_handle_data_response(
198+
system_message["data"], max_query_result_rows
199+
),
200+
)
201+
accumulator = ""
202+
return messages
203+
204+
205+
def _is_json(s: str) -> bool:
206+
"""Checks if a string is a valid JSON object."""
207+
try:
208+
json.loads(s)
209+
except ValueError:
210+
return False
211+
return True
212+
213+
214+
def _get_property(
215+
data: Dict[str, Any], field_name: str, default: Any = ""
216+
) -> Any:
217+
"""Safely gets a property from a dictionary."""
218+
return data.get(field_name, default)
219+
220+
221+
def _format_bq_table_ref(table_ref: Dict[str, str]) -> str:
222+
"""Formats a BigQuery table reference dictionary into a string."""
223+
return f"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}"
224+
225+
226+
def _format_schema_as_dict(
227+
data: Dict[str, Any],
228+
) -> Dict[str, List[Any]]:
229+
"""Extracts schema fields into a dictionary."""
230+
fields = data.get("fields", [])
231+
if not fields:
232+
return {"columns": []}
233+
234+
column_details = []
235+
headers = ["Column", "Type", "Description", "Mode"]
236+
rows: List[Tuple[str, str, str, str]] = []
237+
for field in fields:
238+
row_tuple = (
239+
_get_property(field, "name"),
240+
_get_property(field, "type"),
241+
_get_property(field, "description", ""),
242+
_get_property(field, "mode"),
243+
)
244+
rows.append(row_tuple)
245+
246+
return {"headers": headers, "rows": rows}
247+
248+
249+
def _format_datasource_as_dict(datasource: Dict[str, Any]) -> Dict[str, Any]:
250+
"""Formats a full datasource object into a dictionary with its name and schema."""
251+
source_name = _format_bq_table_ref(datasource["bigqueryTableReference"])
252+
253+
schema = _format_schema_as_dict(datasource["schema"])
254+
return {"source_name": source_name, "schema": schema}
255+
256+
257+
def _handle_text_response(resp: Dict[str, Any]) -> Dict[str, str]:
258+
"""Formats a text response into a dictionary."""
259+
parts = resp.get("parts", [])
260+
return {"Answer": "".join(parts)}
261+
262+
263+
def _handle_schema_response(resp: Dict[str, Any]) -> Dict[str, Any]:
264+
"""Formats a schema response into a dictionary."""
265+
if "query" in resp:
266+
return {"Question": resp["query"].get("question", "")}
267+
elif "result" in resp:
268+
datasources = resp["result"].get("datasources", [])
269+
# Format each datasource and join them with newlines
270+
formatted_sources = [_format_datasource_as_dict(ds) for ds in datasources]
271+
return {"Schema Resolved": formatted_sources}
272+
return {}
273+
274+
275+
def _handle_data_response(
276+
resp: Dict[str, Any], max_query_result_rows: int
277+
) -> Dict[str, Any]:
278+
"""Formats a data response into a dictionary."""
279+
if "query" in resp:
280+
query = resp["query"]
281+
return {
282+
"Retrieval Query": {
283+
"Query Name": query.get("name", "N/A"),
284+
"Question": query.get("question", "N/A"),
285+
}
286+
}
287+
elif "generatedSql" in resp:
288+
return {"SQL Generated": resp["generatedSql"]}
289+
elif "result" in resp:
290+
schema = resp["result"]["schema"]
291+
headers = [field.get("name") for field in schema.get("fields", [])]
292+
293+
all_rows = resp["result"]["data"]
294+
total_rows = len(all_rows)
295+
296+
compact_rows = []
297+
for row_dict in all_rows[:max_query_result_rows]:
298+
row_values = [row_dict.get(header) for header in headers]
299+
compact_rows.append(row_values)
300+
301+
summary_string = f"Showing all {total_rows} rows."
302+
if total_rows > max_query_result_rows:
303+
summary_string = (
304+
f"Showing the first {len(compact_rows)} of {total_rows} total rows."
305+
)
306+
307+
return {
308+
"Data Retrieved": {
309+
"headers": headers,
310+
"rows": compact_rows,
311+
"total_rows": total_rows,
312+
"summary": summary_string,
313+
}
314+
}
315+
316+
return {}
317+
318+
319+
def _handle_error(resp: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
320+
"""Formats an error response into a dictionary."""
321+
return {
322+
"Error": {
323+
"Code": resp.get("code", "N/A"),
324+
"Message": resp.get("message", "No message provided."),
325+
}
326+
}
327+
328+
329+
def _append_message(
330+
messages: List[Dict[str, Any]], new_message: Dict[str, Any]
331+
):
332+
if not new_message:
333+
return
334+
335+
if messages and "Data Retrieved" in messages[-1]:
336+
messages.pop()
337+
338+
messages.append(new_message)

src/google/adk/tools/bigquery/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Optional
18+
1719
import google.api_core.client_info
1820
from google.auth.credentials import Credentials
1921
from google.cloud import bigquery
@@ -24,7 +26,7 @@
2426

2527

2628
def get_bigquery_client(
27-
*, project: str, credentials: Credentials
29+
*, project: Optional[str], credentials: Credentials
2830
) -> bigquery.Client:
2931
"""Get a BigQuery client."""
3032

0 commit comments

Comments
 (0)