|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import json |
| 16 | +from typing import Any |
| 17 | +from typing import Dict |
| 18 | +from typing import List |
| 19 | + |
| 20 | +from google.auth.credentials import Credentials |
| 21 | +from google.cloud import bigquery |
| 22 | +import requests |
| 23 | + |
| 24 | +from . import client |
| 25 | +from .config import BigQueryToolConfig |
| 26 | + |
| 27 | + |
| 28 | +def chat( |
| 29 | + project_id: str, |
| 30 | + user_query_with_context: str, |
| 31 | + table_references: List[Dict[str, str]], |
| 32 | + credentials: Credentials, |
| 33 | + config: BigQueryToolConfig, |
| 34 | +) -> Dict[str, Any]: |
| 35 | + """Answers questions about structured data in BigQuery tables using natural language. |
| 36 | +
|
| 37 | + This function acts as a client for a "chat-with-your-data" service. It takes a |
| 38 | + user's question (which can include conversational history for context) and |
| 39 | + references to specific BigQuery tables, and sends them to a stateless |
| 40 | + conversational API. |
| 41 | +
|
| 42 | + The API uses a GenAI agent to understand the question, generate and execute |
| 43 | + SQL queries and Python code, and formulate an answer. This function returns a |
| 44 | + detailed, sequential log of this entire process, which includes any generated |
| 45 | + SQL or Python code, the data retrieved, and the final text answer. |
| 46 | +
|
| 47 | + Use this tool to perform data analysis, get insights, or answer complex |
| 48 | + questions about the contents of specific BigQuery tables. |
| 49 | +
|
| 50 | + Args: |
| 51 | + project_id (str): The project that the chat is performed in. |
| 52 | + user_query_with_context (str): The user's question, potentially including |
| 53 | + conversation history and system instructions for context. |
| 54 | + table_references (List[Dict[str, str]]): A list of dictionaries, each |
| 55 | + specifying a BigQuery table to be used as context for the question. |
| 56 | + credentials (Credentials): The credentials to use for the request. |
| 57 | + config (BigQueryToolConfig): The configuration for the tool. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + A dictionary with two keys: |
| 61 | + - 'status': A string indicating the final status (e.g., "SUCCESS"). |
| 62 | + - 'response': A list of dictionaries, where each dictionary |
| 63 | + represents a step in the API's execution process (e.g., SQL |
| 64 | + generation, data retrieval, final answer). |
| 65 | +
|
| 66 | + Example: |
| 67 | + A query joining multiple tables, showing the full return structure. |
| 68 | + >>> chat( |
| 69 | + ... project_id="some-project-id", |
| 70 | + ... user_query_with_context="Which customer from New York spent the |
| 71 | + most last month? " |
| 72 | + ... "Context: The 'customers' table joins with |
| 73 | + the 'orders' table " |
| 74 | + ... "on the 'customer_id' column.", |
| 75 | + ... table_references=[ |
| 76 | + ... { |
| 77 | + ... "projectId": "my-gcp-project", |
| 78 | + ... "datasetId": "sales_data", |
| 79 | + ... "tableId": "customers" |
| 80 | + ... }, |
| 81 | + ... { |
| 82 | + ... "projectId": "my-gcp-project", |
| 83 | + ... "datasetId": "sales_data", |
| 84 | + ... "tableId": "orders" |
| 85 | + ... } |
| 86 | + ... ] |
| 87 | + ... ) |
| 88 | + { |
| 89 | + "status": "SUCCESS", |
| 90 | + "response": [ |
| 91 | + { |
| 92 | + "SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... " |
| 93 | + }, |
| 94 | + { |
| 95 | + "Data Retrieved": { |
| 96 | + "headers": ["customer_name", "total_spent"], |
| 97 | + "rows": [["Jane Doe", 1234.56]], |
| 98 | + "summary": "Showing all 1 rows." |
| 99 | + } |
| 100 | + }, |
| 101 | + { |
| 102 | + "Answer": "The customer who spent the most was Jane Doe." |
| 103 | + } |
| 104 | + ] |
| 105 | + } |
| 106 | + """ |
| 107 | + try: |
| 108 | + location = "global" |
| 109 | + if not credentials.token: |
| 110 | + error_message = ( |
| 111 | + "Error: The provided credentials object does not have a valid access" |
| 112 | + " token.\n\nThis is often because the credentials need to be" |
| 113 | + " refreshed or require specific API scopes. Please ensure the" |
| 114 | + " credentials are prepared correctly before calling this" |
| 115 | + " function.\n\nThere may be other underlying causes as well." |
| 116 | + ) |
| 117 | + return { |
| 118 | + "status": "ERROR", |
| 119 | + "error_details": "Chat requires a valid access token.", |
| 120 | + } |
| 121 | + headers = { |
| 122 | + "Authorization": f"Bearer {credentials.token}", |
| 123 | + "Content-Type": "application/json", |
| 124 | + } |
| 125 | + chat_url = f"https://geminidataanalytics.googleapis.com/v1alpha/projects/{project_id}/locations/{location}:chat" |
| 126 | + |
| 127 | + chat_payload = { |
| 128 | + "project": f"projects/{project_id}", |
| 129 | + "messages": [{"userMessage": {"text": user_query_with_context}}], |
| 130 | + "inlineContext": { |
| 131 | + "datasourceReferences": { |
| 132 | + "bq": {"tableReferences": table_references} |
| 133 | + }, |
| 134 | + "options": {"chart": {"image": {"noImage": {}}}}, |
| 135 | + }, |
| 136 | + } |
| 137 | + |
| 138 | + resp = _get_stream( |
| 139 | + chat_url, chat_payload, headers, config.max_query_result_rows |
| 140 | + ) |
| 141 | + except Exception as ex: # pylint: disable=broad-except |
| 142 | + return { |
| 143 | + "status": "ERROR", |
| 144 | + "error_details": str(ex), |
| 145 | + } |
| 146 | + return {"status": "SUCCESS", "response": resp} |
| 147 | + |
| 148 | + |
| 149 | +def _get_stream( |
| 150 | + url: str, |
| 151 | + chat_payload: Dict[str, Any], |
| 152 | + headers: Dict[str, str], |
| 153 | + max_query_result_rows: int, |
| 154 | +) -> List[Dict[str, Any]]: |
| 155 | + """Sends a JSON request to a streaming API and returns a list of messages.""" |
| 156 | + s = requests.Session() |
| 157 | + |
| 158 | + accumulator = "" |
| 159 | + messages = [] |
| 160 | + |
| 161 | + with s.post(url, json=chat_payload, headers=headers, stream=True) as resp: |
| 162 | + for line in resp.iter_lines(): |
| 163 | + if not line: |
| 164 | + continue |
| 165 | + |
| 166 | + decoded_line = str(line, encoding="utf-8") |
| 167 | + |
| 168 | + if decoded_line == "[{": |
| 169 | + accumulator = "{" |
| 170 | + elif decoded_line == "}]": |
| 171 | + accumulator += "}" |
| 172 | + elif decoded_line == ",": |
| 173 | + continue |
| 174 | + else: |
| 175 | + accumulator += decoded_line |
| 176 | + |
| 177 | + if not _is_json(accumulator): |
| 178 | + continue |
| 179 | + |
| 180 | + data_json = json.loads(accumulator) |
| 181 | + if "systemMessage" not in data_json: |
| 182 | + if "error" in data_json: |
| 183 | + _append_message(messages, _handle_error(data_json["error"])) |
| 184 | + continue |
| 185 | + |
| 186 | + system_message = data_json["systemMessage"] |
| 187 | + if "text" in system_message: |
| 188 | + _append_message(messages, _handle_text_response(system_message["text"])) |
| 189 | + elif "schema" in system_message: |
| 190 | + _append_message( |
| 191 | + messages, |
| 192 | + _handle_schema_response(system_message["schema"]), |
| 193 | + ) |
| 194 | + elif "data" in system_message: |
| 195 | + _append_message( |
| 196 | + messages, |
| 197 | + _handle_data_response( |
| 198 | + system_message["data"], max_query_result_rows |
| 199 | + ), |
| 200 | + ) |
| 201 | + accumulator = "" |
| 202 | + return messages |
| 203 | + |
| 204 | + |
| 205 | +def _is_json(s: str) -> bool: |
| 206 | + """Checks if a string is a valid JSON object.""" |
| 207 | + try: |
| 208 | + json.loads(s) |
| 209 | + except ValueError: |
| 210 | + return False |
| 211 | + return True |
| 212 | + |
| 213 | + |
| 214 | +def _get_property( |
| 215 | + data: Dict[str, Any], field_name: str, default: Any = "" |
| 216 | +) -> Any: |
| 217 | + """Safely gets a property from a dictionary.""" |
| 218 | + return data.get(field_name, default) |
| 219 | + |
| 220 | + |
| 221 | +def _format_bq_table_ref(table_ref: Dict[str, str]) -> str: |
| 222 | + """Formats a BigQuery table reference dictionary into a string.""" |
| 223 | + return f"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}" |
| 224 | + |
| 225 | + |
| 226 | +def _format_schema_as_dict( |
| 227 | + data: Dict[str, Any], |
| 228 | +) -> Dict[str, List[Any]]: |
| 229 | + """Extracts schema fields into a dictionary.""" |
| 230 | + fields = data.get("fields", []) |
| 231 | + if not fields: |
| 232 | + return {"columns": []} |
| 233 | + |
| 234 | + column_details = [] |
| 235 | + headers = ["Column", "Type", "Description", "Mode"] |
| 236 | + rows: List[Tuple[str, str, str, str]] = [] |
| 237 | + for field in fields: |
| 238 | + row_tuple = ( |
| 239 | + _get_property(field, "name"), |
| 240 | + _get_property(field, "type"), |
| 241 | + _get_property(field, "description", ""), |
| 242 | + _get_property(field, "mode"), |
| 243 | + ) |
| 244 | + rows.append(row_tuple) |
| 245 | + |
| 246 | + return {"headers": headers, "rows": rows} |
| 247 | + |
| 248 | + |
| 249 | +def _format_datasource_as_dict(datasource: Dict[str, Any]) -> Dict[str, Any]: |
| 250 | + """Formats a full datasource object into a dictionary with its name and schema.""" |
| 251 | + source_name = _format_bq_table_ref(datasource["bigqueryTableReference"]) |
| 252 | + |
| 253 | + schema = _format_schema_as_dict(datasource["schema"]) |
| 254 | + return {"source_name": source_name, "schema": schema} |
| 255 | + |
| 256 | + |
| 257 | +def _handle_text_response(resp: Dict[str, Any]) -> Dict[str, str]: |
| 258 | + """Formats a text response into a dictionary.""" |
| 259 | + parts = resp.get("parts", []) |
| 260 | + return {"Answer": "".join(parts)} |
| 261 | + |
| 262 | + |
| 263 | +def _handle_schema_response(resp: Dict[str, Any]) -> Dict[str, Any]: |
| 264 | + """Formats a schema response into a dictionary.""" |
| 265 | + if "query" in resp: |
| 266 | + return {"Question": resp["query"].get("question", "")} |
| 267 | + elif "result" in resp: |
| 268 | + datasources = resp["result"].get("datasources", []) |
| 269 | + # Format each datasource and join them with newlines |
| 270 | + formatted_sources = [_format_datasource_as_dict(ds) for ds in datasources] |
| 271 | + return {"Schema Resolved": formatted_sources} |
| 272 | + return {} |
| 273 | + |
| 274 | + |
| 275 | +def _handle_data_response( |
| 276 | + resp: Dict[str, Any], max_query_result_rows: int |
| 277 | +) -> Dict[str, Any]: |
| 278 | + """Formats a data response into a dictionary.""" |
| 279 | + if "query" in resp: |
| 280 | + query = resp["query"] |
| 281 | + return { |
| 282 | + "Retrieval Query": { |
| 283 | + "Query Name": query.get("name", "N/A"), |
| 284 | + "Question": query.get("question", "N/A"), |
| 285 | + } |
| 286 | + } |
| 287 | + elif "generatedSql" in resp: |
| 288 | + return {"SQL Generated": resp["generatedSql"]} |
| 289 | + elif "result" in resp: |
| 290 | + schema = resp["result"]["schema"] |
| 291 | + headers = [field.get("name") for field in schema.get("fields", [])] |
| 292 | + |
| 293 | + all_rows = resp["result"]["data"] |
| 294 | + total_rows = len(all_rows) |
| 295 | + |
| 296 | + compact_rows = [] |
| 297 | + for row_dict in all_rows[:max_query_result_rows]: |
| 298 | + row_values = [row_dict.get(header) for header in headers] |
| 299 | + compact_rows.append(row_values) |
| 300 | + |
| 301 | + summary_string = f"Showing all {total_rows} rows." |
| 302 | + if total_rows > max_query_result_rows: |
| 303 | + summary_string = ( |
| 304 | + f"Showing the first {len(compact_rows)} of {total_rows} total rows." |
| 305 | + ) |
| 306 | + |
| 307 | + return { |
| 308 | + "Data Retrieved": { |
| 309 | + "headers": headers, |
| 310 | + "rows": compact_rows, |
| 311 | + "total_rows": total_rows, |
| 312 | + "summary": summary_string, |
| 313 | + } |
| 314 | + } |
| 315 | + |
| 316 | + return {} |
| 317 | + |
| 318 | + |
| 319 | +def _handle_error(resp: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: |
| 320 | + """Formats an error response into a dictionary.""" |
| 321 | + return { |
| 322 | + "Error": { |
| 323 | + "Code": resp.get("code", "N/A"), |
| 324 | + "Message": resp.get("message", "No message provided."), |
| 325 | + } |
| 326 | + } |
| 327 | + |
| 328 | + |
| 329 | +def _append_message( |
| 330 | + messages: List[Dict[str, Any]], new_message: Dict[str, Any] |
| 331 | +): |
| 332 | + if not new_message: |
| 333 | + return |
| 334 | + |
| 335 | + if messages and "Data Retrieved" in messages[-1]: |
| 336 | + messages.pop() |
| 337 | + |
| 338 | + messages.append(new_message) |
0 commit comments