diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 050ce1332..885da85bd 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -25,6 +25,13 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: Runs a SQL query in BigQuery. +1. `chat` + + Natural language-in, natural language-out chat tool that answers questions + about structured data in BigQuery. Provide a one-stop solution for generating + insights from data. + + ## How to use Set up environment variables in your `.env` file for using diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py index 50d49ff77..0b231edb6 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -65,7 +65,9 @@ def __init__( if credentials_config else None ) - self._tool_config = bigquery_tool_config + self._tool_config = ( + bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig() + ) @override async def run_async( diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 313cf4990..6f6e9f94a 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -21,6 +21,7 @@ from google.adk.agents.readonly_context import ReadonlyContext from typing_extensions import override +from . import chat_tool from . import metadata_tool from . import query_tool from ...tools.base_tool import BaseTool @@ -78,6 +79,7 @@ async def get_tools( metadata_tool.list_dataset_ids, metadata_tool.list_table_ids, query_tool.get_execute_sql(self._tool_config), + chat_tool.chat, ] ] diff --git a/src/google/adk/tools/bigquery/chat_tool.py b/src/google/adk/tools/bigquery/chat_tool.py new file mode 100644 index 000000000..2e1ebecfa --- /dev/null +++ b/src/google/adk/tools/bigquery/chat_tool.py @@ -0,0 +1,337 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any +from typing import Dict +from typing import List + +from google.auth.credentials import Credentials +from google.cloud import bigquery +import requests + +from . import client +from .config import BigQueryToolConfig + + +def chat( + project_id: str, + user_query_with_context: str, + table_references: List[Dict[str, str]], + credentials: Credentials, + config: BigQueryToolConfig, +) -> Dict[str, Any]: + """Answers questions about structured data in BigQuery tables using natural language. + + This function acts as a client for a "chat-with-your-data" service. It takes a + user's question (which can include conversational history for context) and + references to specific BigQuery tables, and sends them to a stateless + conversational API. + + The API uses a GenAI agent to understand the question, generate and execute + SQL queries and Python code, and formulate an answer. This function returns a + detailed, sequential log of this entire process, which includes any generated + SQL or Python code, the data retrieved, and the final text answer. + + Use this tool to perform data analysis, get insights, or answer complex + questions about the contents of specific BigQuery tables. + + Args: + project_id (str): The project that the chat is performed in. + user_query_with_context (str): The user's question, potentially including + conversation history and system instructions for context. + table_references (List[Dict[str, str]]): A list of dictionaries, each + specifying a BigQuery table to be used as context for the question. + credentials (Credentials): The credentials to use for the request. + config (BigQueryToolConfig): The configuration for the tool. + + Returns: + A dictionary with two keys: + - 'status': A string indicating the final status (e.g., "SUCCESS"). + - 'response': A list of dictionaries, where each dictionary + represents a step in the API's execution process (e.g., SQL + generation, data retrieval, final answer). + + Example: + A query joining multiple tables, showing the full return structure. + >>> chat( + ... project_id="some-project-id", + ... user_query_with_context="Which customer from New York spent the + most last month? " + ... "Context: The 'customers' table joins with + the 'orders' table " + ... "on the 'customer_id' column.", + ... table_references=[ + ... { + ... "projectId": "my-gcp-project", + ... "datasetId": "sales_data", + ... "tableId": "customers" + ... }, + ... { + ... "projectId": "my-gcp-project", + ... "datasetId": "sales_data", + ... "tableId": "orders" + ... } + ... ] + ... ) + { + "status": "SUCCESS", + "response": [ + { + "SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... " + }, + { + "Data Retrieved": { + "headers": ["customer_name", "total_spent"], + "rows": [["Jane Doe", 1234.56]], + "summary": "Showing all 1 rows." + } + }, + { + "Answer": "The customer who spent the most was Jane Doe." + } + ] + } + """ + try: + location = "global" + if not credentials.token: + error_message = ( + "Error: The provided credentials object does not have a valid access" + " token.\n\nThis is often because the credentials need to be" + " refreshed or require specific API scopes. Please ensure the" + " credentials are prepared correctly before calling this" + " function.\n\nThere may be other underlying causes as well." + ) + return { + "status": "ERROR", + "error_details": "Chat requires a valid access token.", + } + headers = { + "Authorization": f"Bearer {credentials.token}", + "Content-Type": "application/json", + } + chat_url = f"https://geminidataanalytics.googleapis.com/v1alpha/projects/{project_id}/locations/{location}:chat" + + chat_payload = { + "project": f"projects/{project_id}", + "messages": [{"userMessage": {"text": user_query_with_context}}], + "inlineContext": { + "datasourceReferences": { + "bq": {"tableReferences": table_references} + }, + "options": {"chart": {"image": {"noImage": {}}}}, + }, + } + + resp = _get_stream( + chat_url, chat_payload, headers, config.max_query_result_rows + ) + except Exception as ex: # pylint: disable=broad-except + return { + "status": "ERROR", + "error_details": str(ex), + } + return {"status": "SUCCESS", "response": resp} + + +def _get_stream( + url: str, + chat_payload: Dict[str, Any], + headers: Dict[str, str], + max_query_result_rows: int, +) -> List[Dict[str, Any]]: + """Sends a JSON request to a streaming API and returns a list of messages.""" + s = requests.Session() + + accumulator = "" + messages = [] + + with s.post(url, json=chat_payload, headers=headers, stream=True) as resp: + for line in resp.iter_lines(): + if not line: + continue + + decoded_line = str(line, encoding="utf-8") + + if decoded_line == "[{": + accumulator = "{" + elif decoded_line == "}]": + accumulator += "}" + elif decoded_line == ",": + continue + else: + accumulator += decoded_line + + if not _is_json(accumulator): + continue + + data_json = json.loads(accumulator) + if "systemMessage" not in data_json: + if "error" in data_json: + _append_message(messages, _handle_error(data_json["error"])) + continue + + system_message = data_json["systemMessage"] + if "text" in system_message: + _append_message(messages, _handle_text_response(system_message["text"])) + elif "schema" in system_message: + _append_message( + messages, + _handle_schema_response(system_message["schema"]), + ) + elif "data" in system_message: + _append_message( + messages, + _handle_data_response( + system_message["data"], max_query_result_rows + ), + ) + accumulator = "" + return messages + + +def _is_json(s: str) -> bool: + """Checks if a string is a valid JSON object.""" + try: + json.loads(s) + except ValueError: + return False + return True + + +def _get_property( + data: Dict[str, Any], field_name: str, default: Any = "" +) -> Any: + """Safely gets a property from a dictionary.""" + return data.get(field_name, default) + + +def _format_bq_table_ref(table_ref: Dict[str, str]) -> str: + """Formats a BigQuery table reference dictionary into a string.""" + return f"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}" + + +def _format_schema_as_dict( + data: Dict[str, Any], +) -> Dict[str, List[Any]]: + """Extracts schema fields into a dictionary.""" + fields = data.get("fields", []) + if not fields: + return {"columns": []} + + column_details = [] + headers = ["Column", "Type", "Description", "Mode"] + rows: List[List[str, str, str, str]] = [] + for field in fields: + row_list = [ + _get_property(field, "name"), + _get_property(field, "type"), + _get_property(field, "description", ""), + _get_property(field, "mode"), + ] + rows.append(row_list) + + return {"headers": headers, "rows": rows} + + +def _format_datasource_as_dict(datasource: Dict[str, Any]) -> Dict[str, Any]: + """Formats a full datasource object into a dictionary with its name and schema.""" + source_name = _format_bq_table_ref(datasource["bigqueryTableReference"]) + + schema = _format_schema_as_dict(datasource["schema"]) + return {"source_name": source_name, "schema": schema} + + +def _handle_text_response(resp: Dict[str, Any]) -> Dict[str, str]: + """Formats a text response into a dictionary.""" + parts = resp.get("parts", []) + return {"Answer": "".join(parts)} + + +def _handle_schema_response(resp: Dict[str, Any]) -> Dict[str, Any]: + """Formats a schema response into a dictionary.""" + if "query" in resp: + return {"Question": resp["query"].get("question", "")} + elif "result" in resp: + datasources = resp["result"].get("datasources", []) + # Format each datasource and join them with newlines + formatted_sources = [_format_datasource_as_dict(ds) for ds in datasources] + return {"Schema Resolved": formatted_sources} + return {} + + +def _handle_data_response( + resp: Dict[str, Any], max_query_result_rows: int +) -> Dict[str, Any]: + """Formats a data response into a dictionary.""" + if "query" in resp: + query = resp["query"] + return { + "Retrieval Query": { + "Query Name": query.get("name", "N/A"), + "Question": query.get("question", "N/A"), + } + } + elif "generatedSql" in resp: + return {"SQL Generated": resp["generatedSql"]} + elif "result" in resp: + schema = resp["result"]["schema"] + headers = [field.get("name") for field in schema.get("fields", [])] + + all_rows = resp["result"]["data"] + total_rows = len(all_rows) + + compact_rows = [] + for row_dict in all_rows[:max_query_result_rows]: + row_values = [row_dict.get(header) for header in headers] + compact_rows.append(row_values) + + summary_string = f"Showing all {total_rows} rows." + if total_rows > max_query_result_rows: + summary_string = ( + f"Showing the first {len(compact_rows)} of {total_rows} total rows." + ) + + return { + "Data Retrieved": { + "headers": headers, + "rows": compact_rows, + "summary": summary_string, + } + } + + return {} + + +def _handle_error(resp: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: + """Formats an error response into a dictionary.""" + return { + "Error": { + "Code": resp.get("code", "N/A"), + "Message": resp.get("message", "No message provided."), + } + } + + +def _append_message( + messages: List[Dict[str, Any]], new_message: Dict[str, Any] +): + if not new_message: + return + + if messages and ("Data Retrieved" in messages[-1]): + messages.pop() + + messages.append(new_message) diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index 8b2816ebe..bc2f638b5 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -14,6 +14,8 @@ from __future__ import annotations +from typing import Optional + import google.api_core.client_info from google.auth.credentials import Credentials from google.cloud import bigquery @@ -24,7 +26,7 @@ def get_bigquery_client( - *, project: str, credentials: Credentials + *, project: Optional[str], credentials: Credentials ) -> bigquery.Client: """Get a BigQuery client.""" diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py index a6f8eeb5e..b2c02cfd2 100644 --- a/src/google/adk/tools/bigquery/config.py +++ b/src/google/adk/tools/bigquery/config.py @@ -54,3 +54,8 @@ class BigQueryToolConfig(BaseModel): By default, the tool will allow only read operations. This behaviour may change in future versions. """ + + max_query_result_rows: int = 50 + """Maximum number of rows to return from a query. + + By default, the query result will be limited to 50 rows.""" diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 7406d9a4d..1b5bfdafb 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -24,7 +24,6 @@ from .config import BigQueryToolConfig from .config import WriteMode -MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50 BIGQUERY_SESSION_INFO_KEY = "bigquery_session_info" @@ -157,17 +156,17 @@ def execute_sql( query, job_config=job_config, project=project_id, - max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS, + max_results=config.max_query_result_rows, ) rows = [{key: val for key, val in row.items()} for row in row_iterator] result = {"status": "SUCCESS", "rows": rows} if ( - MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None - and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS + config.max_query_result_rows is not None + and len(rows) == config.max_query_result_rows ): result["result_is_likely_truncated"] = True return result - except Exception as ex: + except Exception as ex: # pylint: disable=broad-except return { "status": "ERROR", "error_details": str(ex), diff --git a/tests/unittests/tools/bigquery/test_bigquery_chat_tool.py b/tests/unittests/tools/bigquery/test_bigquery_chat_tool.py new file mode 100644 index 000000000..90064e50e --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_chat_tool.py @@ -0,0 +1,271 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from unittest import mock + +from google.adk.tools.bigquery import chat_tool +import pytest +import yaml + + +@pytest.mark.parametrize( + "case_file_path", + [ + pytest.param("test_data/chat_penguins_highest_mass.yaml"), + ], +) +@mock.patch("google.adk.tools.bigquery.chat_tool.requests.Session.post") +def test_chat_pipeline_from_file(mock_post, case_file_path): + """Runs a full integration test for the chat pipeline using data from a specific file.""" + # 1. Construct the full, absolute path to the data file + full_path = pathlib.Path(__file__).parent / case_file_path + + # 2. Load the test case data from the specified YAML file + with open(full_path, "r", encoding="utf-8") as f: + case_data = yaml.safe_load(f) + + # 3. Prepare the mock stream and expected output from the loaded data + mock_stream_str = case_data["mock_api_stream"] + fake_stream_lines = [ + line.encode("utf-8") for line in mock_stream_str.splitlines() + ] + # Load the expected output as a list of dictionaries, not a single string + expected_final_list = case_data["expected_output"] + + # 4. Configure the mock for requests.post + mock_response = mock.Mock() + mock_response.iter_lines.return_value = fake_stream_lines + # Add raise_for_status mock which is called in the updated code + mock_response.raise_for_status.return_value = None + mock_post.return_value.__enter__.return_value = mock_response + + # 5. Call the function under test + result = chat_tool._get_stream( # pylint: disable=protected-access + url="fake_url", + chat_payload={}, + headers={}, + max_query_result_rows=50, + ) + + # 6. Assert that the final list of dicts matches the expected output + assert result == expected_final_list + + +@mock.patch("google.adk.tools.bigquery.chat_tool._get_stream") +def test_chat_success(mock_get_stream): + """Tests the success path of chat using decorators.""" + # 1. Configure the behavior of the mocked functions + mock_get_stream.return_value = "Final formatted string from stream" + + # 2. Create mock inputs for the function call + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_config = mock.Mock() + mock_config.max_query_result_rows = 100 + + # 3. Call the function under test + result = chat_tool.chat( + project_id="test-project", + user_query_with_context="test query", + table_references=[], + credentials=mock_creds, + config=mock_config, + ) + + # 4. Assert the results are as expected + assert result["status"] == "SUCCESS" + assert result["response"] == "Final formatted string from stream" + mock_get_stream.assert_called_once() + + +@mock.patch("google.adk.tools.bigquery.chat_tool._get_stream") +def test_chat_handles_exception(mock_get_stream): + """Tests the exception path of chat using decorators.""" + # 1. Configure one of the mocks to raise an error + mock_get_stream.side_effect = Exception("API call failed!") + + # 2. Create mock inputs + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_config = mock.Mock() + + # 3. Call the function + result = chat_tool.chat( + project_id="test-project", + user_query_with_context="test query", + table_references=[], + credentials=mock_creds, + config=mock_config, + ) + + # 4. Assert that the error was caught and formatted correctly + assert result["status"] == "ERROR" + assert "API call failed!" in result["error_details"] + mock_get_stream.assert_called_once() + + +@pytest.mark.parametrize( + "initial_messages, new_message, expected_list", + [ + pytest.param( + [{"Thinking": None}, {"Schema Resolved": {}}], + {"SQL Generated": "SELECT 1"}, + [ + {"Thinking": None}, + {"Schema Resolved": {}}, + {"SQL Generated": "SELECT 1"}, + ], + id="append_when_last_message_is_not_data", + ), + pytest.param( + [{"Thinking": None}, {"Data Retrieved": {"rows": [1]}}], + {"Data Retrieved": {"rows": [1, 2]}}, + [{"Thinking": None}, {"Data Retrieved": {"rows": [1, 2]}}], + id="replace_when_last_message_is_data", + ), + pytest.param( + [], + {"Answer": "First Message"}, + [{"Answer": "First Message"}], + id="append_to_an_empty_list", + ), + pytest.param( + [{"Data Retrieved": {}}], + {}, + [{"Data Retrieved": {}}], + id="should_not_append_an_empty_new_message", + ), + ], +) +def test_append_message(initial_messages, new_message, expected_list): + """Tests the logic of replacing the last message if it's a data message.""" + messages_copy = initial_messages.copy() + chat_tool._append_message(messages_copy, new_message) # pylint: disable=protected-access + assert messages_copy == expected_list + + +@pytest.mark.parametrize( + "response_dict, expected_output", + [ + pytest.param( + {"parts": ["The answer", " is 42."]}, + {"Answer": "The answer is 42."}, + id="multiple_parts", + ), + pytest.param( + {"parts": ["Hello"]}, {"Answer": "Hello"}, id="single_part" + ), + pytest.param({}, {"Answer": ""}, id="empty_response"), + ], +) +def test_handle_text_response(response_dict, expected_output): + """Tests the text response handler.""" + result = chat_tool._handle_text_response(response_dict) # pylint: disable=protected-access + assert result == expected_output + + +@pytest.mark.parametrize( + "response_dict, expected_output", + [ + pytest.param( + {"query": {"question": "What is the schema?"}}, + {"Question": "What is the schema?"}, + id="schema_query_path", + ), + pytest.param( + { + "result": { + "datasources": [{ + "bigqueryTableReference": { + "projectId": "p", + "datasetId": "d", + "tableId": "t", + }, + "schema": { + "fields": [{"name": "col1", "type": "STRING"}] + }, + }] + } + }, + { + "Schema Resolved": [{ + "source_name": "p.d.t", + "schema": { + "headers": ["Column", "Type", "Description", "Mode"], + "rows": [["col1", "STRING", "", ""]], + }, + }] + }, + id="schema_result_path", + ), + ], +) +def test_handle_schema_response(response_dict, expected_output): + """Tests different paths of the schema response handler.""" + result = chat_tool._handle_schema_response(response_dict) # pylint: disable=protected-access + assert result == expected_output + + +@pytest.mark.parametrize( + "response_dict, expected_output", + [ + pytest.param( + {"generatedSql": "SELECT 1;"}, + {"SQL Generated": "SELECT 1;"}, + id="format_generated_sql", + ), + pytest.param( + { + "result": { + "schema": {"fields": [{"name": "id"}, {"name": "name"}]}, + "data": [{"id": 1, "name": "A"}, {"id": 2, "name": "B"}], + } + }, + { + "Data Retrieved": { + "headers": ["id", "name"], + "rows": [[1, "A"], [2, "B"]], + "summary": "Showing all 2 rows.", + } + }, + id="format_data_result_table", + ), + ], +) +def test_handle_data_response(response_dict, expected_output): + """Tests different paths of the data response handler, including truncation.""" + result = chat_tool._handle_data_response(response_dict, 100) # pylint: disable=protected-access + assert result == expected_output + + +@pytest.mark.parametrize( + "response_dict, expected_output", + [ + pytest.param( + {"code": 404, "message": "Not Found"}, + {"Error": {"Code": 404, "Message": "Not Found"}}, + id="full_error_message", + ), + pytest.param( + {"code": 500}, + {"Error": {"Code": 500, "Message": "No message provided."}}, + id="error_with_missing_message", + ), + ], +) +def test_handle_error(response_dict, expected_output): + """Tests the error response handler.""" + result = chat_tool._handle_error(response_dict) # pylint: disable=protected-access + assert result == expected_output diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index e8b373416..0bf71381b 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -1,3 +1,4 @@ +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool.py b/tests/unittests/tools/bigquery/test_bigquery_tool.py index b4ea75b16..2e00e6007 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_tool.py @@ -20,6 +20,7 @@ from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager from google.adk.tools.bigquery.bigquery_tool import BigQueryTool +from google.adk.tools.bigquery.config import BigQueryToolConfig # Mock the Google OAuth and API dependencies from google.oauth2.credentials import Credentials import pytest @@ -267,3 +268,35 @@ def complex_function( assert "required_param" in mandatory_args assert "credentials" not in mandatory_args assert "optional_param" not in mandatory_args + + @pytest.mark.parametrize( + "input_config, expected_config", + [ + pytest.param( + BigQueryToolConfig( + write_mode="blocked", max_query_result_rows=50 + ), + BigQueryToolConfig( + write_mode="blocked", max_query_result_rows=50 + ), + id="with_provided_config", + ), + pytest.param( + None, + BigQueryToolConfig(), + id="with_none_config_creates_default", + ), + ], + ) + def test_tool_config_initialization(self, input_config, expected_config): + """Tests that self._tool_config is correctly initialized by comparing its + + final state to an expected configuration object. + """ + # 1. Initialize the tool with the parameterized config + tool = BigQueryTool(func=None, bigquery_tool_config=input_config) + + # 2. Assert that the tool's config has the same attribute values + # as the expected config. Comparing the __dict__ is a robust + # way to check for value equality. + assert tool._tool_config.__dict__ == expected_config.__dict__ # pylint: disable=protected-access diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index 4129dc512..e0c411c79 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -34,7 +34,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 5 + assert len(tools) == 6 assert all([isinstance(tool, BigQueryTool) for tool in tools]) expected_tool_names = set([ @@ -43,6 +43,7 @@ async def test_bigquery_toolset_tools_default(): "list_table_ids", "get_table_info", "execute_sql", + "chat", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/bigquery/test_data/chat_penguins_highest_mass.yaml b/tests/unittests/tools/bigquery/test_data/chat_penguins_highest_mass.yaml new file mode 100644 index 000000000..7c0f213aa --- /dev/null +++ b/tests/unittests/tools/bigquery/test_data/chat_penguins_highest_mass.yaml @@ -0,0 +1,336 @@ +description: "Tests a full, realistic stream about finding the penguin island with the highest body mass." + +user_question: "Penguins on which island have the highest average body mass?" + +mock_api_stream: | + [{ + "timestamp": "2025-07-17T17:25:28.231Z", + "systemMessage": { + "schema": { + "query": { + "question": "Penguins on which island have the highest average body mass?" + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:29.406Z", + "systemMessage": { + "schema": { + "result": { + "datasources": [ + { + "bigqueryTableReference": { + "projectId": "bigframes-dev-perf", + "datasetId": "bigframes_testing_eu", + "tableId": "penguins" + }, + "schema": { + "fields": [ + { + "name": "species", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "island", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "culmen_length_mm", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "culmen_depth_mm", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "flipper_length_mm", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "body_mass_g", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "sex", + "type": "STRING", + "mode": "NULLABLE" + } + ] + } + } + ] + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:30.431Z", + "systemMessage": { + "data": { + "query": { + "question": "What is the average body mass for each island?", + "datasources": [ + { + "bigqueryTableReference": { + "projectId": "bigframes-dev-perf", + "datasetId": "bigframes_testing_eu", + "tableId": "penguins" + }, + "schema": { + "fields": [ + { + "name": "species", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "island", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "culmen_length_mm", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "culmen_depth_mm", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "flipper_length_mm", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "body_mass_g", + "type": "FLOAT64", + "mode": "NULLABLE" + }, + { + "name": "sex", + "type": "STRING", + "mode": "NULLABLE" + } + ] + } + } + ], + "name": "average_body_mass_by_island" + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:31.171Z", + "systemMessage": { + "data": { + "generatedSql": "SELECT island, AVG(body_mass_g) AS average_body_mass\nFROM `bigframes-dev-perf`.`bigframes_testing_eu`.`penguins`\nGROUP BY island;" + } + } + } + , + { + "timestamp": "2025-07-17T17:25:32.378Z", + "systemMessage": { + "data": { + "bigQueryJob": { + "projectId": "bigframes-dev-perf", + "jobId": "job_S4PGRwxO78_FrVmCHW_sklpeZFps", + "destinationTable": { + "projectId": "bigframes-dev-perf", + "datasetId": "_376b2bd1b83171a540d39ff3d58f39752e2724c9", + "tableId": "anonev_4a9PK1uHzAHwAOpSNOxMVhpUppM2sllR68riN6t41kM" + }, + "location": "EU", + "schema": { + "fields": [ + { + "name": "island", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "average_body_mass", + "type": "FLOAT", + "mode": "NULLABLE" + } + ] + } + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:32.664Z", + "systemMessage": { + "data": { + "result": { + "data": [ + { + "island": "Biscoe", + "average_body_mass": "4716.017964071853" + }, + { + "island": "Dream", + "average_body_mass": "3712.9032258064512" + }, + { + "island": "Torgersen", + "average_body_mass": "3706.3725490196075" + } + ], + "name": "average_body_mass_by_island", + "schema": { + "fields": [ + { + "name": "island", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "average_body_mass", + "type": "FLOAT", + "mode": "NULLABLE" + } + ] + } + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:33.808Z", + "systemMessage": { + "chart": { + "query": { + "instructions": "Create a bar chart showing the average body mass for each island. The island should be on the x axis and the average body mass should be on the y axis.", + "dataResultName": "average_body_mass_by_island" + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:38.999Z", + "systemMessage": { + "chart": { + "result": { + "vegaConfig": { + "mark": { + "type": "bar", + "tooltip": true + }, + "encoding": { + "x": { + "field": "island", + "type": "nominal", + "title": "Island", + "axis": { + "labelOverlap": true + }, + "sort": {} + }, + "y": { + "field": "average_body_mass", + "type": "quantitative", + "title": "Average Body Mass", + "axis": { + "labelOverlap": true + }, + "sort": {} + } + }, + "title": "Average Body Mass for Each Island", + "data": { + "values": [ + { + "island": "Biscoe", + "average_body_mass": 4716.0179640718534 + }, + { + "island": "Dream", + "average_body_mass": 3712.9032258064512 + }, + { + "island": "Torgersen", + "average_body_mass": 3706.3725490196075 + } + ] + } + }, + "image": {} + } + } + } + } + , + { + "timestamp": "2025-07-17T17:25:40.018Z", + "systemMessage": { + "text": { + "parts": [ + "Penguins on Biscoe island have the highest average body mass, with an average of 4716.02g." + ] + } + } + } + ] + +expected_output: +- Question: Penguins on which island have the highest average body mass? +- Schema Resolved: + - source_name: bigframes-dev-perf.bigframes_testing_eu.penguins + schema: + headers: + - Column + - Type + - Description + - Mode + rows: + - - species + - STRING + - '' + - NULLABLE + - - island + - STRING + - '' + - NULLABLE + - - culmen_length_mm + - FLOAT64 + - '' + - NULLABLE + - - culmen_depth_mm + - FLOAT64 + - '' + - NULLABLE + - - flipper_length_mm + - FLOAT64 + - '' + - NULLABLE + - - body_mass_g + - FLOAT64 + - '' + - NULLABLE + - - sex + - STRING + - '' + - NULLABLE +- Retrieval Query: + Query Name: average_body_mass_by_island + Question: What is the average body mass for each island? +- SQL Generated: "SELECT island, AVG(body_mass_g) AS average_body_mass\nFROM `bigframes-dev-perf`.`bigframes_testing_eu`.`penguins`\nGROUP BY island;" +- Answer: Penguins on Biscoe island have the highest average body mass, with an average of 4716.02g. \ No newline at end of file