Skip to content

Commit 773d4b7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: add chat first-party tool
This tool answers questions about structured data in BigQuery using natural language. PiperOrigin-RevId: 776833839
1 parent 6e68c2d commit 773d4b7

File tree

12 files changed

+1068
-8
lines changed

12 files changed

+1068
-8
lines changed

contributing/samples/bigquery/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ distributed via the `google.adk.tools.bigquery` module. These tools include:
2525

2626
Runs a SQL query in BigQuery.
2727

28+
1. `chat`
29+
30+
Natural language-in, natural language-out chat tool that answers questions
31+
about structured data in BigQuery. Provide a one-stop solution for generating
32+
insights from data.
33+
34+
2835
## How to use
2936

3037
Set up environment variables in your `.env` file for using

src/google/adk/tools/bigquery/bigquery_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def __init__(
6565
if credentials_config
6666
else None
6767
)
68-
self._tool_config = bigquery_tool_config
68+
self._tool_config = (
69+
bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig()
70+
)
6971

7072
@override
7173
async def run_async(

src/google/adk/tools/bigquery/bigquery_toolset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.adk.agents.readonly_context import ReadonlyContext
2222
from typing_extensions import override
2323

24+
from . import chat_tool
2425
from . import metadata_tool
2526
from . import query_tool
2627
from ...tools.base_tool import BaseTool
@@ -78,6 +79,7 @@ async def get_tools(
7879
metadata_tool.list_dataset_ids,
7980
metadata_tool.list_table_ids,
8081
query_tool.get_execute_sql(self._tool_config),
82+
chat_tool.chat,
8183
]
8284
]
8385

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from typing import Any
17+
from typing import Dict
18+
from typing import List
19+
20+
from google.auth.credentials import Credentials
21+
from google.cloud import bigquery
22+
import requests
23+
24+
from . import client
25+
from .config import BigQueryToolConfig
26+
27+
28+
def chat(
29+
project_id: str,
30+
user_query_with_context: str,
31+
table_references: List[Dict[str, str]],
32+
credentials: Credentials,
33+
config: BigQueryToolConfig,
34+
) -> Dict[str, str]:
35+
"""Answers questions about structured data in BigQuery tables using natural language.
36+
37+
This function acts as a client for a "chat-with-your-data" service. It takes a
38+
user's question (which can include conversational history for context) and
39+
references to specific BigQuery tables, and sends them to a stateless
40+
conversational API.
41+
42+
The API uses a GenAI agent to understand the question, generate and execute
43+
SQL queries and Python code, and formulate an answer. This function returns a
44+
detailed, sequential log of this entire process, which includes any generated
45+
SQL or Python code, the data retrieved, and the final text answer.
46+
47+
Use this tool to perform data analysis, get insights, or answer complex
48+
questions about the contents of specific BigQuery tables.
49+
50+
Args:
51+
project_id (str): The project that the chat is performed in.
52+
user_query_with_context (str): The user's question, potentially including
53+
conversation history and system instructions for context.
54+
table_references (List[Dict[str, str]]): A list of dictionaries, each
55+
specifying a BigQuery table to be used as context for the question.
56+
credentials (Credentials): The credentials to use for the request.
57+
config (BigQueryToolConfig): The configuration for the tool.
58+
59+
Returns:
60+
A dictionary with two keys:
61+
- 'status': A string indicating the final status (e.g., "SUCCESS").
62+
- 'response': A list of dictionaries, where each dictionary
63+
represents a timestamped system message from the API's execution
64+
process.
65+
66+
Example:
67+
A query joining multiple tables, showing the full return structure.
68+
>>> chat(
69+
... project_id="some-project-id",
70+
... user_query_with_context="Which customer from New York spent the
71+
most last month? "
72+
... "Context: The 'customers' table joins with
73+
the 'orders' table "
74+
... "on the 'customer_id' column.",
75+
... table_references=[
76+
... {
77+
... "projectId": "my-gcp-project",
78+
... "datasetId": "sales_data",
79+
... "tableId": "customers"
80+
... },
81+
... {
82+
... "projectId": "my-gcp-project",
83+
... "datasetId": "sales_data",
84+
... "tableId": "orders"
85+
... }
86+
... ]
87+
... )
88+
{
89+
"status": "SUCCESS",
90+
"response": (
91+
"## SQL Generated\\n"
92+
"```sql\\n"
93+
"SELECT t1.customer_name, SUM(t2.order_total) AS total_spent "
94+
"FROM `my-gcp-project.sales_data.customers` AS t1 JOIN "
95+
"`my-gcp-project.sales_data.orders` AS t2 ON t1.customer_id = "
96+
"t2.customer_id WHERE t1.state = 'NY' AND t2.order_date >= ... "
97+
"GROUP BY 1 ORDER BY 2 DESC LIMIT 1;\\n"
98+
"```\\n\\n"
99+
"Answer: The customer who spent the most from New York last "
100+
"month was Jane Doe."
101+
)
102+
}
103+
"""
104+
try:
105+
location = "global"
106+
if not credentials.token:
107+
error_message = (
108+
"Error: The provided credentials object does not have a valid access"
109+
" token.\n\nThis is often because the credentials need to be"
110+
" refreshed or require specific API scopes. Please ensure the"
111+
" credentials are prepared correctly before calling this"
112+
" function.\n\nThere may be other underlying causes as well."
113+
)
114+
return {
115+
"status": "ERROR",
116+
"error_details": "Chat requires a valid access token.",
117+
}
118+
headers = {
119+
"Authorization": f"Bearer {credentials.token}",
120+
"Content-Type": "application/json",
121+
}
122+
chat_url = f"https://geminidataanalytics.googleapis.com/v1alpha/projects/{project_id}/locations/{location}:chat"
123+
124+
chat_payload = {
125+
"project": f"projects/{project_id}",
126+
"messages": [{"userMessage": {"text": user_query_with_context}}],
127+
"inlineContext": {
128+
"datasourceReferences": {
129+
"bq": {"tableReferences": table_references}
130+
},
131+
"options": {"chart": {"image": {"noImage": {}}}},
132+
},
133+
}
134+
135+
resp = _get_stream(
136+
chat_url, chat_payload, headers, config.max_query_result_rows
137+
)
138+
except Exception as ex: # pylint: disable=broad-except
139+
return {
140+
"status": "ERROR",
141+
"error_details": str(ex),
142+
}
143+
return {"status": "SUCCESS", "response": resp}
144+
145+
146+
def _get_stream(
147+
url: str,
148+
chat_payload: Dict[str, Any],
149+
headers: Dict[str, str],
150+
max_query_result_rows: int,
151+
) -> str:
152+
"""Sends a JSON request to a streaming API and returns the response as a string."""
153+
s = requests.Session()
154+
155+
accumulator = ""
156+
messages = []
157+
158+
with s.post(url, json=chat_payload, headers=headers, stream=True) as resp:
159+
for line in resp.iter_lines():
160+
if not line:
161+
continue
162+
163+
decoded_line = str(line, encoding="utf-8")
164+
165+
if decoded_line == "[{":
166+
accumulator = "{"
167+
elif decoded_line == "}]":
168+
accumulator += "}"
169+
elif decoded_line == ",":
170+
continue
171+
else:
172+
accumulator += decoded_line
173+
174+
if not _is_json(accumulator):
175+
continue
176+
177+
data_json = json.loads(accumulator)
178+
if "systemMessage" not in data_json:
179+
if "error" in data_json:
180+
_append_message(messages, _handle_error(data_json["error"]))
181+
continue
182+
183+
system_message = data_json["systemMessage"]
184+
if "text" in system_message:
185+
_append_message(messages, _handle_text_response(system_message["text"]))
186+
elif "schema" in system_message:
187+
_append_message(
188+
messages,
189+
_handle_schema_response(system_message["schema"]),
190+
)
191+
elif "data" in system_message:
192+
_append_message(
193+
messages,
194+
_handle_data_response(
195+
system_message["data"], max_query_result_rows
196+
),
197+
)
198+
accumulator = ""
199+
return "\n\n".join(messages)
200+
201+
202+
def _is_json(str):
203+
try:
204+
json_object = json.loads(str)
205+
except ValueError as e:
206+
return False
207+
return True
208+
209+
210+
def _get_property(data, field_name, default=""):
211+
"""Safely gets a property from a dictionary."""
212+
return data[field_name] if field_name in data else default
213+
214+
215+
def _format_section_title(text: str) -> str:
216+
"""Formats text as a Markdown H2 title."""
217+
return f"## {text}"
218+
219+
220+
def _format_bq_table_ref(table_ref: Dict[str, str]) -> str:
221+
"""Formats a BigQuery table reference dictionary into a string."""
222+
return f"{table_ref['projectId']}.{table_ref['datasetId']}.{table_ref['tableId']}"
223+
224+
225+
def _format_schema_as_markdown(data: Dict[str, Any]) -> str:
226+
"""Converts a schema dictionary to a Markdown table string without using pandas."""
227+
fields = data.get("fields", [])
228+
if not fields:
229+
return "No schema fields found."
230+
231+
# Define the table headers
232+
headers = ["Column", "Type", "Description", "Mode"]
233+
234+
# Create the header and separator lines for the Markdown table
235+
header_line = f"| {' | '.join(headers)} |"
236+
separator_line = f"| {' | '.join(['---'] * len(headers))} |"
237+
238+
# Create a list to hold each data row string
239+
data_lines = []
240+
for field in fields:
241+
# Extract each property in the correct order for a row
242+
row_values = [
243+
_get_property(field, "name"),
244+
_get_property(field, "type"),
245+
_get_property(field, "description", "-"),
246+
_get_property(field, "mode"),
247+
]
248+
# Format the row by joining the values with pipes
249+
data_lines.append(f"| {' | '.join(map(str, row_values))} |")
250+
251+
# Combine the header, separator, and data lines into the final table string
252+
return "\n".join([header_line, separator_line] + data_lines)
253+
254+
255+
def _format_datasource_as_markdown(datasource: Dict[str, Any]) -> str:
256+
"""Formats a full datasource object into a string with its name and schema."""
257+
source_name = _format_bq_table_ref(datasource["bigqueryTableReference"])
258+
259+
schema_markdown = _format_schema_as_markdown(datasource["schema"])
260+
return f"**Source:** `{source_name}`\n{schema_markdown}"
261+
262+
263+
def _handle_text_response(resp: Dict[str, Any]) -> str:
264+
"""Joins and returns text parts from a response."""
265+
parts = resp.get("parts", [])
266+
return "Answer: " + "".join(parts)
267+
268+
269+
def _handle_schema_response(resp: Dict[str, Any]) -> str:
270+
"""Formats a schema response into a complete string."""
271+
if "query" in resp:
272+
return resp["query"].get("question", "")
273+
elif "result" in resp:
274+
title = _format_section_title("Schema Resolved")
275+
datasources = resp["result"].get("datasources", [])
276+
# Format each datasource and join them with newlines
277+
formatted_sources = "\n\n".join(
278+
[_format_datasource_as_markdown(ds) for ds in datasources]
279+
)
280+
return f"{title}\nData sources:\n{formatted_sources}"
281+
return ""
282+
283+
284+
def _handle_data_response(
285+
resp: Dict[str, Any], max_query_result_rows: int
286+
) -> str:
287+
"""Formats a data response (query, SQL, or result) into a string."""
288+
if "query" in resp:
289+
query = resp["query"]
290+
title = _format_section_title("Retrieval Query")
291+
return (
292+
f"{title}\n"
293+
f"**Query Name:** {query.get('name', 'N/A')}\n"
294+
f"**Question:** {query.get('question', 'N/A')}"
295+
)
296+
elif "generatedSql" in resp:
297+
title = _format_section_title("SQL Generated")
298+
sql_code = resp["generatedSql"]
299+
# Format SQL in a Markdown code block
300+
return f"{title}\n```sql\n{sql_code}\n```"
301+
elif "result" in resp:
302+
title = _format_section_title("Data Retrieved")
303+
fields = [
304+
_get_property(field, "name")
305+
for field in resp["result"]["schema"]["fields"]
306+
]
307+
data_rows = resp["result"]["data"]
308+
total_rows = len(data_rows)
309+
header_line = f"| {' | '.join(fields)} |"
310+
separator_line = f"| {' | '.join(['---'] * len(fields))} |"
311+
312+
table_lines = [header_line, separator_line]
313+
314+
for row_dict in data_rows[:max_query_result_rows]:
315+
row_values = [str(row_dict.get(field, "")) for field in fields]
316+
table_lines.append(f"| {' | '.join(row_values)} |")
317+
318+
table_markdown = "\n".join(table_lines)
319+
320+
if total_rows > max_query_result_rows:
321+
table_markdown += (
322+
f"\n\n... *and {total_rows - max_query_result_rows} more rows*."
323+
)
324+
325+
return f"{title}\n{table_markdown}"
326+
return ""
327+
328+
329+
def _handle_error(resp: Dict[str, str]) -> str:
330+
"""Formats an error response into a string."""
331+
title = _format_section_title("Error")
332+
code = resp.get("code", "N/A")
333+
message = resp.get("message", "No message provided.")
334+
return f"{title}\n**Code:** {code}\n**Message:** {message}"
335+
336+
337+
def _append_message(messages: List[str], new_message: str):
338+
if new_message:
339+
if messages and messages[-1].startswith("## Data Retrieved"):
340+
messages.pop()
341+
messages.append(new_message)

0 commit comments

Comments
 (0)