From b1610f7f22b7f4104eaab13b409418710d089501 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Wed, 25 Feb 2026 05:55:27 +0000 Subject: [PATCH 01/15] feat(bigquery): support authorized views with dataset restrictions Refactors BigQuery query validation to support Authorized Views by combining behavioral analysis (Dry Run) with explicit reference auditing. ### Working Logic: The validation process uses a coordinated check: 1. Behavioral Detection (Dry Run): Identifies physical tables accessed by the query. If tables are found and all are allowed, the query is permitted immediately. 2. Explicit Reference Audit (Lexical Search): If the dry run identifies disallowed tables, a lexical scanner checks the SQL text for explicit mentions. Access is permitted if the disallowed tables are absent from the SQL, which correctly identifies Authorized View scenarios. 3. Fallback Intent Check: If no tables are found (e.g. DDL/Dynamic SQL) or to confirm the final intent, a local SQL parser validates the query to catch restricted operations and disallowed identifiers. ### Limitations: - Structural vs. Lexical Gap: SQL parsing complexity makes it difficult to understand all structural nuances; lexical search provides a safer layer to detect user intent where the structural parser might fall short. - False Positives: Queries may be blocked if non-table identifiers (like column aliases) exactly match a restricted table name. This prioritizes data security over permissive access. ### Changes: - Added IsAnyTableExplicitlyReferenced using a lexical state machine. - Reorganized ValidateQueryAgainstAllowedDatasets to optimize for performance (early return) and safety (coordinated checks). - Simplified tool code by removing redundant mock indirection. - Added tests for Authorized Views and complex table ID formats. --- .../tools/bigquery/bigquery-execute-sql.md | 9 +- .../bigqueryanalyzecontribution.go | 19 +- .../bigquerycommon/table_name_parser.go | 396 ++++++++++++++---- .../bigquerycommon/table_name_parser_test.go | 199 +++++++++ .../tools/bigquery/bigquerycommon/util.go | 94 +++++ .../bigqueryexecutesql/bigqueryexecutesql.go | 71 +--- .../bigqueryforecast/bigqueryforecast.go | 19 +- src/google-cloud-bigquery-storage | 1 + tests/bigquery/bigquery_integration_test.go | 145 ++++++- 9 files changed, 766 insertions(+), 187 deletions(-) create mode 160000 src/google-cloud-bigquery-storage diff --git a/docs/en/resources/tools/bigquery/bigquery-execute-sql.md b/docs/en/resources/tools/bigquery/bigquery-execute-sql.md index b59ae5824971..561408551ec8 100644 --- a/docs/en/resources/tools/bigquery/bigquery-execute-sql.md +++ b/docs/en/resources/tools/bigquery/bigquery-execute-sql.md @@ -39,11 +39,10 @@ layer of security by controlling which datasets can be accessed: - **Without `allowedDatasets` restriction:** The tool can execute any valid GoogleSQL query. -- **With `allowedDatasets` restriction:** Before execution, the tool performs a - dry run to analyze the query. - It will reject the query if it attempts to access any table outside the - allowed `datasets` list. To enforce this restriction, the following operations - are also disallowed: +- **With `allowedDatasets` restriction:** The tool analyzes the query before execution to ensure that it only accesses the allowed datasets. + This check also supports authorized views by validating direct references against the allowed list. + To enforce this restriction, the following operations are also disallowed: + - **Dataset-level operations** (e.g., `CREATE SCHEMA`, `ALTER SCHEMA`). - **Unanalyzable operations** where the accessed tables cannot be determined statically (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`, `CALL`). diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 7f50803b0a34..d026ad89a3ad 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -214,24 +214,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para {Key: "session_id", Value: session.ID}, } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps) - if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) - } - statementType := dryRunJob.Statistics.Query.StatementType - if statementType != "SELECT" { - return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType) - } - queryStats := dryRunJob.Statistics.Query - if queryStats != nil { - for _, tableRef := range queryStats.ReferencedTables { - if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { - return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) - } - } - } else { - return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets") + if _, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source); err != nil { + return nil, err } } inputDataSource = fmt.Sprintf("(%s)", inputData) diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index 941ed533e415..23c9438477c5 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -80,20 +80,32 @@ var tableContextExitKeywords = map[string]bool{ func TableParser(sql, defaultProjectID string) ([]string, error) { tableIDSet := make(map[string]struct{}) visitedSQLs := make(map[string]struct{}) - if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, false); err != nil { + aliases := make(map[string]struct{}) + if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { return nil, err } tableIDs := make([]string, 0, len(tableIDSet)) for id := range tableIDSet { - tableIDs = append(tableIDs, id) + isAlias := false + parts := strings.Split(id, ".") + for j := 0; j < len(parts); j++ { + suffix := strings.ToLower(strings.Join(parts[j:], ".")) + if _, ok := aliases[suffix]; ok { + isAlias = true + break + } + } + if !isAlias { + tableIDs = append(tableIDs, id) + } } return tableIDs, nil } // parseSQL is the core recursive function that processes SQL strings. // It uses a state machine to find table names and recursively parse EXECUTE IMMEDIATE. -func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, inSubquery bool) (int, error) { +func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, aliases map[string]struct{}, inSubquery bool) (int, error) { // Prevent infinite recursion. if _, ok := visitedSQLs[sql]; ok { return len(sql), nil @@ -101,13 +113,13 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi visitedSQLs[sql] = struct{}{} state := stateNormal - expectingTable := false + expectingTable, expectingAlias, expectingCTE := false, false, false var lastTableKeyword, lastToken, statementVerb string runes := []rune(sql) for i := 0; i < len(runes); { + remaining := string(runes[i:]) char := runes[i] - remaining := sql[i:] switch state { case stateNormal: @@ -126,20 +138,29 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i += 2 continue } + if char == ',' { + if lastTableKeyword == "from" { + expectingTable = true + expectingAlias = false + } else if statementVerb == "with" { + expectingCTE = true + expectingAlias = false + } + i++ + continue + } if char == '(' { - if expectingTable { - // The subquery starts after '('. - consumed, err := parseSQL(remaining[1:], defaultProjectID, tableIDSet, visitedSQLs, true) + if expectingTable || expectingCTE || lastToken == "as" { + consumed, err := parseSQL(string(runes[i+1:]), defaultProjectID, tableIDSet, visitedSQLs, aliases, true) if err != nil { return 0, err } - // Advance i by the length of the subquery + the opening parenthesis. - // The recursive call returns what it consumed, including the closing parenthesis. i += consumed + 1 - // For most keywords, we expect only one table. `from` can have multiple "tables" (subqueries). if lastTableKeyword != "from" { expectingTable = false } + expectingAlias = true + expectingCTE = false continue } } @@ -148,32 +169,32 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi return i + 1, nil } } - if char == ';' { statementVerb = "" lastToken = "" + expectingTable = false + expectingAlias = false + expectingCTE = false i++ continue - } - - // Raw strings must be checked before regular strings. - if strings.HasPrefix(remaining, "r'''") || strings.HasPrefix(remaining, "R'''") { + remLow := strings.ToLower(remaining) + if strings.HasPrefix(remLow, "r'''") { state = stateInRawTripleSingleQuoteString i += 4 continue } - if strings.HasPrefix(remaining, `r"""`) || strings.HasPrefix(remaining, `R"""`) { + if strings.HasPrefix(remLow, `r"""`) { state = stateInRawTripleDoubleQuoteString i += 4 continue } - if strings.HasPrefix(remaining, "r'") || strings.HasPrefix(remaining, "R'") { + if strings.HasPrefix(remLow, "r'") { state = stateInRawSingleQuoteString i += 2 continue } - if strings.HasPrefix(remaining, `r"`) || strings.HasPrefix(remaining, `R"`) { + if strings.HasPrefix(remLow, `r"`) { state = stateInRawDoubleQuoteString i += 2 continue @@ -199,7 +220,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi continue } - if unicode.IsLetter(char) || char == '`' { + if unicode.IsLetter(char) || char == '`' || char == '_' { parts, consumed, err := parseIdentifierSequence(remaining) if err != nil { return 0, err @@ -208,9 +229,11 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i++ continue } + keyword := strings.ToLower(parts[0]) + fullID := strings.ToLower(strings.Join(parts, ".")) + // Handle security-restricted operations and verb identification. if len(parts) == 1 { - keyword := strings.ToLower(parts[0]) switch keyword { case "call": return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place, as the called procedure's contents cannot be safely analyzed") @@ -233,13 +256,76 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed when dataset restrictions are in place", strings.ToUpper(statementVerb), strings.ToUpper(keyword)) } } + } + + // Resolve aliases and identify table references. + isKnownAlias := false + if _, ok := aliases[fullID]; ok { + isKnownAlias = true + } + if !isKnownAlias && len(parts) > 1 { + if _, ok := aliases[strings.ToLower(parts[0])]; ok { + isKnownAlias = true + } + } + + if expectingCTE { + aliases[fullID] = struct{}{} + aliases[strings.ToLower(parts[0])] = struct{}{} + expectingCTE = false + } else if expectingAlias { + if len(parts) == 1 && (tableContextExitKeywords[keyword] || tableFollowsKeywords[keyword] || keyword == "select" || keyword == "with") { + expectingAlias = false + } else { + aliases[fullID] = struct{}{} + aliases[strings.ToLower(parts[0])] = struct{}{} + expectingAlias = false + isKnownAlias = true + } + } + + // Re-check aliases after potential registration. + if !isKnownAlias { + if _, ok := aliases[fullID]; ok { + isKnownAlias = true + } + } + + if expectingTable && !isKnownAlias { + if len(parts) >= 2 { + tableID, err := formatTableID(parts, defaultProjectID) + if err != nil { + return 0, err + } + if tableID != "" { + tableIDSet[tableID] = struct{}{} + } + } + // For most keywords, we expect only one table. + if lastTableKeyword != "from" { + expectingTable = false + } + expectingAlias = true + } - if _, ok := tableFollowsKeywords[keyword]; ok { + // Update state machine based on the current keyword. + if len(parts) == 1 { + if keyword == "with" { + expectingCTE = true + statementVerb = "with" + } else if keyword == "as" { + if statementVerb != "with" { + expectingAlias = true + } + expectingTable = false + } else if _, ok := tableFollowsKeywords[keyword]; ok { expectingTable = true lastTableKeyword = keyword + expectingAlias = false } else if _, ok := tableContextExitKeywords[keyword]; ok { expectingTable = false lastTableKeyword = "" + expectingAlias = false } if lastToken == "create" && keyword == "or" { lastToken = "create or" @@ -248,29 +334,13 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } else { lastToken = keyword } - } else if len(parts) >= 2 { - // This is a multi-part identifier. If we were expecting a table, this is it. - if expectingTable { - tableID, err := formatTableID(parts, defaultProjectID) - if err != nil { - return 0, err - } - if tableID != "" { - tableIDSet[tableID] = struct{}{} - } - // For most keywords, we expect only one table. - if lastTableKeyword != "from" { - expectingTable = false - } - } + } else { lastToken = "" } - i += consumed continue } i++ - case stateInSingleQuoteString: if char == '\\' { i += 2 // Skip backslash and the escaped character. @@ -290,14 +360,14 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInTripleSingleQuoteString: - if strings.HasPrefix(remaining, "'''") { + if strings.HasPrefix(string(runes[i:]), "'''") { state = stateNormal i += 3 } else { i++ } case stateInTripleDoubleQuoteString: - if strings.HasPrefix(remaining, `"""`) { + if strings.HasPrefix(string(runes[i:]), `"""`) { state = stateNormal i += 3 } else { @@ -309,7 +379,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInMultiLineComment: - if strings.HasPrefix(remaining, "*/") { + if strings.HasPrefix(string(runes[i:]), "*/") { state = stateNormal i += 2 } else { @@ -326,14 +396,14 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInRawTripleSingleQuoteString: - if strings.HasPrefix(remaining, "'''") { + if strings.HasPrefix(string(runes[i:]), "'''") { state = stateNormal i += 3 } else { i++ } case stateInRawTripleDoubleQuoteString: - if strings.HasPrefix(remaining, `"""`) { + if strings.HasPrefix(string(runes[i:]), `"""`) { state = stateNormal i += 3 } else { @@ -341,11 +411,10 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } } } - if inSubquery { return 0, fmt.Errorf("unclosed subquery parenthesis") } - return len(sql), nil + return len(runes), nil } // parseIdentifierSequence parses a sequence of dot-separated identifiers. @@ -353,56 +422,40 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi func parseIdentifierSequence(s string) ([]string, int, error) { var parts []string var totalConsumed int - + runes := []rune(s) for { - remaining := s[totalConsumed:] - trimmed := strings.TrimLeftFunc(remaining, unicode.IsSpace) - totalConsumed += len(remaining) - len(trimmed) - current := s[totalConsumed:] - - if len(current) == 0 { + for totalConsumed < len(runes) && unicode.IsSpace(runes[totalConsumed]) { + totalConsumed++ + } + if totalConsumed >= len(runes) { break } var part string var consumed int - if current[0] == '`' { - end := strings.Index(current[1:], "`") + if runes[totalConsumed] == '`' { + end := strings.Index(string(runes[totalConsumed+1:]), "`") if end == -1 { return nil, 0, fmt.Errorf("unclosed backtick identifier") } - part = current[1 : end+1] + part = string(runes[totalConsumed+1 : totalConsumed+end+1]) consumed = end + 2 - } else if len(current) > 0 && unicode.IsLetter(rune(current[0])) { - end := strings.IndexFunc(current, func(r rune) bool { - return !unicode.IsLetter(r) && !unicode.IsNumber(r) && r != '_' && r != '-' - }) - if end == -1 { - part = current - consumed = len(current) - } else { - part = current[:end] - consumed = end + } else if unicode.IsLetter(runes[totalConsumed]) || runes[totalConsumed] == '_' { + end := totalConsumed + for end < len(runes) && (unicode.IsLetter(runes[end]) || unicode.IsNumber(runes[end]) || runes[end] == '_' || runes[end] == '-') { + end++ } + part = string(runes[totalConsumed:end]) + consumed = end - totalConsumed } else { break } - if current[0] == '`' && strings.Contains(part, ".") { - // This handles cases like `project.dataset.table` but not `project.dataset`.table. - // If the character after the quoted identifier is not a dot, we treat it as a full name. - if len(current) <= consumed || current[consumed] != '.' { - parts = append(parts, strings.Split(part, ".")...) - totalConsumed += consumed - break - } - } - parts = append(parts, strings.Split(part, ".")...) totalConsumed += consumed - if len(s) <= totalConsumed || s[totalConsumed] != '.' { + if totalConsumed >= len(runes) || runes[totalConsumed] != '.' { break } totalConsumed++ @@ -410,22 +463,193 @@ func parseIdentifierSequence(s string) ([]string, int, error) { return parts, totalConsumed, nil } +// IsAnyTableExplicitlyReferenced checks if any target tables are explicitly mentioned as +// identifiers in the SQL, correctly skipping comments and strings. +func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs []string) (bool, error) { + if len(targetTableIDs) == 0 { + return false, nil + } + + targets := make(map[string]struct{}) + for _, id := range targetTableIDs { + targets[strings.ToLower(id)] = struct{}{} + } + + state := stateNormal + runes := []rune(sql) + + for i := 0; i < len(runes); { + remaining := string(runes[i:]) + char := runes[i] + + switch state { + case stateNormal: + if strings.HasPrefix(remaining, "--") { + state = stateInSingleLineCommentDash + i += 2 + continue + } + if strings.HasPrefix(remaining, "#") { + state = stateInSingleLineCommentHash + i++ + continue + } + if strings.HasPrefix(remaining, "/*") { + state = stateInMultiLineComment + i += 2 + continue + } + + if unicode.IsLetter(char) || char == '`' || char == '_' { + parts, consumed, err := parseIdentifierSequence(remaining) + if err != nil { + return false, err + } + if consumed > 0 { + if len(parts) < 2 { + i += consumed + continue + } + fullID := strings.ToLower(strings.Join(parts, ".")) + for target := range targets { + // Match exact table name or as a prefix for column references. + if fullID == target || strings.HasPrefix(fullID, target+".") { + return true, nil + } + // Also try matching with the default project ID prefix. + if defaultProjectID != "" { + withDefault := strings.ToLower(defaultProjectID + "." + fullID) + if withDefault == target || strings.HasPrefix(withDefault, target+".") { + return true, nil + } + } + } + i += consumed + continue + } + } + + // Handle various BigQuery string literal formats. + remLow := strings.ToLower(remaining) + if strings.HasPrefix(remLow, "r'''") { + state = stateInRawTripleSingleQuoteString + i += 4 + continue + } + if strings.HasPrefix(remLow, `r"""`) { + state = stateInRawTripleDoubleQuoteString + i += 4 + continue + } + if strings.HasPrefix(remLow, "r'") { + state = stateInRawSingleQuoteString + i += 2 + continue + } + if strings.HasPrefix(remLow, `r"`) { + state = stateInRawDoubleQuoteString + i += 2 + continue + } + if strings.HasPrefix(remaining, "'''") { + state = stateInTripleSingleQuoteString + i += 3 + continue + } + if strings.HasPrefix(remaining, `"""`) { + state = stateInTripleDoubleQuoteString + i += 3 + continue + } + if char == '\'' { + state = stateInSingleQuoteString + i++ + continue + } + if char == '"' { + state = stateInDoubleQuoteString + i++ + continue + } + + case stateInSingleQuoteString: + if char == '\\' { + i += 2 + continue + } + if char == '\'' { + state = stateNormal + } + case stateInDoubleQuoteString: + if char == '\\' { + i += 2 + continue + } + if char == '"' { + state = stateNormal + } + case stateInTripleSingleQuoteString: + if strings.HasPrefix(remaining, "'''") { + state = stateNormal + i += 3 + continue + } + case stateInTripleDoubleQuoteString: + if strings.HasPrefix(remaining, `"""`) { + state = stateNormal + i += 3 + continue + } + case stateInSingleLineCommentDash, stateInSingleLineCommentHash: + if char == '\n' { + state = stateNormal + } + case stateInMultiLineComment: + if strings.HasPrefix(remaining, "*/") { + state = stateNormal + i += 2 + continue + } + case stateInRawSingleQuoteString: + if char == '\'' { + state = stateNormal + } + case stateInRawDoubleQuoteString: + if char == '"' { + state = stateNormal + } + case stateInRawTripleSingleQuoteString: + if strings.HasPrefix(remaining, "'''") { + state = stateNormal + i += 3 + continue + } + case stateInRawTripleDoubleQuoteString: + if strings.HasPrefix(remaining, `"""`) { + state = stateNormal + i += 3 + continue + } + } + i++ + } + + return false, nil +} + func formatTableID(parts []string, defaultProjectID string) (string, error) { if len(parts) < 2 || len(parts) > 3 { // Not a table identifier (could be a CTE, column, etc.). - // Return the consumed length so the main loop can skip this identifier. return "", nil } - var tableID string if len(parts) == 3 { // project.dataset.table - tableID = strings.Join(parts, ".") - } else { // dataset.table - if defaultProjectID == "" { - return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, ".")) - } - tableID = fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, ".")) + return strings.Join(parts, "."), nil } - return tableID, nil + // dataset.table + if defaultProjectID == "" { + return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, ".")) + } + return fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, ".")), nil } diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go b/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go index 662c5c4071f0..18855ec88472 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go @@ -471,6 +471,77 @@ func TestTableParser(t *testing.T) { wantErr: true, wantErrMsg: "unanalyzable statements like 'CREATE FUNCTION' are not allowed", }, + { + name: "alias filtering simple", + sql: "SELECT t1.col FROM proj.data.table AS t1", + defaultProjectID: "default-proj", + want: []string{"proj.data.table"}, + wantErr: false, + }, + { + name: "alias filtering complex", + sql: "SELECT t1.col1, t2.col2 FROM proj.data.tbl1 t1 JOIN proj.data.tbl2 AS t2 ON t1.id = t2.id", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl1", "proj.data.tbl2"}, + wantErr: false, + }, + { + name: "alias filtering in where clause", + sql: "SELECT * FROM proj.data.tbl1 AS t1 WHERE t1.id > 10", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl1"}, + wantErr: false, + }, + { + name: "unnest column reference", + sql: "SELECT x FROM `proj.ds.tbl` AS t, UNNEST(t.arr) AS x", + defaultProjectID: "default-proj", + want: []string{"proj.ds.tbl"}, + wantErr: false, + }, + { + name: "CTE with dots", + sql: "WITH `my.cte` AS (SELECT 1) SELECT * FROM `my.cte`", + defaultProjectID: "default-proj", + want: []string{}, + wantErr: false, + }, + { + name: "nested CTEs with dot-containing aliases", + sql: ` + WITH raw_metrics AS ( + SELECT id, score FROM production-data.analytics.events + ), + derived.results AS ( + SELECT id, score * 2 as double_score FROM raw_metrics + ) + SELECT * FROM derived.results WHERE double_score > 100 + `, + defaultProjectID: "default-proj", + want: []string{"production-data.analytics.events"}, + wantErr: false, + }, + { + name: "implicit join with comma", + sql: "SELECT * FROM proj.data.tbl1, proj.data.tbl2", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl1", "proj.data.tbl2"}, + wantErr: false, + }, + { + name: "implicit alias", + sql: "SELECT t.col FROM proj.data.tbl t", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl"}, + wantErr: false, + }, + { + name: "unnest column reference complex", + sql: "SELECT x FROM `proj.ds.tbl` AS t, UNNEST(t.arr) AS x JOIN `other.ds.tbl2` as o ON t.id = o.id", + defaultProjectID: "default-proj", + want: []string{"proj.ds.tbl", "other.ds.tbl2"}, + wantErr: false, + }, } for _, tc := range testCases { @@ -494,3 +565,131 @@ func TestTableParser(t *testing.T) { }) } } + +func TestIsAnyTableExplicitlyReferenced(t *testing.T) { + testCases := []struct { + name string + sql string + defaultProjectID string + targetTableIDs []string + want bool + }{ + { + name: "simple match", + sql: "SELECT * FROM `proj.ds.tbl`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "match without project id in sql", + sql: "SELECT * FROM `ds.tbl`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"def-proj.ds.tbl"}, + want: true, + }, + { + name: "no match", + sql: "SELECT * FROM `ds.view`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"def-proj.ds.tbl"}, + want: false, + }, + { + name: "ignore in strings", + sql: "SELECT 'proj.ds.tbl' FROM `ds.view` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: false, + }, + { + name: "ignore in comments", + sql: "SELECT * FROM `ds.view` -- referencing proj.ds.tbl here", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: false, + }, + { + name: "match in join", + sql: "SELECT * FROM `ds.view` JOIN `proj.ds.tbl` ON 1=1", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "match as column reference", + sql: "SELECT proj.ds.tbl.col FROM `ds.view`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "match with different casing", + sql: "SELECT * FROM `PROJ.ds.TBL`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "raw string ignore", + sql: "SELECT r'proj.ds.tbl' FROM `ds.view` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: false, + }, + { + name: "mixed quoting style", + sql: "SELECT * FROM `proj`.ds.`tbl` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "all parts quoted separately", + sql: "SELECT * FROM `proj`.`ds`.`tbl` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "middle part quoted", + sql: "SELECT * FROM proj.`ds`.tbl ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "fully qualified column reference", + sql: "SELECT proj.ds.tbl.col FROM `something` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "domain scoped project ID match", + sql: "SELECT * FROM `google.com:project.dataset.table` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"google.com:project.dataset.table"}, + want: true, + }, + { + name: "domain scoped project ID match with column", + sql: "SELECT `google.com:project.dataset.table`.col FROM `something_else` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"google.com:project.dataset.table"}, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := bigquerycommon.IsAnyTableExplicitlyReferenced(tc.sql, tc.defaultProjectID, tc.targetTableIDs) + if err != nil { + t.Fatalf("IsAnyTableExplicitlyReferenced() error = %v", err) + } + if got != tc.want { + t.Errorf("IsAnyTableExplicitlyReferenced() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 5486ac36eda1..c72c703c947b 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -57,6 +57,100 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj return insertResponse, nil } +// DatasetValidator defines the interface for checking if a dataset is allowed. +type DatasetValidator interface { + IsDatasetAllowed(projectID, datasetID string) bool +} + +// ValidateQueryAgainstAllowedDatasets validates a SQL query against a list of allowed datasets. +// It uses both dry run and a local parser to support authorized views. +func ValidateQueryAgainstAllowedDatasets( + ctx context.Context, + restService *bigqueryrestapi.Service, + projectID string, + location string, + sql string, + params []*bigqueryrestapi.QueryParameter, + connProps []*bigqueryapi.ConnectionProperty, + validator DatasetValidator, +) (*bigqueryrestapi.Job, error) { + dryRunJob, err := DryRunQuery(ctx, restService, projectID, location, sql, params, connProps) + if err != nil { + return nil, fmt.Errorf("query validation failed: %w", err) + } + + if dryRunJob.Statistics == nil || dryRunJob.Statistics.Query == nil { + return nil, fmt.Errorf("dry run failed to return query statistics") + } + statementType := dryRunJob.Statistics.Query.StatementType + // Common restricted operations + switch statementType { + case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": + return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) + case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE": + return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) + case "CALL": + return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) + } + + // Use a map to avoid duplicate table names from the dry run result. + tableIDSet := make(map[string]struct{}) + queryStats := dryRunJob.Statistics.Query + if queryStats != nil { + for _, tableRef := range queryStats.ReferencedTables { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } + if tableRef := queryStats.DdlTargetTable; tableRef != nil { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } + if tableRef := queryStats.DdlDestinationTable; tableRef != nil { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } + } + + var violatingTables []string + for tableID := range tableIDSet { + parts := strings.Split(tableID, ".") + if len(parts) == 3 { + if !validator.IsDatasetAllowed(parts[0], parts[1]) { + violatingTables = append(violatingTables, tableID) + } + } + } + + if len(tableIDSet) > 0 && len(violatingTables) == 0 { + return dryRunJob, nil + } + + // If violations were found, check if they are explicitly in the SQL to support authorized views. + if len(violatingTables) > 0 { + explicitlyReferenced, err := IsAnyTableExplicitlyReferenced(sql, projectID, violatingTables) + if err != nil { + return nil, fmt.Errorf("failed to analyze query for explicit table references: %w", err) + } + if explicitlyReferenced { + return nil, fmt.Errorf("query explicitly accesses dataset '%s', which is not in the allowed list", strings.Join(strings.Split(violatingTables[0], ".")[:2], ".")) + } + } + + // Fall back to TableParser for final intent verification or if dry run was inconclusive. + parsedTables, parseErr := TableParser(sql, projectID) + if parseErr != nil { + return nil, fmt.Errorf("could not safely analyze query with dataset restrictions: %w", parseErr) + } + + for _, tableID := range parsedTables { + parts := strings.Split(tableID, ".") + if len(parts) == 3 { + if !validator.IsDatasetAllowed(parts[0], parts[1]) { + return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", parts[0], parts[1]) + } + } + } + + return dryRunJob, nil +} + // BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string. func BQTypeStringFromToolType(toolType string) (string, error) { switch toolType { diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index 6f2fc245c9b6..54331e4f1abf 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -185,11 +185,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps) - if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) + var dryRunJob *bigqueryrestapi.Job + if len(source.BigQueryAllowedDatasets()) > 0 { + dryRunJob, err = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source) + if err != nil { + return nil, err + } + } else { + dryRunJob, err = bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps) + if err != nil { + return nil, fmt.Errorf("query validation failed: %w", err) + } } + if dryRunJob.Statistics == nil || dryRunJob.Statistics.Query == nil { + return nil, fmt.Errorf("dry run failed to return query statistics") + } statementType := dryRunJob.Statistics.Query.StatementType switch source.BigQueryWriteMode() { @@ -206,60 +217,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if len(source.BigQueryAllowedDatasets()) > 0 { - switch statementType { - case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": - return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) - case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE": - return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) - case "CALL": - return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) - } - - // Use a map to avoid duplicate table names. - tableIDSet := make(map[string]struct{}) - - // Get all tables from the dry run result. This is the most reliable method. - queryStats := dryRunJob.Statistics.Query - if queryStats != nil { - for _, tableRef := range queryStats.ReferencedTables { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} - } - if tableRef := queryStats.DdlTargetTable; tableRef != nil { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} - } - if tableRef := queryStats.DdlDestinationTable; tableRef != nil { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} - } - } - - var tableNames []string - if len(tableIDSet) > 0 { - for tableID := range tableIDSet { - tableNames = append(tableNames, tableID) - } - } else if statementType != "SELECT" { - // If dry run yields no tables, fall back to the parser for non-SELECT statements - // to catch unsafe operations like EXECUTE IMMEDIATE. - parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project()) - if parseErr != nil { - // If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail. - return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr) - } - tableNames = parsedTables - } - - for _, tableID := range tableNames { - parts := strings.Split(tableID, ".") - if len(parts) == 3 { - projectID, datasetID := parts[0], parts[1] - if !source.IsDatasetAllowed(projectID, datasetID) { - return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID) - } - } - } - } - if dryRun { if dryRunJob != nil { jobJSON, err := json.MarshalIndent(dryRunJob, "", " ") diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index b3d56fb46520..646fa286a50b 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -192,24 +192,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para {Key: "session_id", Value: session.ID}, } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps) - if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) - } - statementType := dryRunJob.Statistics.Query.StatementType - if statementType != "SELECT" { - return nil, fmt.Errorf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType) - } - queryStats := dryRunJob.Statistics.Query - if queryStats != nil { - for _, tableRef := range queryStats.ReferencedTables { - if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { - return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) - } - } - } else { - return nil, fmt.Errorf("could not analyze query in history_data to validate against allowed datasets") + if _, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source); err != nil { + return nil, err } } historyDataSource = fmt.Sprintf("(%s)", historyData) diff --git a/src/google-cloud-bigquery-storage b/src/google-cloud-bigquery-storage new file mode 160000 index 000000000000..316faaff534c --- /dev/null +++ b/src/google-cloud-bigquery-storage @@ -0,0 +1 @@ +Subproject commit 316faaff534c3a989d92ff78efd1cfaa0e45a10b diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index de5126cd2410..0f12873cb6ad 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -368,6 +368,141 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName) } +func TestBigQueryAuthorizedViewWithRestriction(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + client, err := initBigQueryConnection(BigqueryProject) + if err != nil { + t.Fatalf("unable to create BigQuery client: %s", err) + } + + baseName := strings.ReplaceAll(uuid.New().String(), "-", "") + allowedDatasetName := fmt.Sprintf("allowed_ds_%s", baseName) + disallowedDatasetName := fmt.Sprintf("disallowed_ds_%s", baseName) + tableName := "source_table" + viewName := "auth_view" + + // Create datasets + if err := client.Dataset(allowedDatasetName).Create(ctx, &bigqueryapi.DatasetMetadata{Location: "US"}); err != nil { + t.Fatalf("failed to create allowed dataset: %v", err) + } + defer client.Dataset(allowedDatasetName).DeleteWithContents(ctx) + + if err := client.Dataset(disallowedDatasetName).Create(ctx, &bigqueryapi.DatasetMetadata{Location: "US"}); err != nil { + t.Fatalf("failed to create disallowed dataset: %v", err) + } + defer client.Dataset(disallowedDatasetName).DeleteWithContents(ctx) + + // Create source table in disallowed dataset + tableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, tableName) + if _, err := client.Query(fmt.Sprintf("CREATE TABLE %s (id INT64)", tableFullName)).Run(ctx); err != nil { + // Wait and retry if it's a "still being created" error? No, usually it's fine. + t.Fatalf("failed to create source table: %v", err) + } + + // Create view in allowed dataset referencing the disallowed table + viewFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName, viewName) + metadata := &bigqueryapi.TableMetadata{ + ViewQuery: fmt.Sprintf("SELECT * FROM %s", tableFullName), + } + if err := client.Dataset(allowedDatasetName).Table(viewName).Create(ctx, metadata); err != nil { + t.Fatalf("failed to create view: %v", err) + } + + // Authorize the view to access the disallowed dataset + dsMetadata, err := client.Dataset(disallowedDatasetName).Metadata(ctx) + if err != nil { + t.Fatalf("failed to get disallowed dataset metadata: %v", err) + } + + newAccess := append(dsMetadata.Access, &bigqueryapi.AccessEntry{ + EntityType: bigqueryapi.ViewEntity, + View: client.Dataset(allowedDatasetName).Table(viewName), + }) + + update := bigqueryapi.DatasetMetadataToUpdate{ + Access: newAccess, + } + if _, err := client.Dataset(disallowedDatasetName).Update(ctx, update, dsMetadata.ETag); err != nil { + t.Fatalf("failed to authorize view: %v", err) + } + + // Configure toolbox with ONLY the allowed dataset + sourceConfig := getBigQueryVars(t) + sourceConfig["allowedDatasets"] = []string{allowedDatasetName} + config := map[string]any{ + "sources": map[string]any{"my-instance": sourceConfig}, + "tools": map[string]any{ + "execute-sql-restricted": map[string]any{ + "kind": "bigquery-execute-sql", + "source": "my-instance", + "description": "Tool to execute SQL with restriction", + }, + }, + } + + cmd, cleanup, err := tests.StartCmd(ctx, config) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Invoke the tool querying the view. This should now SUCCEED. + t.Run("query authorized view", func(t *testing.T) { + sql := fmt.Sprintf("SELECT * FROM %s", viewFullName) + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, sql))) + req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) + } + }) + + // Also verify that direct query to the disallowed table still FAILS. + t.Run("query disallowed table directly", func(t *testing.T) { + sql := fmt.Sprintf("SELECT * FROM %s", tableFullName) + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, sql))) + req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusBadRequest, string(bodyBytes)) + } + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), fmt.Sprintf("query explicitly accesses dataset '%s.%s', which is not in the allowed list", BigqueryProject, disallowedDatasetName)) { + t.Errorf("unexpected error message: %s", string(bodyBytes)) + } + }) +} + func TestBigQueryWriteModeAllowed(t *testing.T) { sourceConfig := getBigQueryVars(t) sourceConfig["writeMode"] = "allowed" @@ -2760,7 +2895,7 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed name: "invoke on disallowed table", sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusBadRequest, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", + wantInError: fmt.Sprintf("query explicitly accesses dataset '%s', which is not in the allowed list", strings.Join( strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2], ".")), @@ -2787,13 +2922,13 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed name: "disallowed create procedure", sql: fmt.Sprintf("CREATE PROCEDURE %s.my_proc() BEGIN SELECT 1; END", allowedDatasetID), wantStatusCode: http.StatusBadRequest, - wantInError: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed", + wantInError: "not allowed", }, { name: "disallowed execute immediate", sql: "EXECUTE IMMEDIATE 'SELECT 1'", wantStatusCode: http.StatusBadRequest, - wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place", + wantInError: "not allowed", }, } @@ -3096,7 +3231,7 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa name: "invoke with query on disallowed table", historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusBadRequest, - wantInError: fmt.Sprintf("query in history_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("query explicitly accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), }, } @@ -3187,7 +3322,7 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d name: "invoke with query on disallowed table", inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusBadRequest, - wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("query explicitly accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), }, } From 0fba781e71f205e8e57b13fc8cbab4f267eb1897 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Wed, 25 Feb 2026 07:52:42 +0000 Subject: [PATCH 02/15] support authorized views and align util/tests with latest main architecture --- .../bigqueryanalyzecontribution.go | 6 +- .../bigquerycommon/table_name_parser.go | 376 +++++++++++------- .../tools/bigquery/bigquerycommon/util.go | 4 +- .../bigqueryforecast/bigqueryforecast.go | 6 +- tests/bigquery/bigquery_integration_test.go | 335 ++++++---------- 5 files changed, 364 insertions(+), 363 deletions(-) diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 54a7a32ffd11..6430675acdbb 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -217,9 +217,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if _, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source); err != nil { + dryRunJob, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source) + if err != nil { return nil, util.ProcessGcpError(err) } + if dryRunJob.Statistics.Query.StatementType != "SELECT" { + return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", dryRunJob.Statistics.Query.StatementType), nil) + } } inputDataSource = fmt.Sprintf("(%s)", inputData) } else { diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index 23c9438477c5..8b4a178f2490 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -20,63 +20,79 @@ import ( "unicode" ) -// parserState defines the state of the SQL parser's state machine. type parserState int const ( stateNormal parserState = iota - // String states stateInSingleQuoteString stateInDoubleQuoteString stateInTripleSingleQuoteString stateInTripleDoubleQuoteString + stateInSingleLineCommentDash + stateInSingleLineCommentHash + stateInMultiLineComment stateInRawSingleQuoteString stateInRawDoubleQuoteString stateInRawTripleSingleQuoteString stateInRawTripleDoubleQuoteString - // Comment states - stateInSingleLineCommentDash - stateInSingleLineCommentHash - stateInMultiLineComment -) - -// SQL statement verbs -const ( - verbCreate = "create" - verbAlter = "alter" - verbDrop = "drop" - verbSelect = "select" - verbInsert = "insert" - verbUpdate = "update" - verbDelete = "delete" - verbMerge = "merge" ) var tableFollowsKeywords = map[string]bool{ "from": true, "join": true, + "into": true, "update": true, - "into": true, // INSERT INTO, MERGE INTO - "table": true, // CREATE TABLE, ALTER TABLE - "using": true, // MERGE ... USING - "insert": true, // INSERT my_table - "merge": true, // MERGE my_table + "table": true, } var tableContextExitKeywords = map[string]bool{ - "where": true, - "group": true, // GROUP BY - "having": true, - "order": true, // ORDER BY - "limit": true, - "window": true, - "on": true, // JOIN ... ON - "set": true, // UPDATE ... SET - "when": true, // MERGE ... WHEN + "where": true, + "group": true, + "order": true, + "having": true, + "limit": true, + "window": true, + "union": true, + "intersect": true, + "except": true, +} + +// hasPrefix checks if the runes starting at offset match the given prefix. +func hasPrefix(r []rune, offset int, prefix string) bool { + if offset+len(prefix) > len(r) { + return false + } + for i := 0; i < len(prefix); i++ { + if r[offset+i] != rune(prefix[i]) { + return false + } + } + return true +} + +// hasPrefixFold checks if the runes starting at offset match the given prefix, ignoring case (ASCII only). +func hasPrefixFold(r []rune, offset int, prefix string) bool { + if offset+len(prefix) > len(r) { + return false + } + for i := 0; i < len(prefix); i++ { + rChar := r[offset+i] + pChar := rune(prefix[i]) + if rChar >= 'A' && rChar <= 'Z' { + rChar += 32 + } + if pChar >= 'A' && pChar <= 'Z' { + pChar += 32 + } + if rChar != pChar { + return false + } + } + return true } -// TableParser is the main entry point for parsing a SQL string to find all referenced table IDs. -// It handles multi-statement SQL, comments, and recursive parsing of EXECUTE IMMEDIATE statements. +// TableParser parses a SQL query and returns a list of table IDs that it references. +// It is intended as a conservative fallback for when a dry run cannot be performed or analyzed. func TableParser(sql, defaultProjectID string) ([]string, error) { tableIDSet := make(map[string]struct{}) visitedSQLs := make(map[string]struct{}) @@ -118,22 +134,21 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi runes := []rune(sql) for i := 0; i < len(runes); { - remaining := string(runes[i:]) char := runes[i] switch state { case stateNormal: - if strings.HasPrefix(remaining, "--") { + if hasPrefix(runes, i, "--") { state = stateInSingleLineCommentDash i += 2 continue } - if strings.HasPrefix(remaining, "#") { + if char == '#' { state = stateInSingleLineCommentHash i++ continue } - if strings.HasPrefix(remaining, "/*") { + if hasPrefix(runes, i, "/*") { state = stateInMultiLineComment i += 2 continue @@ -178,33 +193,34 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i++ continue } - remLow := strings.ToLower(remaining) - if strings.HasPrefix(remLow, "r'''") { + + // Raw strings must be checked before regular strings. + if hasPrefixFold(runes, i, "r'''") { state = stateInRawTripleSingleQuoteString i += 4 continue } - if strings.HasPrefix(remLow, `r"""`) { + if hasPrefixFold(runes, i, `r"""`) { state = stateInRawTripleDoubleQuoteString i += 4 continue } - if strings.HasPrefix(remLow, "r'") { + if hasPrefixFold(runes, i, "r'") { state = stateInRawSingleQuoteString i += 2 continue } - if strings.HasPrefix(remLow, `r"`) { + if hasPrefixFold(runes, i, `r"`) { state = stateInRawDoubleQuoteString i += 2 continue } - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateInTripleSingleQuoteString i += 3 continue } - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateInTripleDoubleQuoteString i += 3 continue @@ -221,7 +237,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } if unicode.IsLetter(char) || char == '`' || char == '_' { - parts, consumed, err := parseIdentifierSequence(remaining) + parts, consumed, err := parseIdentifierSequence(runes[i:]) if err != nil { return 0, err } @@ -229,33 +245,20 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i++ continue } + keyword := strings.ToLower(parts[0]) fullID := strings.ToLower(strings.Join(parts, ".")) - // Handle security-restricted operations and verb identification. - if len(parts) == 1 { - switch keyword { - case "call": - return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place, as the called procedure's contents cannot be safely analyzed") - case "immediate": - if lastToken == "execute" { - return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place, as its contents cannot be safely analyzed") - } - case "procedure", "function": - if lastToken == "create" || lastToken == "create or replace" { - return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", strings.ToUpper(lastToken), strings.ToUpper(keyword)) - } - case verbCreate, verbAlter, verbDrop, verbSelect, verbInsert, verbUpdate, verbDelete, verbMerge: - if statementVerb == "" { - statementVerb = keyword - } - } - - if statementVerb == verbCreate || statementVerb == verbAlter || statementVerb == verbDrop { - if keyword == "schema" || keyword == "dataset" { - return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed when dataset restrictions are in place", strings.ToUpper(statementVerb), strings.ToUpper(keyword)) - } + if lastToken == "execute" && keyword == "immediate" { + // Found EXECUTE IMMEDIATE. The first expression must be the SQL string. + // Search for the next string literal. + sqlConsumed, err := findAndParseSQLString(runes[i+consumed:], defaultProjectID, tableIDSet, visitedSQLs, aliases) + if err != nil { + return 0, err } + i += consumed + sqlConsumed + lastToken = "execute immediate" + continue } // Resolve aliases and identify table references. @@ -343,7 +346,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i++ case stateInSingleQuoteString: if char == '\\' { - i += 2 // Skip backslash and the escaped character. + i += 2 continue } if char == '\'' { @@ -352,7 +355,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i++ case stateInDoubleQuoteString: if char == '\\' { - i += 2 // Skip backslash and the escaped character. + i += 2 continue } if char == '"' { @@ -360,14 +363,14 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInTripleSingleQuoteString: - if strings.HasPrefix(string(runes[i:]), "'''") { + if hasPrefix(runes, i, "'''") { state = stateNormal i += 3 } else { i++ } case stateInTripleDoubleQuoteString: - if strings.HasPrefix(string(runes[i:]), `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateNormal i += 3 } else { @@ -379,7 +382,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInMultiLineComment: - if strings.HasPrefix(string(runes[i:]), "*/") { + if hasPrefix(runes, i, "*/") { state = stateNormal i += 2 } else { @@ -396,14 +399,14 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInRawTripleSingleQuoteString: - if strings.HasPrefix(string(runes[i:]), "'''") { + if hasPrefix(runes, i, "'''") { state = stateNormal i += 3 } else { i++ } case stateInRawTripleDoubleQuoteString: - if strings.HasPrefix(string(runes[i:]), `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateNormal i += 3 } else { @@ -417,91 +420,99 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi return len(runes), nil } -// parseIdentifierSequence parses a sequence of dot-separated identifiers. -// It returns the parts of the identifier, the number of characters consumed, and an error. -func parseIdentifierSequence(s string) ([]string, int, error) { - var parts []string - var totalConsumed int - runes := []rune(s) - for { - for totalConsumed < len(runes) && unicode.IsSpace(runes[totalConsumed]) { - totalConsumed++ - } - if totalConsumed >= len(runes) { - break +// findAndParseSQLString scans for the first string literal and parses its content as SQL. +func findAndParseSQLString(runes []rune, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, aliases map[string]struct{}) (int, error) { + for i := 0; i < len(runes); { + if hasPrefix(runes, i, "'''") { + end := strings.Index(string(runes[i+3:]), "'''") + if end != -1 { + sqlContent := string(runes[i+3 : i+3+end]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return i + 3 + end + 3, nil + } } - - var part string - var consumed int - - if runes[totalConsumed] == '`' { - end := strings.Index(string(runes[totalConsumed+1:]), "`") - if end == -1 { - return nil, 0, fmt.Errorf("unclosed backtick identifier") + if hasPrefix(runes, i, `"""`) { + end := strings.Index(string(runes[i+3:]), `"""`) + if end != -1 { + sqlContent := string(runes[i+3 : i+3+end]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return i + 3 + end + 3, nil } - part = string(runes[totalConsumed+1 : totalConsumed+end+1]) - consumed = end + 2 - } else if unicode.IsLetter(runes[totalConsumed]) || runes[totalConsumed] == '_' { - end := totalConsumed - for end < len(runes) && (unicode.IsLetter(runes[end]) || unicode.IsNumber(runes[end]) || runes[end] == '_' || runes[end] == '-') { - end++ + } + if runes[i] == '\'' { + // Find end of single-quoted string, respecting backslash escapes. + for j := i + 1; j < len(runes); j++ { + if runes[j] == '\\' { + j++ + continue + } + if runes[j] == '\'' { + sqlContent := string(runes[i+1 : j]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return j + 1, nil + } } - part = string(runes[totalConsumed:end]) - consumed = end - totalConsumed - } else { - break } - - parts = append(parts, strings.Split(part, ".")...) - totalConsumed += consumed - - if totalConsumed >= len(runes) || runes[totalConsumed] != '.' { - break + if runes[i] == '"' { + for j := i + 1; j < len(runes); j++ { + if runes[j] == '\\' { + j++ + continue + } + if runes[j] == '"' { + sqlContent := string(runes[i+1 : j]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return j + 1, nil + } + } } - totalConsumed++ + i++ } - return parts, totalConsumed, nil + return len(runes), nil } -// IsAnyTableExplicitlyReferenced checks if any target tables are explicitly mentioned as -// identifiers in the SQL, correctly skipping comments and strings. +// IsAnyTableExplicitlyReferenced performs a lexical audit of the SQL to see if any of the target tables +// are explicitly named as identifiers. It correctly ignores names inside comments or strings. func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs []string) (bool, error) { - if len(targetTableIDs) == 0 { - return false, nil - } - targets := make(map[string]struct{}) for _, id := range targetTableIDs { targets[strings.ToLower(id)] = struct{}{} } - state := stateNormal runes := []rune(sql) + state := stateNormal for i := 0; i < len(runes); { - remaining := string(runes[i:]) char := runes[i] switch state { case stateNormal: - if strings.HasPrefix(remaining, "--") { + if hasPrefix(runes, i, "--") { state = stateInSingleLineCommentDash i += 2 continue } - if strings.HasPrefix(remaining, "#") { + if char == '#' { state = stateInSingleLineCommentHash i++ continue } - if strings.HasPrefix(remaining, "/*") { + if hasPrefix(runes, i, "/*") { state = stateInMultiLineComment i += 2 continue } if unicode.IsLetter(char) || char == '`' || char == '_' { - parts, consumed, err := parseIdentifierSequence(remaining) + parts, consumed, err := parseIdentifierSequence(runes[i:]) if err != nil { return false, err } @@ -530,33 +541,32 @@ func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs } // Handle various BigQuery string literal formats. - remLow := strings.ToLower(remaining) - if strings.HasPrefix(remLow, "r'''") { + if hasPrefixFold(runes, i, "r'''") { state = stateInRawTripleSingleQuoteString i += 4 continue } - if strings.HasPrefix(remLow, `r"""`) { + if hasPrefixFold(runes, i, `r"""`) { state = stateInRawTripleDoubleQuoteString i += 4 continue } - if strings.HasPrefix(remLow, "r'") { + if hasPrefixFold(runes, i, "r'") { state = stateInRawSingleQuoteString i += 2 continue } - if strings.HasPrefix(remLow, `r"`) { + if hasPrefixFold(runes, i, `r"`) { state = stateInRawDoubleQuoteString i += 2 continue } - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateInTripleSingleQuoteString i += 3 continue } - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateInTripleDoubleQuoteString i += 3 continue @@ -589,13 +599,13 @@ func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs state = stateNormal } case stateInTripleSingleQuoteString: - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateNormal i += 3 continue } case stateInTripleDoubleQuoteString: - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateNormal i += 3 continue @@ -605,7 +615,7 @@ func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs state = stateNormal } case stateInMultiLineComment: - if strings.HasPrefix(remaining, "*/") { + if hasPrefix(runes, i, "*/") { state = stateNormal i += 2 continue @@ -619,13 +629,13 @@ func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs state = stateNormal } case stateInRawTripleSingleQuoteString: - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateNormal i += 3 continue } case stateInRawTripleDoubleQuoteString: - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateNormal i += 3 continue @@ -637,6 +647,96 @@ func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs return false, nil } +// parseIdentifierSequence parses a sequence of dot-separated identifiers. +// It returns the parts of the identifier, the number of characters consumed, and an error. +func parseIdentifierSequence(runes []rune) ([]string, int, error) { + var parts []string + var totalConsumed int + for { + // Skip whitespace and comments before identifier part + for { + originalConsumed := totalConsumed + for totalConsumed < len(runes) && unicode.IsSpace(runes[totalConsumed]) { + totalConsumed++ + } + if hasPrefix(runes, totalConsumed, "/*") { + endIdx := strings.Index(string(runes[totalConsumed:]), "*/") + if endIdx != -1 { + totalConsumed += endIdx + 2 + } + } else if hasPrefix(runes, totalConsumed, "--") || (totalConsumed < len(runes) && runes[totalConsumed] == '#') { + endIdx := strings.Index(string(runes[totalConsumed:]), "\n") + if endIdx != -1 { + totalConsumed += endIdx + 1 + } else { + totalConsumed = len(runes) + } + } + if totalConsumed == originalConsumed { + break + } + } + if totalConsumed >= len(runes) { + break + } + + var part string + var consumed int + + if runes[totalConsumed] == '`' { + end := strings.Index(string(runes[totalConsumed+1:]), "`") + if end == -1 { + return nil, 0, fmt.Errorf("unclosed backtick identifier") + } + part = string(runes[totalConsumed+1 : totalConsumed+end+1]) + consumed = end + 2 + } else if unicode.IsLetter(runes[totalConsumed]) || runes[totalConsumed] == '_' { + end := totalConsumed + for end < len(runes) && (unicode.IsLetter(runes[end]) || unicode.IsNumber(runes[end]) || runes[end] == '_' || runes[end] == '-') { + end++ + } + part = string(runes[totalConsumed:end]) + consumed = end - totalConsumed + } else { + break + } + + parts = append(parts, strings.Split(part, ".")...) + totalConsumed += consumed + + // Skip whitespace and comments between parts (before potential dot) + for { + originalConsumed := totalConsumed + for totalConsumed < len(runes) && unicode.IsSpace(runes[totalConsumed]) { + totalConsumed++ + } + if hasPrefix(runes, totalConsumed, "/*") { + endIdx := strings.Index(string(runes[totalConsumed:]), "*/") + if endIdx != -1 { + totalConsumed += endIdx + 2 + } + } else if hasPrefix(runes, totalConsumed, "--") || (totalConsumed < len(runes) && runes[totalConsumed] == '#') { + endIdx := strings.Index(string(runes[totalConsumed:]), "\n") + if endIdx != -1 { + totalConsumed += endIdx + 1 + } else { + totalConsumed = len(runes) + } + } + if totalConsumed == originalConsumed { + break + } + } + + if totalConsumed >= len(runes) || runes[totalConsumed] != '.' { + break + } + totalConsumed++ + } + + return parts, totalConsumed, nil +} + func formatTableID(parts []string, defaultProjectID string) (string, error) { if len(parts) < 2 || len(parts) > 3 { // Not a table identifier (could be a CTE, column, etc.). diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index c72c703c947b..570aa695d444 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -129,7 +129,7 @@ func ValidateQueryAgainstAllowedDatasets( return nil, fmt.Errorf("failed to analyze query for explicit table references: %w", err) } if explicitlyReferenced { - return nil, fmt.Errorf("query explicitly accesses dataset '%s', which is not in the allowed list", strings.Join(strings.Split(violatingTables[0], ".")[:2], ".")) + return nil, fmt.Errorf("access to dataset '%s' is not allowed", strings.Join(strings.Split(violatingTables[0], ".")[:2], ".")) } } @@ -143,7 +143,7 @@ func ValidateQueryAgainstAllowedDatasets( parts := strings.Split(tableID, ".") if len(parts) == 3 { if !validator.IsDatasetAllowed(parts[0], parts[1]) { - return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", parts[0], parts[1]) + return nil, fmt.Errorf("access to dataset '%s.%s' is not allowed", parts[0], parts[1]) } } } diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index befa3649db0b..825765c21a55 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -194,9 +194,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if _, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source); err != nil { + dryRunJob, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source) + if err != nil { return nil, util.ProcessGcpError(err) } + if dryRunJob.Statistics.Query.StatementType != "SELECT" { + return nil, util.NewAgentError(fmt.Sprintf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", dryRunJob.Statistics.Query.StatementType), nil) + } } historyDataSource = fmt.Sprintf("(%s)", historyData) } else { diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 170115eb0ebc..29e1f22a38e3 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -279,6 +279,43 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams) defer teardownDisallowedAnalyzeContribution(t) + // Setup authorized views in BOTH allowed datasets pointing to disallowed tables + viewInAllowedPointingToDisallowedForecastName := "auth_view_forecast" + viewInAllowedPointingToDisallowedAnalyzeName := "auth_view_analyze" + + // Create Forecast views + for _, dsName := range []string{allowedDatasetName1, allowedDatasetName2} { + if err := client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedForecastName).Create(ctx, &bigqueryapi.TableMetadata{ + ViewQuery: fmt.Sprintf("SELECT * FROM %s", disallowedForecastTableFullName), + }); err != nil { + t.Fatalf("failed to create forecast view in %s: %v", dsName, err) + } + defer client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedForecastName).Delete(ctx) + + if err := client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedAnalyzeName).Create(ctx, &bigqueryapi.TableMetadata{ + ViewQuery: fmt.Sprintf("SELECT * FROM %s", disallowedAnalyzeContributionTableFullName), + }); err != nil { + t.Fatalf("failed to create analyze view in %s: %v", dsName, err) + } + defer client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedAnalyzeName).Delete(ctx) + } + + // Authorize ALL views to access the disallowed dataset + dsMetadata, err := client.Dataset(disallowedDatasetName).Metadata(ctx) + if err != nil { + t.Fatalf("failed to get disallowed dataset metadata: %v", err) + } + newAccess := append(dsMetadata.Access, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName1).Table(viewInAllowedPointingToDisallowedForecastName)}, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName2).Table(viewInAllowedPointingToDisallowedForecastName)}, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName1).Table(viewInAllowedPointingToDisallowedAnalyzeName)}, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName2).Table(viewInAllowedPointingToDisallowedAnalyzeName)}, + ) + update := bigqueryapi.DatasetMetadataToUpdate{Access: newAccess} + if _, err := client.Dataset(disallowedDatasetName).Update(ctx, update, dsMetadata.ETag); err != nil { + t.Fatalf("failed to authorize views: %v", err) + } + // Configure source with dataset restriction. sourceConfig := getBigQueryVars(t) sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2} @@ -368,141 +405,6 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName) } -func TestBigQueryAuthorizedViewWithRestriction(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - client, err := initBigQueryConnection(BigqueryProject) - if err != nil { - t.Fatalf("unable to create BigQuery client: %s", err) - } - - baseName := strings.ReplaceAll(uuid.New().String(), "-", "") - allowedDatasetName := fmt.Sprintf("allowed_ds_%s", baseName) - disallowedDatasetName := fmt.Sprintf("disallowed_ds_%s", baseName) - tableName := "source_table" - viewName := "auth_view" - - // Create datasets - if err := client.Dataset(allowedDatasetName).Create(ctx, &bigqueryapi.DatasetMetadata{Location: "US"}); err != nil { - t.Fatalf("failed to create allowed dataset: %v", err) - } - defer client.Dataset(allowedDatasetName).DeleteWithContents(ctx) - - if err := client.Dataset(disallowedDatasetName).Create(ctx, &bigqueryapi.DatasetMetadata{Location: "US"}); err != nil { - t.Fatalf("failed to create disallowed dataset: %v", err) - } - defer client.Dataset(disallowedDatasetName).DeleteWithContents(ctx) - - // Create source table in disallowed dataset - tableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, tableName) - if _, err := client.Query(fmt.Sprintf("CREATE TABLE %s (id INT64)", tableFullName)).Run(ctx); err != nil { - // Wait and retry if it's a "still being created" error? No, usually it's fine. - t.Fatalf("failed to create source table: %v", err) - } - - // Create view in allowed dataset referencing the disallowed table - viewFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName, viewName) - metadata := &bigqueryapi.TableMetadata{ - ViewQuery: fmt.Sprintf("SELECT * FROM %s", tableFullName), - } - if err := client.Dataset(allowedDatasetName).Table(viewName).Create(ctx, metadata); err != nil { - t.Fatalf("failed to create view: %v", err) - } - - // Authorize the view to access the disallowed dataset - dsMetadata, err := client.Dataset(disallowedDatasetName).Metadata(ctx) - if err != nil { - t.Fatalf("failed to get disallowed dataset metadata: %v", err) - } - - newAccess := append(dsMetadata.Access, &bigqueryapi.AccessEntry{ - EntityType: bigqueryapi.ViewEntity, - View: client.Dataset(allowedDatasetName).Table(viewName), - }) - - update := bigqueryapi.DatasetMetadataToUpdate{ - Access: newAccess, - } - if _, err := client.Dataset(disallowedDatasetName).Update(ctx, update, dsMetadata.ETag); err != nil { - t.Fatalf("failed to authorize view: %v", err) - } - - // Configure toolbox with ONLY the allowed dataset - sourceConfig := getBigQueryVars(t) - sourceConfig["allowedDatasets"] = []string{allowedDatasetName} - config := map[string]any{ - "sources": map[string]any{"my-instance": sourceConfig}, - "tools": map[string]any{ - "execute-sql-restricted": map[string]any{ - "kind": "bigquery-execute-sql", - "source": "my-instance", - "description": "Tool to execute SQL with restriction", - }, - }, - } - - cmd, cleanup, err := tests.StartCmd(ctx, config) - if err != nil { - t.Fatalf("command initialization returned an error: %s", err) - } - defer cleanup() - - waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) - if err != nil { - t.Logf("toolbox command logs: \n%s", out) - t.Fatalf("toolbox didn't start successfully: %s", err) - } - - // Invoke the tool querying the view. This should now SUCCEED. - t.Run("query authorized view", func(t *testing.T) { - sql := fmt.Sprintf("SELECT * FROM %s", viewFullName) - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, sql))) - req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) - } - }) - - // Also verify that direct query to the disallowed table still FAILS. - t.Run("query disallowed table directly", func(t *testing.T) { - sql := fmt.Sprintf("SELECT * FROM %s", tableFullName) - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, sql))) - req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) - } - bodyBytes, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(bodyBytes), fmt.Sprintf("query explicitly accesses dataset '%s.%s', which is not in the allowed list", BigqueryProject, disallowedDatasetName)) { - t.Errorf("unexpected error message: %s", string(bodyBytes)) - } - }) -} - func TestBigQueryWriteModeAllowed(t *testing.T) { sourceConfig := getBigQueryVars(t) sourceConfig["writeMode"] = "allowed" @@ -1935,7 +1837,7 @@ func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) { name: "invoke my-list-dataset-ids-tool with non-existent project", api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\"}", BigqueryProject+"-nonexistent"))), + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\"}", BigqueryProject, uuid.NewString()))), isErr: true, }, { @@ -2064,7 +1966,7 @@ func runBigQueryGetDatasetInfoToolInvokeTest(t *testing.T, datasetName, datasetI name: "Invoke my-auth-get-dataset-info-tool with non-existent project", api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\"}", BigqueryProject+"-nonexistent", datasetName))), + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName))), isErr: true, }, { @@ -2229,7 +2131,7 @@ func runBigQueryListTableIdsToolInvokeTest(t *testing.T, datasetName, tablename_ name: "Invoke my-auth-list-table-ids-tool with non-existent project", api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\"}", BigqueryProject+"-nonexistent", datasetName))), + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName))), isErr: true, }, { @@ -2379,7 +2281,7 @@ func runBigQueryGetTableInfoToolInvokeTest(t *testing.T, datasetName, tableName, name: "Invoke my-auth-get-table-info-tool with non-existent project", api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\", \"table\":\"%s\"}", BigqueryProject+"-nonexistent", datasetName, tableName))), + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\", \"table\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName, tableName))), isErr: true, }, { @@ -2683,6 +2585,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed } func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string, allowedTableNames ...string) { + allowedTableNames = append(allowedTableNames, "auth_view_forecast", "auth_view_analyze") sort.Strings(allowedTableNames) var quotedNames []string for _, name := range allowedTableNames { @@ -2725,9 +2628,9 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != tc.wantStatusCode { bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) } if tc.wantInResult != "" { @@ -2756,16 +2659,9 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed } if tc.wantInError != "" { - var respBody map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - t.Fatalf("error parsing response body: %v", err) - } - got, ok := respBody["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } - if !strings.Contains(got, tc.wantInError) { - t.Errorf("unexpected error message: got %q, want to contain %q", got, tc.wantInError) + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } }) @@ -2806,22 +2702,15 @@ func runGetDatasetInfoWithRestriction(t *testing.T, allowedDatasetName, disallow } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != tc.wantStatusCode { bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) } if tc.wantInError != "" { - var respBody map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - t.Fatalf("error parsing response body: %v", err) - } - got, ok := respBody["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } - if !strings.Contains(got, tc.wantInError) { - t.Errorf("unexpected error message: got %q, want to contain %q", got, tc.wantInError) + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } }) @@ -2864,22 +2753,15 @@ func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowed } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != tc.wantStatusCode { bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) } if tc.wantInError != "" { - var respBody map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - t.Fatalf("error parsing response body: %v", err) - } - got, ok := respBody["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } - if !strings.Contains(got, tc.wantInError) { - t.Errorf("unexpected error message: got %q, want to contain %q", got, tc.wantInError) + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } }) @@ -2891,7 +2773,8 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed if len(allowedTableParts) != 3 { t.Fatalf("invalid allowed table name format: %s", allowedTableFullName) } - allowedDatasetID := allowedTableParts[1] + allowedProjectID, allowedDatasetID := allowedTableParts[0], allowedTableParts[1] + viewFullName := fmt.Sprintf("`%s.%s.auth_view_forecast`", allowedProjectID, allowedDatasetID) testCases := []struct { name string @@ -2904,11 +2787,16 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed sql: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName), wantStatusCode: http.StatusOK, }, + { + name: "invoke on authorized view", + sql: fmt.Sprintf("SELECT * FROM %s", viewFullName), + wantStatusCode: http.StatusOK, + }, { name: "invoke on disallowed table", sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query explicitly accesses dataset '%s', which is not in the allowed list", + wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", strings.Join( strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2], ".")), @@ -2935,13 +2823,13 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed name: "disallowed create procedure", sql: fmt.Sprintf("CREATE PROCEDURE %s.my_proc() BEGIN SELECT 1; END", allowedDatasetID), wantStatusCode: http.StatusOK, - wantInError: "not allowed", + wantInError: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed", }, { name: "disallowed execute immediate", sql: "EXECUTE IMMEDIATE 'SELECT 1'", wantStatusCode: http.StatusOK, - wantInError: "not allowed", + wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place", }, } @@ -2965,16 +2853,9 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed } if tc.wantInError != "" { - var respBody map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - t.Fatalf("error parsing response body: %v", err) - } - got, ok := respBody["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } - if !strings.Contains(got, tc.wantInError) { - t.Errorf("unexpected error message: got %q, want to contain %q", got, tc.wantInError) + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } }) @@ -3049,16 +2930,9 @@ func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, } if tc.wantInError != "" { - var respBody map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - t.Fatalf("error parsing response body: %v", err) - } - got, ok := respBody["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } - if !strings.Contains(got, tc.wantInError) { - t.Errorf("unexpected error message: got %q, want to contain %q", got, tc.wantInError) + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } }) @@ -3123,7 +2997,7 @@ func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, ta name: "Invoke my-auth-search-catalog-tool with non-existent project", api: "http://127.0.0.1:5000/api/tool/my-auth-search-catalog-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": idToken}, - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"projectIds\":[\"%s\"], \"datasetIds\":[\"%s\"]}", tableName, BigqueryProject+"-nonexistent", datasetName))), + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"projectIds\":[\"%s-%s\"], \"datasetIds\":[\"%s\"]}", tableName, BigqueryProject, uuid.NewString(), datasetName))), isErr: true, }, { @@ -3241,6 +3115,10 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "") disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".") + allowedParts := strings.Split(allowedTableUnquoted, ".") + viewFullName := fmt.Sprintf("`%s.%s.auth_view_forecast`", allowedParts[0], allowedParts[1]) + viewUnquoted := strings.ReplaceAll(viewFullName, "`", "") + testCases := []struct { name string historyData string @@ -3254,6 +3132,12 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa wantStatusCode: http.StatusOK, wantInResult: `"forecast_timestamp"`, }, + { + name: "invoke with authorized view name", + historyData: viewUnquoted, + wantStatusCode: http.StatusOK, + wantInResult: `"forecast_timestamp"`, + }, { name: "invoke with disallowed table name", historyData: disallowedTableUnquoted, @@ -3266,11 +3150,17 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa wantStatusCode: http.StatusOK, wantInResult: `"forecast_timestamp"`, }, + { + name: "invoke with query on authorized view", + historyData: fmt.Sprintf("SELECT * FROM %s", viewFullName), + wantStatusCode: http.StatusOK, + wantInResult: `"forecast_timestamp"`, + }, { name: "invoke with query on disallowed table", historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query explicitly accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), }, } @@ -3313,21 +3203,14 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa t.Fatalf("unable to find result in response body") } if !strings.Contains(got, tc.wantInResult) { - t.Errorf("unexpected result: got %q, want %q", got, tc.wantInResult) + t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult) } } if tc.wantInError != "" { - var respBody map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - t.Fatalf("error parsing response body: %v", err) - } - got, ok := respBody["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") - } - if !strings.Contains(got, tc.wantInError) { - t.Errorf("unexpected error message: got %q, want to contain %q", got, tc.wantInError) + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } }) @@ -3339,6 +3222,10 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "") disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".") + allowedParts := strings.Split(allowedTableUnquoted, ".") + viewFullName := fmt.Sprintf("`%s.%s.auth_view_analyze`", allowedParts[0], allowedParts[1]) + viewUnquoted := strings.ReplaceAll(viewFullName, "`", "") + testCases := []struct { name string inputData string @@ -3352,11 +3239,17 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d wantStatusCode: http.StatusOK, wantInResult: `"relative_difference"`, }, + { + name: "invoke with authorized view name", + inputData: viewUnquoted, + wantStatusCode: http.StatusOK, + wantInResult: `"relative_difference"`, + }, { name: "invoke with disallowed table name", inputData: disallowedTableUnquoted, wantStatusCode: http.StatusOK, - wantInResult: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), + wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), }, { name: "invoke with query on allowed table", @@ -3364,11 +3257,17 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d wantStatusCode: http.StatusOK, wantInResult: `"relative_difference"`, }, + { + name: "invoke with query on authorized view", + inputData: fmt.Sprintf("SELECT * FROM %s", viewFullName), + wantStatusCode: http.StatusOK, + wantInResult: `"relative_difference"`, + }, { name: "invoke with query on disallowed table", inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query explicitly accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), }, } @@ -3409,13 +3308,7 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d } if tc.wantInError != "" { - var got string - if respBody["result"] != nil { - got, _ = respBody["result"].(string) - } else if respBody["error"] != nil { - got, _ = respBody["error"].(string) - } - if !strings.Contains(got, tc.wantInError) && !strings.Contains(string(bodyBytes), tc.wantInError) { + if !strings.Contains(string(bodyBytes), tc.wantInError) { t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) } } From 74f479535653da045d1d45bd7a268cf26762148a Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Wed, 25 Feb 2026 09:13:36 +0000 Subject: [PATCH 03/15] resolve linting errors and test failures --- .../bigquerycommon/table_name_parser.go | 67 ++++++++++---- .../bigquerycommon/table_name_parser_test.go | 88 +++++++++++++++++++ .../tools/bigquery/bigquerycommon/util.go | 10 --- tests/bigquery/bigquery_integration_test.go | 80 +++++++++-------- 4 files changed, 183 insertions(+), 62 deletions(-) diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index 8b4a178f2490..e007edad9479 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -43,18 +43,24 @@ var tableFollowsKeywords = map[string]bool{ "into": true, "update": true, "table": true, + "using": true, + "insert": true, + "merge": true, } var tableContextExitKeywords = map[string]bool{ - "where": true, - "group": true, - "order": true, - "having": true, - "limit": true, - "window": true, - "union": true, + "where": true, + "group": true, + "order": true, + "having": true, + "limit": true, + "window": true, + "union": true, "intersect": true, - "except": true, + "except": true, + "on": true, + "set": true, + "when": true, } // hasPrefix checks if the runes starting at offset match the given prefix. @@ -249,6 +255,26 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi keyword := strings.ToLower(parts[0]) fullID := strings.ToLower(strings.Join(parts, ".")) + // Security check for restricted statements + if keyword == "immediate" && lastToken == "execute" { + return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place") + } + if (lastToken == "create" || lastToken == "create or" || lastToken == "create or replace") && + (keyword == "procedure" || keyword == "function" || keyword == "table function") { + tokenToReport := strings.ToUpper(lastToken) + if tokenToReport == "" { + tokenToReport = "CREATE" + } + return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", tokenToReport, strings.ToUpper(keyword)) + } + if keyword == "call" { + return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place") + } + if (statementVerb == "create" || statementVerb == "alter" || statementVerb == "drop") && + (keyword == "schema" || keyword == "dataset") { + return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed", strings.ToUpper(statementVerb), strings.ToUpper(keyword)) + } + if lastToken == "execute" && keyword == "immediate" { // Found EXECUTE IMMEDIATE. The first expression must be the SQL string. // Search for the next string literal. @@ -337,6 +363,12 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } else { lastToken = keyword } + // Also track statement verb for schema checks + if keyword == "select" || keyword == "insert" || keyword == "update" || keyword == "delete" || keyword == "merge" || keyword == "create" || keyword == "alter" || keyword == "drop" { + if statementVerb == "" || statementVerb == "with" { + statementVerb = keyword + } + } } else { lastToken = "" } @@ -517,20 +549,23 @@ func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs return false, err } if consumed > 0 { - if len(parts) < 2 { - i += consumed - continue - } fullID := strings.ToLower(strings.Join(parts, ".")) for target := range targets { - // Match exact table name or as a prefix for column references. + // Exact match or as a prefix for column references. if fullID == target || strings.HasPrefix(fullID, target+".") { return true, nil } - // Also try matching with the default project ID prefix. + // Match without any backticks. + cleanFullID := strings.ReplaceAll(fullID, "`", "") + cleanTarget := strings.ReplaceAll(target, "`", "") + if cleanFullID == cleanTarget || strings.HasPrefix(cleanFullID, cleanTarget+".") { + return true, nil + } + // Try matching with the default project ID prefix. if defaultProjectID != "" { - withDefault := strings.ToLower(defaultProjectID + "." + fullID) - if withDefault == target || strings.HasPrefix(withDefault, target+".") { + cleanDefaultProjectID := strings.ReplaceAll(strings.ToLower(defaultProjectID), "`", "") + withDefault := cleanDefaultProjectID + "." + cleanFullID + if withDefault == cleanTarget || strings.HasPrefix(withDefault, cleanTarget+".") { return true, nil } } diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go b/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go index 18855ec88472..50717eade3d1 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go @@ -542,6 +542,94 @@ func TestTableParser(t *testing.T) { want: []string{"proj.ds.tbl", "other.ds.tbl2"}, wantErr: false, }, + { + name: "create schema statement", + sql: "CREATE SCHEMA proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'CREATE SCHEMA' are not allowed", + }, + { + name: "create dataset statement", + sql: "CREATE DATASET proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'CREATE DATASET' are not allowed", + }, + { + name: "drop schema statement", + sql: "DROP SCHEMA proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'DROP SCHEMA' are not allowed", + }, + { + name: "drop dataset statement", + sql: "DROP DATASET proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'DROP DATASET' are not allowed", + }, + { + name: "alter schema statement", + sql: "ALTER SCHEMA proj.data SET OPTIONS(description='new one')", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'ALTER SCHEMA' are not allowed", + }, + { + name: "alter dataset statement", + sql: "ALTER DATASET proj.data SET OPTIONS(description='new one')", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'ALTER DATASET' are not allowed", + }, + { + name: "call fully qualified procedure", + sql: "CALL proj.data.proc()", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "CALL is not allowed when dataset restrictions are in place", + }, + { + name: "create procedure statement", + sql: "CREATE PROCEDURE proj.data.proc() BEGIN SELECT 1; END", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed", + }, + { + name: "create or replace procedure statement", + sql: "CREATE OR REPLACE PROCEDURE proj.data.proc() BEGIN SELECT 1; END", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "unanalyzable statements like 'CREATE OR REPLACE PROCEDURE' are not allowed", + }, + { + name: "create function statement", + sql: "CREATE FUNCTION proj.data.func() RETURNS INT64 AS (1)", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "unanalyzable statements like 'CREATE FUNCTION' are not allowed", + }, + { + name: "simple execute immediate", + sql: "EXECUTE IMMEDIATE 'SELECT 1'", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "EXECUTE IMMEDIATE is not allowed", + }, } for _, tc := range testCases { diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 570aa695d444..1fcf55c29101 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -82,16 +82,6 @@ func ValidateQueryAgainstAllowedDatasets( if dryRunJob.Statistics == nil || dryRunJob.Statistics.Query == nil { return nil, fmt.Errorf("dry run failed to return query statistics") } - statementType := dryRunJob.Statistics.Query.StatementType - // Common restricted operations - switch statementType { - case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": - return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) - case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE": - return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) - case "CALL": - return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) - } // Use a map to avoid duplicate table names from the dry run result. tableIDSet := make(map[string]struct{}) diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 29e1f22a38e3..6ed92ca64059 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -36,7 +36,6 @@ import ( "github.com/googleapis/genai-toolbox/tests" "golang.org/x/oauth2/google" "google.golang.org/api/googleapi" - "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -87,6 +86,10 @@ func TestBigQueryToolEndpoints(t *testing.T) { // create table name with UUID datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) + + cleanupDatasets := ensureTeardownDatasets(ctx, client, datasetName) + defer cleanupDatasets(t) + tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) tableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, @@ -218,6 +221,10 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { allowedDatasetName1 := fmt.Sprintf("allowed_dataset_1_%s", baseName) allowedDatasetName2 := fmt.Sprintf("allowed_dataset_2_%s", baseName) disallowedDatasetName := fmt.Sprintf("disallowed_dataset_%s", baseName) + + cleanupDatasets := ensureTeardownDatasets(ctx, client, allowedDatasetName1, allowedDatasetName2, disallowedDatasetName) + defer cleanupDatasets(t) + allowedTableName1 := "allowed_table_1" allowedTableName2 := "allowed_table_2" disallowedTableName := "disallowed_table" @@ -285,19 +292,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { // Create Forecast views for _, dsName := range []string{allowedDatasetName1, allowedDatasetName2} { - if err := client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedForecastName).Create(ctx, &bigqueryapi.TableMetadata{ - ViewQuery: fmt.Sprintf("SELECT * FROM %s", disallowedForecastTableFullName), - }); err != nil { - t.Fatalf("failed to create forecast view in %s: %v", dsName, err) - } - defer client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedForecastName).Delete(ctx) + teardownForecastView := setupBigQueryView(t, ctx, client, dsName, viewInAllowedPointingToDisallowedForecastName, fmt.Sprintf("SELECT * FROM %s", disallowedForecastTableFullName)) + defer teardownForecastView(t) - if err := client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedAnalyzeName).Create(ctx, &bigqueryapi.TableMetadata{ - ViewQuery: fmt.Sprintf("SELECT * FROM %s", disallowedAnalyzeContributionTableFullName), - }); err != nil { - t.Fatalf("failed to create analyze view in %s: %v", dsName, err) - } - defer client.Dataset(dsName).Table(viewInAllowedPointingToDisallowedAnalyzeName).Delete(ctx) + teardownAnalyzeView := setupBigQueryView(t, ctx, client, dsName, viewInAllowedPointingToDisallowedAnalyzeName, fmt.Sprintf("SELECT * FROM %s", disallowedAnalyzeContributionTableFullName)) + defer teardownAnalyzeView(t) } // Authorize ALL views to access the disallowed dataset @@ -660,6 +659,29 @@ func getBigQueryTmplToolStatement() (string, string) { return tmplSelectCombined, tmplSelectFilterCombined } +func ensureTeardownDatasets(ctx context.Context, client *bigqueryapi.Client, datasetNames ...string) func(*testing.T) { + return func(t *testing.T) { + for _, dsName := range datasetNames { + if err := client.Dataset(dsName).DeleteWithContents(ctx); err != nil { + t.Logf("failed to cleanup dataset %s: %v", dsName, err) + } + } + } +} + +func setupBigQueryView(t *testing.T, ctx context.Context, client *bigqueryapi.Client, datasetName, viewName, query string) func(*testing.T) { + if err := client.Dataset(datasetName).Table(viewName).Create(ctx, &bigqueryapi.TableMetadata{ + ViewQuery: query, + }); err != nil { + t.Fatalf("failed to create view %s in %s: %v", viewName, datasetName, err) + } + return func(t *testing.T) { + if err := client.Dataset(datasetName).Table(viewName).Delete(ctx); err != nil { + t.Errorf("failed to delete view %s in %s: %v", viewName, datasetName, err) + } + } +} + func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) { // Create dataset dataset := client.Dataset(datasetName) @@ -723,19 +745,6 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C if err := dropStatus.Err(); err != nil { t.Errorf("Error dropping table %s: %v", tableName, err) } - - // tear down dataset - datasetToTeardown := client.Dataset(datasetName) - tablesIterator := datasetToTeardown.Tables(ctx) - _, err = tablesIterator.Next() - - if err == iterator.Done { - if err := datasetToTeardown.Delete(ctx); err != nil { - t.Errorf("Failed to delete dataset %s: %v", datasetName, err) - } - } else if err != nil { - t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err) - } } } @@ -2796,28 +2805,27 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed name: "invoke on disallowed table", sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", - strings.Join( - strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2], - ".")), + wantInError: fmt.Sprintf("access to dataset '%s.%s' is not allowed", + strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0], + strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[1]), }, { name: "disallowed create schema", sql: "CREATE SCHEMA another_dataset", wantStatusCode: http.StatusOK, - wantInError: "dataset-level operations like 'CREATE_SCHEMA' are not allowed", + wantInError: "dataset-level operations like 'CREATE SCHEMA' are not allowed", }, { name: "disallowed alter schema", sql: fmt.Sprintf("ALTER SCHEMA %s SET OPTIONS(description='new one')", allowedDatasetID), wantStatusCode: http.StatusOK, - wantInError: "dataset-level operations like 'ALTER_SCHEMA' are not allowed", + wantInError: "dataset-level operations like 'ALTER SCHEMA' are not allowed", }, { name: "disallowed create function", sql: fmt.Sprintf("CREATE FUNCTION %s.my_func() RETURNS INT64 AS (1)", allowedDatasetID), wantStatusCode: http.StatusOK, - wantInError: "creating stored routines ('CREATE_FUNCTION') is not allowed", + wantInError: "unanalyzable statements like 'CREATE FUNCTION' are not allowed", }, { name: "disallowed create procedure", @@ -2829,7 +2837,7 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed name: "disallowed execute immediate", sql: "EXECUTE IMMEDIATE 'SELECT 1'", wantStatusCode: http.StatusOK, - wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place", + wantInError: "EXECUTE IMMEDIATE is not allowed", }, } @@ -3160,7 +3168,7 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa name: "invoke with query on disallowed table", historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("access to dataset '%s' is not allowed", disallowedDatasetFQN), }, } @@ -3267,7 +3275,7 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d name: "invoke with query on disallowed table", inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("access to dataset '%s' is not allowed", disallowedDatasetFQN), }, } From 8de9da3192b54bb1ca9c536a7b6e4988f3e8a323 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Wed, 25 Feb 2026 18:01:45 +0000 Subject: [PATCH 04/15] update --- .../tools/bigquery/bigquery-execute-sql.md | 8 ++- .../bigquerycommon/table_name_parser.go | 62 ++++++++++++++----- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/docs/en/resources/tools/bigquery/bigquery-execute-sql.md b/docs/en/resources/tools/bigquery/bigquery-execute-sql.md index 15057c7e27ae..f748999f4cd5 100644 --- a/docs/en/resources/tools/bigquery/bigquery-execute-sql.md +++ b/docs/en/resources/tools/bigquery/bigquery-execute-sql.md @@ -39,9 +39,11 @@ layer of security by controlling which datasets can be accessed: - **Without `allowedDatasets` restriction:** The tool can execute any valid GoogleSQL query. -- **With `allowedDatasets` restriction:** The tool analyzes the query before execution to ensure that it only accesses the allowed datasets. - This check also supports authorized views by validating direct references against the allowed list. - To enforce this restriction, the following operations are also disallowed: +- **With `allowedDatasets` restriction:** Before execution, the tool performs a + dry run to analyze the query. + It will reject the query if it attempts to access any table outside the + allowed `datasets` list. To enforce this restriction, the following operations + are also disallowed: - **Dataset-level operations** (e.g., `CREATE SCHEMA`, `ALTER SCHEMA`). - **Unanalyzable operations** where the accessed tables cannot be determined diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index e007edad9479..6b5205f112a3 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -20,47 +20,79 @@ import ( "unicode" ) +// parserState defines the state of the SQL parser's state machine. type parserState int const ( stateNormal parserState = iota + // String states stateInSingleQuoteString stateInDoubleQuoteString stateInTripleSingleQuoteString stateInTripleDoubleQuoteString - stateInSingleLineCommentDash - stateInSingleLineCommentHash - stateInMultiLineComment stateInRawSingleQuoteString stateInRawDoubleQuoteString stateInRawTripleSingleQuoteString stateInRawTripleDoubleQuoteString + // Comment states + stateInSingleLineCommentDash + stateInSingleLineCommentHash + stateInMultiLineComment +) + +// SQL statement verbs +const ( + verbCreate = "create" + verbAlter = "alter" + verbDrop = "drop" + verbSelect = "select" + verbInsert = "insert" + verbUpdate = "update" + verbDelete = "delete" + verbMerge = "merge" ) var tableFollowsKeywords = map[string]bool{ "from": true, "join": true, - "into": true, + "into": true, // INSERT INTO, MERGE INTO "update": true, - "table": true, - "using": true, - "insert": true, - "merge": true, + "table": true, // CREATE TABLE, ALTER TABLE + "using": true, // MERGE ... USING + "insert": true, // INSERT my_table + "merge": true, // MERGE my_table } var tableContextExitKeywords = map[string]bool{ "where": true, - "group": true, - "order": true, + "group": true, // GROUP BY + "order": true, // ORDER BY "having": true, "limit": true, "window": true, "union": true, "intersect": true, "except": true, - "on": true, - "set": true, - "when": true, + "on": true, // JOIN ... ON + "set": true, // UPDATE ... SET + "when": true, // MERGE ... WHEN +} + +var sqlStatementVerbs = map[string]bool{ + verbCreate: true, + verbAlter: true, + verbDrop: true, + verbSelect: true, + verbInsert: true, + verbUpdate: true, + verbDelete: true, + verbMerge: true, +} + +var schemaOperationVerbs = map[string]bool{ + verbCreate: true, + verbAlter: true, + verbDrop: true, } // hasPrefix checks if the runes starting at offset match the given prefix. @@ -270,7 +302,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi if keyword == "call" { return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place") } - if (statementVerb == "create" || statementVerb == "alter" || statementVerb == "drop") && + if schemaOperationVerbs[statementVerb] && (keyword == "schema" || keyword == "dataset") { return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed", strings.ToUpper(statementVerb), strings.ToUpper(keyword)) } @@ -364,7 +396,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi lastToken = keyword } // Also track statement verb for schema checks - if keyword == "select" || keyword == "insert" || keyword == "update" || keyword == "delete" || keyword == "merge" || keyword == "create" || keyword == "alter" || keyword == "drop" { + if sqlStatementVerbs[keyword] { if statementVerb == "" || statementVerb == "with" { statementVerb = keyword } From 36c12ee4ac7a6dba7d71284193a88f4a2bb6e2ba Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Wed, 25 Feb 2026 18:10:14 +0000 Subject: [PATCH 05/15] update --- internal/tools/bigquery/bigquerycommon/table_name_parser.go | 2 +- src/google-cloud-bigquery-storage | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) delete mode 160000 src/google-cloud-bigquery-storage diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index 6b5205f112a3..87dabf78adf6 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -55,8 +55,8 @@ const ( var tableFollowsKeywords = map[string]bool{ "from": true, "join": true, - "into": true, // INSERT INTO, MERGE INTO "update": true, + "into": true, // INSERT INTO, MERGE INTO "table": true, // CREATE TABLE, ALTER TABLE "using": true, // MERGE ... USING "insert": true, // INSERT my_table diff --git a/src/google-cloud-bigquery-storage b/src/google-cloud-bigquery-storage deleted file mode 160000 index 316faaff534c..000000000000 --- a/src/google-cloud-bigquery-storage +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 316faaff534c3a989d92ff78efd1cfaa0e45a10b From 3192c3769568a0b643dd8916d7162b00d52df0ce Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 2 Jun 2026 00:07:20 +0000 Subject: [PATCH 06/15] improve dataset restriction reporting and increase test timeout --- .../tools/bigquery/bigquerycommon/util.go | 30 +++++++++++++++++-- tests/bigquery/bigquery_integration_test.go | 6 ++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 891c9d06394f..c6cbc84168ed 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -147,7 +147,20 @@ func ValidateQueryAgainstAllowedDatasets( return nil, fmt.Errorf("failed to analyze query for explicit table references: %w", err) } if explicitlyReferenced { - return nil, fmt.Errorf("access to dataset '%s' is not allowed", strings.Join(strings.Split(violatingTables[0], ".")[:2], ".")) + violatingDatasets := []string{} + seenDatasets := make(map[string]struct{}) + for _, tableID := range violatingTables { + datasetFQN := strings.Join(strings.Split(tableID, ".")[:2], ".") + if _, seen := seenDatasets[datasetFQN]; !seen { + violatingDatasets = append(violatingDatasets, fmt.Sprintf("'%s'", datasetFQN)) + seenDatasets[datasetFQN] = struct{}{} + } + } + plural := "" + if len(violatingDatasets) > 1 { + plural = "s" + } + return nil, fmt.Errorf("access to dataset%s %s is not allowed", plural, strings.Join(violatingDatasets, ", ")) } } @@ -157,14 +170,27 @@ func ValidateQueryAgainstAllowedDatasets( return nil, fmt.Errorf("could not safely analyze query with dataset restrictions: %w", parseErr) } + var parsedViolatingDatasets []string + seenParsedDatasets := make(map[string]struct{}) for _, tableID := range parsedTables { parts := strings.Split(tableID, ".") if len(parts) == 3 { if !validator.IsDatasetAllowed(parts[0], parts[1]) { - return nil, fmt.Errorf("access to dataset '%s.%s' is not allowed", parts[0], parts[1]) + datasetFQN := fmt.Sprintf("%s.%s", parts[0], parts[1]) + if _, seen := seenParsedDatasets[datasetFQN]; !seen { + parsedViolatingDatasets = append(parsedViolatingDatasets, fmt.Sprintf("'%s'", datasetFQN)) + seenParsedDatasets[datasetFQN] = struct{}{} + } } } } + if len(parsedViolatingDatasets) > 0 { + plural := "" + if len(parsedViolatingDatasets) > 1 { + plural = "s" + } + return nil, fmt.Errorf("access to dataset%s %s is not allowed", plural, strings.Join(parsedViolatingDatasets, ", ")) + } return dryRunJob, nil } diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 6bfdfd7a2613..64749ac863fd 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -222,7 +222,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { func TestBigQueryToolWithDatasetRestriction(t *testing.T) { uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "") t.Logf("Starting restriction test with uniqueID: %s", uniqueID) - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() client, err := initBigQueryConnection(BigqueryProject) @@ -426,7 +426,7 @@ func TestBigQueryWriteModeAllowed(t *testing.T) { sourceConfig := getBigQueryVars(t) sourceConfig["writeMode"] = "allowed" - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() datasetName := fmt.Sprintf("temp_toolbox_test_allowed_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) @@ -482,7 +482,7 @@ func TestBigQueryWriteModeBlocked(t *testing.T) { sourceConfig := getBigQueryVars(t) sourceConfig["writeMode"] = "blocked" - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() datasetName := fmt.Sprintf("temp_toolbox_test_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) From ff4495558f22ecf14543880def03f7da47f726e1 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 2 Jun 2026 00:41:56 +0000 Subject: [PATCH 07/15] improve bigquery parser and optimize test cleanup --- .../bigquerycommon/table_name_parser.go | 37 +++++++++++++++---- .../tools/bigquery/bigquerycommon/util.go | 4 +- tests/bigquery/bigquery_integration_test.go | 14 +++++-- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index 87dabf78adf6..1c81021251c0 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -488,7 +488,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi func findAndParseSQLString(runes []rune, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, aliases map[string]struct{}) (int, error) { for i := 0; i < len(runes); { if hasPrefix(runes, i, "'''") { - end := strings.Index(string(runes[i+3:]), "'''") + end := indexRunes(runes[i+3:], "'''") if end != -1 { sqlContent := string(runes[i+3 : i+3+end]) if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { @@ -498,7 +498,7 @@ func findAndParseSQLString(runes []rune, defaultProjectID string, tableIDSet map } } if hasPrefix(runes, i, `"""`) { - end := strings.Index(string(runes[i+3:]), `"""`) + end := indexRunes(runes[i+3:], `"""`) if end != -1 { sqlContent := string(runes[i+3 : i+3+end]) if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { @@ -727,12 +727,12 @@ func parseIdentifierSequence(runes []rune) ([]string, int, error) { totalConsumed++ } if hasPrefix(runes, totalConsumed, "/*") { - endIdx := strings.Index(string(runes[totalConsumed:]), "*/") + endIdx := indexRunes(runes[totalConsumed:], "*/") if endIdx != -1 { totalConsumed += endIdx + 2 } } else if hasPrefix(runes, totalConsumed, "--") || (totalConsumed < len(runes) && runes[totalConsumed] == '#') { - endIdx := strings.Index(string(runes[totalConsumed:]), "\n") + endIdx := indexRunes(runes[totalConsumed:], "\n") if endIdx != -1 { totalConsumed += endIdx + 1 } else { @@ -751,7 +751,7 @@ func parseIdentifierSequence(runes []rune) ([]string, int, error) { var consumed int if runes[totalConsumed] == '`' { - end := strings.Index(string(runes[totalConsumed+1:]), "`") + end := indexRunes(runes[totalConsumed+1:], "`") if end == -1 { return nil, 0, fmt.Errorf("unclosed backtick identifier") } @@ -778,12 +778,12 @@ func parseIdentifierSequence(runes []rune) ([]string, int, error) { totalConsumed++ } if hasPrefix(runes, totalConsumed, "/*") { - endIdx := strings.Index(string(runes[totalConsumed:]), "*/") + endIdx := indexRunes(runes[totalConsumed:], "*/") if endIdx != -1 { totalConsumed += endIdx + 2 } } else if hasPrefix(runes, totalConsumed, "--") || (totalConsumed < len(runes) && runes[totalConsumed] == '#') { - endIdx := strings.Index(string(runes[totalConsumed:]), "\n") + endIdx := indexRunes(runes[totalConsumed:], "\n") if endIdx != -1 { totalConsumed += endIdx + 1 } else { @@ -805,6 +805,9 @@ func parseIdentifierSequence(runes []rune) ([]string, int, error) { } func formatTableID(parts []string, defaultProjectID string) (string, error) { + if len(parts) == 4 && strings.Contains(parts[1], ":") { + parts = []string{parts[0] + "." + parts[1], parts[2], parts[3]} + } if len(parts) < 2 || len(parts) > 3 { // Not a table identifier (could be a CTE, column, etc.). return "", nil @@ -820,3 +823,23 @@ func formatTableID(parts []string, defaultProjectID string) (string, error) { } return fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, ".")), nil } + +func indexRunes(r []rune, sub string) int { + subRunes := []rune(sub) + if len(subRunes) == 0 { + return 0 + } + for i := 0; i <= len(r)-len(subRunes); i++ { + match := true + for j := 0; j < len(subRunes); j++ { + if r[i+j] != subRunes[j] { + match = false + break + } + } + if match { + return i + } + } + return -1 +} diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index c6cbc84168ed..2c64f581d0b1 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -116,7 +116,9 @@ func ValidateQueryAgainstAllowedDatasets( queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + if tableRef != nil { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } } if tableRef := queryStats.DdlTargetTable; tableRef != nil { tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 64749ac863fd..736d3482fa86 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -685,8 +685,10 @@ func getBigQueryTmplToolStatement() (string, string) { func ensureTeardownDatasets(ctx context.Context, client *bigqueryapi.Client, datasetNames ...string) func(*testing.T) { return func(t *testing.T) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() for _, dsName := range datasetNames { - if err := client.Dataset(dsName).DeleteWithContents(ctx); err != nil { + if err := client.Dataset(dsName).DeleteWithContents(cleanupCtx); err != nil { t.Logf("failed to cleanup dataset %s: %v", dsName, err) } } @@ -700,7 +702,9 @@ func setupBigQueryView(t *testing.T, ctx context.Context, client *bigqueryapi.Cl t.Fatalf("failed to create view %s in %s: %v", viewName, datasetName, err) } return func(t *testing.T) { - if err := client.Dataset(datasetName).Table(viewName).Delete(ctx); err != nil { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + if err := client.Dataset(datasetName).Table(viewName).Delete(cleanupCtx); err != nil { t.Errorf("failed to delete view %s in %s: %v", viewName, datasetName, err) } } @@ -754,14 +758,16 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C } return func(t *testing.T) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() // tear down table dropSQL := fmt.Sprintf("drop table %s", tableName) - dropJob, err := client.Query(dropSQL).Run(ctx) + dropJob, err := client.Query(dropSQL).Run(cleanupCtx) if err != nil { t.Errorf("Failed to start drop table job for %s: %v", tableName, err) return } - dropStatus, err := dropJob.Wait(ctx) + dropStatus, err := dropJob.Wait(cleanupCtx) if err != nil { t.Errorf("Failed to wait for drop table job for %s: %v", tableName, err) return From c1033af6a157c53b60a9648c932c721b26a8378d Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 2 Jun 2026 03:38:50 +0000 Subject: [PATCH 08/15] fix cleanup race and parser bugs in bigquery integration --- tests/common.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/common.go b/tests/common.go index db2e0deb812c..a4fccea9105a 100644 --- a/tests/common.go +++ b/tests/common.go @@ -32,6 +32,7 @@ import ( "github.com/googleapis/mcp-toolbox/internal/testutils" "github.com/googleapis/mcp-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" + "google.golang.org/api/googleapi" "google.golang.org/api/iterator" ) @@ -1093,16 +1094,27 @@ func CleanupBigQueryDatasets(t *testing.T, ctx context.Context, client *bigquery break } if err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted (during table iteration)", id) + break + } t.Errorf("INTEGRATION CLEANUP: Failed to iterate tables in %s: %v", id, err) break } if err := table.Delete(ctx); err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + continue + } t.Errorf("INTEGRATION CLEANUP: Failed to delete table %s: %v", table.TableID, err) } } //delete empty dataset if err := ds.Delete(ctx); err != nil { - t.Errorf("INTEGRATION CLEANUP: Failed to delete dataset %s: %v", id, err) + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted", id) + } else { + t.Errorf("INTEGRATION CLEANUP: Failed to delete dataset %s: %v", id, err) + } } else { t.Logf("INTEGRATION CLEANUP SUCCESS: Wiped dataset %s", id) } From d6e799def85226a3d73658eb57842b5e760b21f9 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Tue, 9 Jun 2026 21:40:50 +0000 Subject: [PATCH 09/15] fix formatting in bigquerycommon utils --- internal/tools/bigquery/bigquerycommon/util.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 43aa7b5d111c..03afffb6071d 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -21,9 +21,9 @@ import ( "sort" "strings" + bigqueryapi "cloud.google.com/go/bigquery" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/parameters" - bigqueryapi "cloud.google.com/go/bigquery" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) From 017409ce4356b7a9907a9d1b232133516fa05408 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 25 Jun 2026 21:06:54 +0000 Subject: [PATCH 10/15] relocate CleanupBigQueryDatasets helper to bigquery integration tests --- tests/bigquery/bigquery_integration_test.go | 46 ++++++++++++++++++++- tests/common.go | 42 ------------------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 56851134720f..5edad34af1c6 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -36,6 +36,7 @@ import ( "github.com/googleapis/mcp-toolbox/tests" "golang.org/x/oauth2/google" "google.golang.org/api/googleapi" + "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -128,7 +129,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { // global cleanup for this test run t.Cleanup(func() { - tests.CleanupBigQueryDatasets(t, context.Background(), client, []string{datasetName}) + CleanupBigQueryDatasets(t, context.Background(), client, []string{datasetName}) }) // set up data for param tool @@ -260,7 +261,7 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { // global cleanup for this test run t.Cleanup(func() { - tests.CleanupBigQueryDatasets(t, context.Background(), client, []string{allowedDatasetName1, allowedDatasetName2, disallowedDatasetName}) + CleanupBigQueryDatasets(t, context.Background(), client, []string{allowedDatasetName1, allowedDatasetName2, disallowedDatasetName}) }) // Setup allowed table @@ -3308,3 +3309,44 @@ func getBigQueryVectorSearchStmts(vectorTableName string) (string, string) { searchStmt := fmt.Sprintf("SELECT id, content, ML.DISTANCE(embedding, @query, 'COSINE') AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName) return insertStmt, searchStmt } + +func CleanupBigQueryDatasets(t *testing.T, ctx context.Context, client *bigqueryapi.Client, datasetIDs []string) { + for _, id := range datasetIDs { + t.Logf("INTEGRATION CLEANUP: Purging dataset %s", id) + ds := client.Dataset(id) + + // Delete tables first since Dataset.Delete fails if not empty + tableIt := ds.Tables(ctx) + for { + table, err := tableIt.Next() + if err == iterator.Done { + break + } + if err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted (during table iteration)", id) + break + } + t.Errorf("INTEGRATION CLEANUP: Failed to iterate tables in %s: %v", id, err) + break + } + if err := table.Delete(ctx); err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + continue + } + t.Errorf("INTEGRATION CLEANUP: Failed to delete table %s: %v", table.TableID, err) + } + } + // delete empty dataset + if err := ds.Delete(ctx); err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted", id) + } else { + t.Errorf("INTEGRATION CLEANUP: Failed to delete dataset %s: %v", id, err) + } + } else { + t.Logf("INTEGRATION CLEANUP SUCCESS: Wiped dataset %s", id) + } + } +} + diff --git a/tests/common.go b/tests/common.go index 2f26919fe234..4cac4c5fc67d 100644 --- a/tests/common.go +++ b/tests/common.go @@ -24,7 +24,6 @@ import ( "strings" "testing" - "cloud.google.com/go/bigquery" "cloud.google.com/go/bigtable" "github.com/google/go-cmp/cmp" "github.com/googleapis/mcp-toolbox/internal/server" @@ -32,8 +31,6 @@ import ( "github.com/googleapis/mcp-toolbox/internal/testutils" "github.com/googleapis/mcp-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" - "google.golang.org/api/googleapi" - "google.golang.org/api/iterator" ) // GetToolsConfig returns a mock tools config file @@ -1091,45 +1088,6 @@ func CleanupMSSQLTables(t *testing.T, ctx context.Context, pool *sql.DB) { } -func CleanupBigQueryDatasets(t *testing.T, ctx context.Context, client *bigquery.Client, datasetIDs []string) { - for _, id := range datasetIDs { - t.Logf("INTEGRATION CLEANUP: Purging dataset %s", id) - ds := client.Dataset(id) - - //Delete tables first since Dataset.Delete fails if not empty - tableIt := ds.Tables(ctx) - for { - table, err := tableIt.Next() - if err == iterator.Done { - break - } - if err != nil { - if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { - t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted (during table iteration)", id) - break - } - t.Errorf("INTEGRATION CLEANUP: Failed to iterate tables in %s: %v", id, err) - break - } - if err := table.Delete(ctx); err != nil { - if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { - continue - } - t.Errorf("INTEGRATION CLEANUP: Failed to delete table %s: %v", table.TableID, err) - } - } - //delete empty dataset - if err := ds.Delete(ctx); err != nil { - if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { - t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted", id) - } else { - t.Errorf("INTEGRATION CLEANUP: Failed to delete dataset %s: %v", id, err) - } - } else { - t.Logf("INTEGRATION CLEANUP SUCCESS: Wiped dataset %s", id) - } - } -} // finds and deletes all tables in a Bigtable instance that match the uniqueID. func CleanupBigtableTables(t *testing.T, ctx context.Context, adminClient *bigtable.AdminClient, uniqueID string) { From 1b854489cdce30359e8df8a9fc67eb936662d4ff Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 25 Jun 2026 21:08:58 +0000 Subject: [PATCH 11/15] remove redundant error wrappers from bigquerycommon utility --- internal/tools/bigquery/bigquerycommon/util.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index d1e9a47a2093..0b7fac00e800 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -22,7 +22,6 @@ import ( "strings" bigqueryapi "cloud.google.com/go/bigquery" - "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) @@ -316,17 +315,3 @@ func InitializeDatasetParameters( return projectParam, datasetParam } -// ProcessGcpError converts a Google API error into a user-friendly error. -func ProcessGcpError(err error) util.ToolboxError { - return util.ProcessGcpError(err) -} - -// NewAgentError returns a new AgentError. -func NewAgentError(message string, err error) util.ToolboxError { - return util.NewAgentError(message, err) -} - -// NewClientServerError returns a new ClientServerError. -func NewClientServerError(message string, statusCode int, err error) util.ToolboxError { - return util.NewClientServerError(message, statusCode, err) -} From 791d4a8c4e6da0f446f63b407f8c063cb918f6ef Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 25 Jun 2026 21:18:32 +0000 Subject: [PATCH 12/15] centralize and structure query validation errors in bigquery tools --- .../bigqueryanalyzecontribution.go | 15 +++++++------ .../tools/bigquery/bigquerycommon/util.go | 21 ++++++++++++------- .../bigqueryexecutesql/bigqueryexecutesql.go | 7 ++++--- .../bigqueryforecast/bigqueryforecast.go | 15 +++++++------ 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 74d472884bb0..a69bb8fd5c99 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -188,9 +188,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - dryRunJob, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source, source.GetMaximumBytesBilled(), false) - if err != nil { - return nil, util.ProcessGcpError(err) + var dryRunJob *bigqueryrestapi.Job + var validationErr util.ToolboxError + dryRunJob, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr } if dryRunJob.Statistics.Query.StatementType != "SELECT" { return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", dryRunJob.Statistics.Query.StatementType), nil) @@ -247,9 +249,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if len(source.BigQueryAllowedDatasets()) > 0 { - _, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, createModelSQL, nil, createModelQuery.ConnectionProperties, source, source.GetMaximumBytesBilled(), true) - if err != nil { - return nil, util.ProcessGcpError(err) + var validationErr util.ToolboxError + _, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, createModelSQL, nil, createModelQuery.ConnectionProperties, source, source.GetMaximumBytesBilled(), true) + if validationErr != nil { + return nil, validationErr } } diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 0b7fac00e800..90b18729f431 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -16,13 +16,16 @@ package bigquerycommon import ( "context" + "errors" "fmt" "regexp" "sort" "strings" bigqueryapi "cloud.google.com/go/bigquery" + "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/parameters" + "google.golang.org/api/googleapi" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) @@ -130,14 +133,18 @@ func ValidateQueryAgainstAllowedDatasets( validator DatasetValidator, maximumBytesBilled int64, createSession bool, -) (*bigqueryrestapi.Job, error) { +) (*bigqueryrestapi.Job, util.ToolboxError) { dryRunJob, err := DryRunQuery(ctx, restService, projectID, location, sql, params, connProps, maximumBytesBilled, createSession) if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) + var gErr *googleapi.Error + if errors.As(err, &gErr) { + return nil, util.ProcessGcpError(err) + } + return nil, util.NewAgentError("query validation failed", err) } if dryRunJob.Statistics == nil || dryRunJob.Statistics.Query == nil { - return nil, fmt.Errorf("dry run failed to return query statistics") + return nil, util.NewAgentError("dry run failed to return query statistics", nil) } // Use a map to avoid duplicate table names from the dry run result. @@ -176,7 +183,7 @@ func ValidateQueryAgainstAllowedDatasets( if len(violatingTables) > 0 { explicitlyReferenced, err := IsAnyTableExplicitlyReferenced(sql, projectID, violatingTables) if err != nil { - return nil, fmt.Errorf("failed to analyze query for explicit table references: %w", err) + return nil, util.NewAgentError("failed to analyze query for explicit table references", err) } if explicitlyReferenced { violatingDatasets := []string{} @@ -192,14 +199,14 @@ func ValidateQueryAgainstAllowedDatasets( if len(violatingDatasets) > 1 { plural = "s" } - return nil, fmt.Errorf("access to dataset%s %s is not allowed", plural, strings.Join(violatingDatasets, ", ")) + return nil, util.NewAgentError(fmt.Sprintf("access to dataset%s %s is not allowed", plural, strings.Join(violatingDatasets, ", ")), nil) } } // Fall back to TableParser for final intent verification or if dry run was inconclusive. parsedTables, parseErr := TableParser(sql, projectID) if parseErr != nil { - return nil, fmt.Errorf("could not safely analyze query with dataset restrictions: %w", parseErr) + return nil, util.NewAgentError("could not safely analyze query with dataset restrictions", parseErr) } var parsedViolatingDatasets []string @@ -224,7 +231,7 @@ func ValidateQueryAgainstAllowedDatasets( if len(parsedViolatingDatasets) > 1 { plural = "s" } - return nil, fmt.Errorf("access to dataset%s %s is not allowed", plural, strings.Join(parsedViolatingDatasets, ", ")) + return nil, util.NewAgentError(fmt.Sprintf("access to dataset%s %s is not allowed", plural, strings.Join(parsedViolatingDatasets, ", ")), nil) } return dryRunJob, nil diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index c4f7b78ccb91..6f930ffdf8e4 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -140,9 +140,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var dryRunJob *bigqueryrestapi.Job if len(source.BigQueryAllowedDatasets()) > 0 { - dryRunJob, err = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) - if err != nil { - return nil, util.ProcessGcpError(err) + var validationErr util.ToolboxError + dryRunJob, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr } } else { dryRunJob, err = bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source.GetMaximumBytesBilled(), false) diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 6a527d94c941..3685250357d8 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -174,9 +174,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - dryRunJob, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source, source.GetMaximumBytesBilled(), false) - if err != nil { - return nil, util.ProcessGcpError(err) + var dryRunJob *bigqueryrestapi.Job + var validationErr util.ToolboxError + dryRunJob, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr } if dryRunJob.Statistics.Query.StatementType != "SELECT" { return nil, util.NewAgentError(fmt.Sprintf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", dryRunJob.Statistics.Query.StatementType), nil) @@ -235,9 +237,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if len(source.BigQueryAllowedDatasets()) > 0 { - _, err := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) - if err != nil { - return nil, util.ProcessGcpError(err) + var validationErr util.ToolboxError + _, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr } } From 0be4333159db9250dd852a02249da1302ada7b81 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 25 Jun 2026 21:21:51 +0000 Subject: [PATCH 13/15] centralize and structure query validation errors in bigquery tools --- .../bigqueryanalyzecontribution.go | 7 ++----- .../tools/bigquery/bigqueryforecast/bigqueryforecast.go | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index a69bb8fd5c99..09b8dab2c193 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -188,9 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - var dryRunJob *bigqueryrestapi.Job - var validationErr util.ToolboxError - dryRunJob, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source, source.GetMaximumBytesBilled(), false) + dryRunJob, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source, source.GetMaximumBytesBilled(), false) if validationErr != nil { return nil, validationErr } @@ -249,8 +247,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if len(source.BigQueryAllowedDatasets()) > 0 { - var validationErr util.ToolboxError - _, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, createModelSQL, nil, createModelQuery.ConnectionProperties, source, source.GetMaximumBytesBilled(), true) + _, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, createModelSQL, nil, createModelQuery.ConnectionProperties, source, source.GetMaximumBytesBilled(), true) if validationErr != nil { return nil, validationErr } diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 3685250357d8..11da1e842b40 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -174,9 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - var dryRunJob *bigqueryrestapi.Job - var validationErr util.ToolboxError - dryRunJob, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source, source.GetMaximumBytesBilled(), false) + dryRunJob, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source, source.GetMaximumBytesBilled(), false) if validationErr != nil { return nil, validationErr } @@ -237,8 +235,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if len(source.BigQueryAllowedDatasets()) > 0 { - var validationErr util.ToolboxError - _, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) + _, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) if validationErr != nil { return nil, validationErr } From cb5b67e1f9c8f5bd6d6ae410b08139a3cf8e9902 Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 25 Jun 2026 21:23:39 +0000 Subject: [PATCH 14/15] add unit tests for bigquery utility and validation functions --- .../bigquerycommon/validators_test.go | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/internal/tools/bigquery/bigquerycommon/validators_test.go b/internal/tools/bigquery/bigquerycommon/validators_test.go index 9f0efae07c70..ffabd553e394 100644 --- a/internal/tools/bigquery/bigquerycommon/validators_test.go +++ b/internal/tools/bigquery/bigquerycommon/validators_test.go @@ -139,3 +139,100 @@ func TestStripSingleQuotes(t *testing.T) { } } } + +func TestValidColumnParam(t *testing.T) { + tcs := []struct { + in string + valid bool + }{ + {"'sales'", true}, + {"sales", true}, + {"'sales_col'", true}, + {"_internal", true}, + {"'1col'", false}, + {"col'", false}, + {"'col", false}, + {"'col; DROP TABLE x'", false}, + {"", false}, + } + for _, tc := range tcs { + if got := bigquerycommon.ValidColumnParam(tc.in); got != tc.valid { + t.Errorf("ValidColumnParam(%q) = %v, want %v", tc.in, got, tc.valid) + } + } +} + +func TestValidContributionMetricParam(t *testing.T) { + tcs := []struct { + in string + valid bool + }{ + {"'metric'", true}, + {"metric", true}, + {"'metric's'", false}, + {"metric's", false}, + {"'metric", false}, + {"''metric''", false}, + } + for _, tc := range tcs { + if got := bigquerycommon.ValidContributionMetricParam(tc.in); got != tc.valid { + t.Errorf("ValidContributionMetricParam(%q) = %v, want %v", tc.in, got, tc.valid) + } + } +} + +func TestIsSystemResource(t *testing.T) { + tcs := []struct { + datasetID string + resourceID string + want bool + }{ + {"AI", "FORECAST", true}, + {"ai", "forecast", true}, + {"AI", "GENERATE_TEXT", true}, + {"AI", "SUMMARIZE", true}, + {"ML", "PREDICT", true}, + {"ml", "predict", true}, + {"ML", "DISTANCE", true}, + {"ML", "GET_INSIGHTS", true}, + {"AI", "INVALID", false}, + {"ML", "INVALID", false}, + {"OTHER", "FORECAST", false}, + } + for _, tc := range tcs { + if got := bigquerycommon.IsSystemResource(tc.datasetID, tc.resourceID); got != tc.want { + t.Errorf("IsSystemResource(%q, %q) = %v, want %v", tc.datasetID, tc.resourceID, got, tc.want) + } + } +} + +func TestBQTypeStringFromToolType(t *testing.T) { + tcs := []struct { + in string + want string + wantErr bool + }{ + {"string", "STRING", false}, + {"integer", "INT64", false}, + {"float", "FLOAT64", false}, + {"boolean", "BOOL", false}, + {"map", "STRUCT", false}, + {"invalid", "", true}, + } + for _, tc := range tcs { + got, err := bigquerycommon.BQTypeStringFromToolType(tc.in) + if tc.wantErr { + if err == nil { + t.Errorf("BQTypeStringFromToolType(%q) expected error, got nil", tc.in) + } + } else { + if err != nil { + t.Errorf("BQTypeStringFromToolType(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Errorf("BQTypeStringFromToolType(%q) = %q, want %q", tc.in, got, tc.want) + } + } + } +} + From 78f82b1a6fcb0394be9359f8e5db197ebce4269d Mon Sep 17 00:00:00 2001 From: Huan Chen Date: Thu, 25 Jun 2026 21:34:13 +0000 Subject: [PATCH 15/15] run gofmt and goimports to fix formatting and imports order --- internal/tools/bigquery/bigquerycommon/util.go | 4 +--- internal/tools/bigquery/bigquerycommon/validators_test.go | 1 - .../tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go | 1 - tests/bigquery/bigquery_integration_test.go | 1 - tests/common.go | 1 - 5 files changed, 1 insertion(+), 7 deletions(-) diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 90b18729f431..4c5720754d7a 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -25,8 +25,8 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/parameters" - "google.golang.org/api/googleapi" bigqueryrestapi "google.golang.org/api/bigquery/v2" + "google.golang.org/api/googleapi" ) // validBQTableID matches BigQuery table identifiers in 'dataset.table' or @@ -178,7 +178,6 @@ func ValidateQueryAgainstAllowedDatasets( } } - // If violations were found, check if they are explicitly in the SQL to support authorized views. if len(violatingTables) > 0 { explicitlyReferenced, err := IsAnyTableExplicitlyReferenced(sql, projectID, violatingTables) @@ -321,4 +320,3 @@ func InitializeDatasetParameters( return projectParam, datasetParam } - diff --git a/internal/tools/bigquery/bigquerycommon/validators_test.go b/internal/tools/bigquery/bigquerycommon/validators_test.go index ffabd553e394..8e8e114578f9 100644 --- a/internal/tools/bigquery/bigquerycommon/validators_test.go +++ b/internal/tools/bigquery/bigquerycommon/validators_test.go @@ -235,4 +235,3 @@ func TestBQTypeStringFromToolType(t *testing.T) { } } } - diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index 6f930ffdf8e4..543f692d34a0 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -171,7 +171,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if dryRun { if dryRunJob != nil { jobJSON, err := json.MarshalIndent(dryRunJob, "", " ") diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 5edad34af1c6..1b8d4fe53e58 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -3349,4 +3349,3 @@ func CleanupBigQueryDatasets(t *testing.T, ctx context.Context, client *bigquery } } } - diff --git a/tests/common.go b/tests/common.go index 4cac4c5fc67d..06504065931a 100644 --- a/tests/common.go +++ b/tests/common.go @@ -1088,7 +1088,6 @@ func CleanupMSSQLTables(t *testing.T, ctx context.Context, pool *sql.DB) { } - // finds and deletes all tables in a Bigtable instance that match the uniqueID. func CleanupBigtableTables(t *testing.T, ctx context.Context, adminClient *bigtable.AdminClient, uniqueID string) { tables, err := adminClient.Tables(ctx)