diff --git a/docs/en/integrations/bigquery/tools/bigquery-execute-sql.md b/docs/en/integrations/bigquery/tools/bigquery-execute-sql.md index e552a01a43cc..a5fa0157e2e8 100644 --- a/docs/en/integrations/bigquery/tools/bigquery-execute-sql.md +++ b/docs/en/integrations/bigquery/tools/bigquery-execute-sql.md @@ -39,6 +39,7 @@ layer of security by controlling which datasets can be accessed: It will reject the query if it attempts to access any table outside the allowed `datasets` list. To enforce this restriction, the following operations are also disallowed: + - **Dataset-level operations** (e.g., `CREATE SCHEMA`, `ALTER SCHEMA`). - **Unanalyzable operations** where the accessed tables cannot be determined statically (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`, `CALL`). diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index d2989586adb3..09b8dab2c193 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -116,7 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, util.NewAgentError(fmt.Sprintf("unable to cast input_data parameter %s", paramsMap["input_data"]), nil) } - bqClient, _, err := source.RetrieveClientAndService(accessToken) + bqClient, restService, err := source.RetrieveClientAndService(accessToken) if err != nil { return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } @@ -176,6 +176,26 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var inputDataSource string trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData)) if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") { + if len(source.BigQueryAllowedDatasets()) > 0 { + var connProps []*bigqueryapi.ConnectionProperty + session, err := source.BigQuerySession()(ctx) + if err != nil { + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) + } + if session != nil { + connProps = []*bigqueryapi.ConnectionProperty{ + {Key: "session_id", Value: session.ID}, + } + } + + dryRunJob, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr + } + if dryRunJob.Statistics.Query.StatementType != "SELECT" { + return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", dryRunJob.Statistics.Query.StatementType), nil) + } + } inputDataSource = fmt.Sprintf("(%s)", inputData) } else { if !bqutil.ValidTableID(inputData) { @@ -225,27 +245,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // If not in protected mode, create a session for this invocation. createModelQuery.CreateSession = true } + if len(source.BigQueryAllowedDatasets()) > 0 { - createModelQuery.DryRun = true - dryRunJob, err := createModelQuery.Run(ctx) - if err != nil { - return nil, util.ProcessGcpError(err) - } - status := dryRunJob.LastStatus() - if status.Statistics != nil { - if qStats, ok := status.Statistics.Details.(*bigqueryapi.QueryStatistics); ok { - for _, tableRef := range qStats.ReferencedTables { - if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { - return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectID, tableRef.DatasetID), nil) - } - } - } else { - return nil, util.NewAgentError("could not get query statistics details during dry run validation", nil) - } - } else { - return nil, util.NewAgentError("could not dry run model creation query to validate allowed datasets", nil) + _, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, createModelSQL, nil, createModelQuery.ConnectionProperties, source, source.GetMaximumBytesBilled(), true) + if validationErr != nil { + return nil, validationErr } - createModelQuery.DryRun = false } createModelJob, err := createModelQuery.Run(ctx) diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution_test.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution_test.go index a73728021ecf..5f2d06b14480 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution_test.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution_test.go @@ -32,6 +32,7 @@ import ( "github.com/googleapis/mcp-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution" "github.com/googleapis/mcp-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/mcp-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/option" ) @@ -319,9 +320,15 @@ func TestInvokeAllowedDatasetsValidation(t *testing.T) { t.Fatalf("failed to create mocked BigQuery client: %v", err) } + restService, err := bigqueryrestapi.NewService(ctx, option.WithEndpoint(mockServer.URL), option.WithoutAuthentication()) + if err != nil { + t.Fatalf("failed to create mocked BigQuery REST service: %v", err) + } + // 3. Define mock source that returns this client and allowed datasets configuration testSrc := &bigquerycommon.MockSource{ Client: bqClient, + Service: restService, AllowedDatasets: []string{"allowed_dataset"}, } @@ -348,7 +355,7 @@ func TestInvokeAllowedDatasetsValidation(t *testing.T) { // 4. Set up parameters data := map[string]any{ - "input_data": "allowed_dataset.my_table", + "input_data": "SELECT * FROM unauthorized_dataset.some_table", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": []any{"dim1"}, @@ -370,7 +377,7 @@ func TestInvokeAllowedDatasetsValidation(t *testing.T) { t.Fatal("expected Invoke to return an error due to out-of-allowlist dataset reference, but got nil") } - expectedErr := "query accesses dataset 'test-project.unauthorized_dataset', which is not in the allowed list" + expectedErr := "access to dataset 'test-project.unauthorized_dataset' is not allowed" if !strings.Contains(err.Error(), expectedErr) { t.Errorf("expected error to contain %q, got: %v", expectedErr, err) } diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser.go b/internal/tools/bigquery/bigquerycommon/table_name_parser.go index f7fbfba8f775..b106947c58d9 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser.go @@ -78,42 +78,110 @@ var tableFollowsKeywords = map[string]bool{ "update": true, "into": true, // INSERT INTO, MERGE INTO "table": true, // CREATE TABLE, ALTER TABLE + "model": true, // ML.GET_INSIGHTS(MODEL ...) + "view": true, // DROP VIEW ... "using": true, // MERGE ... USING "insert": true, // INSERT my_table "merge": true, // MERGE my_table } var tableContextExitKeywords = map[string]bool{ - "where": true, - "group": true, // GROUP BY - "having": true, - "order": true, // ORDER BY - "limit": true, - "window": true, - "on": true, // JOIN ... ON - "set": true, // UPDATE ... SET - "when": true, // MERGE ... WHEN + "where": true, + "group": true, // GROUP BY + "order": true, // ORDER BY + "having": true, + "limit": true, + "window": true, + "union": true, + "intersect": true, + "except": true, + "on": true, // JOIN ... ON + "set": true, // UPDATE ... SET + "when": true, // MERGE ... WHEN } -// TableParser is the main entry point for parsing a SQL string to find all referenced table IDs. -// It handles multi-statement SQL, comments, and recursive parsing of EXECUTE IMMEDIATE statements. +var sqlStatementVerbs = map[string]bool{ + verbCreate: true, + verbAlter: true, + verbDrop: true, + verbSelect: true, + verbInsert: true, + verbUpdate: true, + verbDelete: true, + verbMerge: true, +} + +var schemaOperationVerbs = map[string]bool{ + verbCreate: true, + verbAlter: true, + verbDrop: true, +} + +// hasPrefix checks if the runes starting at offset match the given prefix. +func hasPrefix(r []rune, offset int, prefix string) bool { + if offset+len(prefix) > len(r) { + return false + } + for i := 0; i < len(prefix); i++ { + if r[offset+i] != rune(prefix[i]) { + return false + } + } + return true +} + +// hasPrefixFold checks if the runes starting at offset match the given prefix, ignoring case (ASCII only). +func hasPrefixFold(r []rune, offset int, prefix string) bool { + if offset+len(prefix) > len(r) { + return false + } + for i := 0; i < len(prefix); i++ { + rChar := r[offset+i] + pChar := rune(prefix[i]) + if rChar >= 'A' && rChar <= 'Z' { + rChar += 32 + } + if pChar >= 'A' && pChar <= 'Z' { + pChar += 32 + } + if rChar != pChar { + return false + } + } + return true +} + +// TableParser parses a SQL query and returns a list of table IDs that it references. +// It is intended as a conservative fallback for when a dry run cannot be performed or analyzed. func TableParser(sql, defaultProjectID string) ([]string, error) { tableIDSet := make(map[string]struct{}) visitedSQLs := make(map[string]struct{}) - if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, false); err != nil { + aliases := make(map[string]struct{}) + if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { return nil, err } tableIDs := make([]string, 0, len(tableIDSet)) for id := range tableIDSet { - tableIDs = append(tableIDs, id) + isAlias := false + parts := strings.Split(id, ".") + for j := 0; j < len(parts); j++ { + suffix := strings.ToLower(strings.Join(parts[j:], ".")) + if _, ok := aliases[suffix]; ok { + isAlias = true + break + } + } + if !isAlias { + tableIDs = append(tableIDs, id) + } } return tableIDs, nil } // parseSQL is the core recursive function that processes SQL strings. // It uses a state machine to find table names and recursively parse EXECUTE IMMEDIATE. -func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, inSubquery bool) (int, error) { +func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, aliases map[string]struct{}, inSubquery bool) (int, error) { // Prevent infinite recursion. if _, ok := visitedSQLs[sql]; ok { return len(sql), nil @@ -121,45 +189,53 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi visitedSQLs[sql] = struct{}{} state := stateNormal - expectingTable := false + expectingTable, expectingAlias, expectingCTE := false, false, false var lastTableKeyword, lastToken, statementVerb string runes := []rune(sql) for i := 0; i < len(runes); { char := runes[i] - remaining := sql[i:] switch state { case stateNormal: - if strings.HasPrefix(remaining, "--") { + if hasPrefix(runes, i, "--") { state = stateInSingleLineCommentDash i += 2 continue } - if strings.HasPrefix(remaining, "#") { + if char == '#' { state = stateInSingleLineCommentHash i++ continue } - if strings.HasPrefix(remaining, "/*") { + if hasPrefix(runes, i, "/*") { state = stateInMultiLineComment i += 2 continue } + if char == ',' { + if lastTableKeyword == "from" { + expectingTable = true + expectingAlias = false + } else if statementVerb == "with" { + expectingCTE = true + expectingAlias = false + } + i++ + continue + } if char == '(' { - if expectingTable { - // The subquery starts after '('. - consumed, err := parseSQL(remaining[1:], defaultProjectID, tableIDSet, visitedSQLs, true) + if expectingTable || expectingCTE || lastToken == "as" { + consumed, err := parseSQL(string(runes[i+1:]), defaultProjectID, tableIDSet, visitedSQLs, aliases, true) if err != nil { return 0, err } - // Advance i by the length of the subquery + the opening parenthesis. - // The recursive call returns what it consumed, including the closing parenthesis. i += consumed + 1 - // For most keywords, we expect only one table. `from` can have multiple "tables" (subqueries). if lastTableKeyword != "from" { expectingTable = false } + expectingAlias = true + expectingCTE = false continue } } @@ -168,42 +244,43 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi return i + 1, nil } } - if char == ';' { statementVerb = "" lastToken = "" + expectingTable = false + expectingAlias = false + expectingCTE = false i++ continue - } // Raw strings must be checked before regular strings. - if strings.HasPrefix(remaining, "r'''") || strings.HasPrefix(remaining, "R'''") { + if hasPrefixFold(runes, i, "r'''") { state = stateInRawTripleSingleQuoteString i += 4 continue } - if strings.HasPrefix(remaining, `r"""`) || strings.HasPrefix(remaining, `R"""`) { + if hasPrefixFold(runes, i, `r"""`) { state = stateInRawTripleDoubleQuoteString i += 4 continue } - if strings.HasPrefix(remaining, "r'") || strings.HasPrefix(remaining, "R'") { + if hasPrefixFold(runes, i, "r'") { state = stateInRawSingleQuoteString i += 2 continue } - if strings.HasPrefix(remaining, `r"`) || strings.HasPrefix(remaining, `R"`) { + if hasPrefixFold(runes, i, `r"`) { state = stateInRawDoubleQuoteString i += 2 continue } - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateInTripleSingleQuoteString i += 3 continue } - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateInTripleDoubleQuoteString i += 3 continue @@ -219,8 +296,8 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi continue } - if unicode.IsLetter(char) || char == '`' { - parts, consumed, err := parseIdentifierSequence(remaining) + if unicode.IsLetter(char) || char == '`' || char == '_' { + parts, consumed, err := parseIdentifierSequence(runes[i:]) if err != nil { return 0, err } @@ -229,6 +306,10 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi continue } + keyword := strings.ToLower(parts[0]) + fullID := strings.ToLower(strings.Join(parts, ".")) + + // Check for EXTERNAL_QUERY for _, part := range parts { if strings.EqualFold(part, "EXTERNAL_QUERY") { return 0, fmt.Errorf("EXTERNAL_QUERY is not allowed when dataset restrictions are in place") @@ -255,39 +336,115 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi return 0, fmt.Errorf("invalid INFORMATION_SCHEMA query path %q", strings.Join(parts, ".")) } parts = parts[:infoSchemaIdx+1] + fullID = strings.ToLower(strings.Join(parts, ".")) } - if len(parts) == 1 { - keyword := strings.ToLower(parts[0]) - switch keyword { - case "call": - return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place, as the called procedure's contents cannot be safely analyzed") - case "immediate": - if lastToken == "execute" { - return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place, as its contents cannot be safely analyzed") - } - case "procedure", "function": - if lastToken == "create" || lastToken == "create or replace" { - return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", strings.ToUpper(lastToken), strings.ToUpper(keyword)) - } - case verbCreate, verbAlter, verbDrop, verbSelect, verbInsert, verbUpdate, verbDelete, verbMerge: - if statementVerb == "" { - statementVerb = keyword - } + // Security check for restricted statements + if keyword == "immediate" && lastToken == "execute" { + return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place") + } + if (lastToken == "create" || lastToken == "create or" || lastToken == "create or replace") && + (keyword == "procedure" || keyword == "function" || keyword == "table function") { + tokenToReport := strings.ToUpper(lastToken) + if tokenToReport == "" { + tokenToReport = "CREATE" } + return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", tokenToReport, strings.ToUpper(keyword)) + } + if keyword == "call" { + return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place") + } + if schemaOperationVerbs[statementVerb] && + (keyword == "schema" || keyword == "dataset") { + return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed", strings.ToUpper(statementVerb), strings.ToUpper(keyword)) + } - if statementVerb == verbCreate || statementVerb == verbAlter || statementVerb == verbDrop { - if keyword == "schema" || keyword == "dataset" { - return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed when dataset restrictions are in place", strings.ToUpper(statementVerb), strings.ToUpper(keyword)) - } + if lastToken == "execute" && keyword == "immediate" { + // Found EXECUTE IMMEDIATE. The first expression must be the SQL string. + // Search for the next string literal. + sqlConsumed, err := findAndParseSQLString(runes[i+consumed:], defaultProjectID, tableIDSet, visitedSQLs, aliases) + if err != nil { + return 0, err } + i += consumed + sqlConsumed + lastToken = "execute immediate" + continue + } - if _, ok := tableFollowsKeywords[keyword]; ok { + // Resolve aliases and identify table references. + isKnownAlias := false + if _, ok := aliases[fullID]; ok { + isKnownAlias = true + } + if !isKnownAlias && len(parts) > 1 { + if _, ok := aliases[strings.ToLower(parts[0])]; ok { + isKnownAlias = true + } + } + + if expectingCTE { + aliases[fullID] = struct{}{} + aliases[strings.ToLower(parts[0])] = struct{}{} + expectingCTE = false + } else if expectingAlias { + if len(parts) == 1 && (tableContextExitKeywords[keyword] || tableFollowsKeywords[keyword] || keyword == "select" || keyword == "with") { + expectingAlias = false + } else { + aliases[fullID] = struct{}{} + aliases[strings.ToLower(parts[0])] = struct{}{} + expectingAlias = false + isKnownAlias = true + } + } + + // Re-check aliases after potential registration. + if !isKnownAlias { + if _, ok := aliases[fullID]; ok { + isKnownAlias = true + } + } + + if expectingTable && !isKnownAlias { + if len(parts) >= 2 { + tableID, err := formatTableID(parts, defaultProjectID) + if err != nil { + return 0, err + } + if tableID != "" { + // If it's a system function (AI.FORECAST, etc.), don't treat it as a table. + isSystem := false + p := strings.Split(tableID, ".") + if len(p) == 3 && IsSystemResource(p[1], p[2]) { + isSystem = true + } + if !isSystem { + tableIDSet[tableID] = struct{}{} + } + } + } + // For most keywords, we expect only one table. + if lastTableKeyword != "from" { + expectingTable = false + } + expectingAlias = true + } + if len(parts) == 1 { + if keyword == "with" { + expectingCTE = true + statementVerb = "with" + } else if keyword == "as" { + if statementVerb != "with" { + expectingAlias = true + } + expectingTable = false + } else if _, ok := tableFollowsKeywords[keyword]; ok { expectingTable = true lastTableKeyword = keyword + expectingAlias = false } else if _, ok := tableContextExitKeywords[keyword]; ok { expectingTable = false lastTableKeyword = "" + expectingAlias = false } if lastToken == "create" && keyword == "or" { lastToken = "create or" @@ -296,32 +453,22 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } else { lastToken = keyword } - } else if len(parts) >= 2 { - // This is a multi-part identifier. If we were expecting a table, this is it. - if expectingTable { - tableID, err := formatTableID(parts, defaultProjectID) - if err != nil { - return 0, err - } - if tableID != "" { - tableIDSet[tableID] = struct{}{} - } - // For most keywords, we expect only one table. - if lastTableKeyword != "from" { - expectingTable = false + // Also track statement verb for schema checks + if sqlStatementVerbs[keyword] { + if statementVerb == "" || statementVerb == "with" { + statementVerb = keyword } } + } else { lastToken = "" } - i += consumed continue } i++ - case stateInSingleQuoteString: if char == '\\' { - i += 2 // Skip backslash and the escaped character. + i += 2 continue } if char == '\'' { @@ -330,7 +477,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi i++ case stateInDoubleQuoteString: if char == '\\' { - i += 2 // Skip backslash and the escaped character. + i += 2 continue } if char == '"' { @@ -338,14 +485,14 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInTripleSingleQuoteString: - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateNormal i += 3 } else { i++ } case stateInTripleDoubleQuoteString: - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateNormal i += 3 } else { @@ -357,7 +504,7 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInMultiLineComment: - if strings.HasPrefix(remaining, "*/") { + if hasPrefix(runes, i, "*/") { state = stateNormal i += 2 } else { @@ -374,14 +521,14 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } i++ case stateInRawTripleSingleQuoteString: - if strings.HasPrefix(remaining, "'''") { + if hasPrefix(runes, i, "'''") { state = stateNormal i += 3 } else { i++ } case stateInRawTripleDoubleQuoteString: - if strings.HasPrefix(remaining, `"""`) { + if hasPrefix(runes, i, `"""`) { state = stateNormal i += 3 } else { @@ -389,91 +536,368 @@ func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visi } } } - if inSubquery { return 0, fmt.Errorf("unclosed subquery parenthesis") } - return len(sql), nil + return len(runes), nil +} + +// findAndParseSQLString scans for the first string literal and parses its content as SQL. +func findAndParseSQLString(runes []rune, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, aliases map[string]struct{}) (int, error) { + for i := 0; i < len(runes); { + if hasPrefix(runes, i, "'''") { + end := indexRunes(runes[i+3:], "'''") + if end != -1 { + sqlContent := string(runes[i+3 : i+3+end]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return i + 3 + end + 3, nil + } + } + if hasPrefix(runes, i, `"""`) { + end := indexRunes(runes[i+3:], `"""`) + if end != -1 { + sqlContent := string(runes[i+3 : i+3+end]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return i + 3 + end + 3, nil + } + } + if runes[i] == '\'' { + // Find end of single-quoted string, respecting backslash escapes. + for j := i + 1; j < len(runes); j++ { + if runes[j] == '\\' { + j++ + continue + } + if runes[j] == '\'' { + sqlContent := string(runes[i+1 : j]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return j + 1, nil + } + } + } + if runes[i] == '"' { + for j := i + 1; j < len(runes); j++ { + if runes[j] == '\\' { + j++ + continue + } + if runes[j] == '"' { + sqlContent := string(runes[i+1 : j]) + if _, err := parseSQL(sqlContent, defaultProjectID, tableIDSet, visitedSQLs, aliases, false); err != nil { + return 0, err + } + return j + 1, nil + } + } + } + i++ + } + return len(runes), nil +} + +// IsAnyTableExplicitlyReferenced performs a lexical audit of the SQL to see if any of the target tables +// are explicitly named as identifiers. It correctly ignores names inside comments or strings. +func IsAnyTableExplicitlyReferenced(sql, defaultProjectID string, targetTableIDs []string) (bool, error) { + targets := make(map[string]struct{}) + for _, id := range targetTableIDs { + targets[strings.ToLower(id)] = struct{}{} + } + + runes := []rune(sql) + state := stateNormal + + for i := 0; i < len(runes); { + char := runes[i] + + switch state { + case stateNormal: + if hasPrefix(runes, i, "--") { + state = stateInSingleLineCommentDash + i += 2 + continue + } + if char == '#' { + state = stateInSingleLineCommentHash + i++ + continue + } + if hasPrefix(runes, i, "/*") { + state = stateInMultiLineComment + i += 2 + continue + } + + if unicode.IsLetter(char) || char == '`' || char == '_' { + parts, consumed, err := parseIdentifierSequence(runes[i:]) + if err != nil { + return false, err + } + if consumed > 0 { + fullID := strings.ToLower(strings.Join(parts, ".")) + for target := range targets { + // Exact match or as a prefix for column references. + if fullID == target || strings.HasPrefix(fullID, target+".") { + return true, nil + } + // Match without any backticks. + cleanFullID := strings.ReplaceAll(fullID, "`", "") + cleanTarget := strings.ReplaceAll(target, "`", "") + if cleanFullID == cleanTarget || strings.HasPrefix(cleanFullID, cleanTarget+".") { + return true, nil + } + // Try matching with the default project ID prefix. + if defaultProjectID != "" { + cleanDefaultProjectID := strings.ReplaceAll(strings.ToLower(defaultProjectID), "`", "") + withDefault := cleanDefaultProjectID + "." + cleanFullID + if withDefault == cleanTarget || strings.HasPrefix(withDefault, cleanTarget+".") { + return true, nil + } + } + } + i += consumed + continue + } + } + + // Handle various BigQuery string literal formats. + if hasPrefixFold(runes, i, "r'''") { + state = stateInRawTripleSingleQuoteString + i += 4 + continue + } + if hasPrefixFold(runes, i, `r"""`) { + state = stateInRawTripleDoubleQuoteString + i += 4 + continue + } + if hasPrefixFold(runes, i, "r'") { + state = stateInRawSingleQuoteString + i += 2 + continue + } + if hasPrefixFold(runes, i, `r"`) { + state = stateInRawDoubleQuoteString + i += 2 + continue + } + if hasPrefix(runes, i, "'''") { + state = stateInTripleSingleQuoteString + i += 3 + continue + } + if hasPrefix(runes, i, `"""`) { + state = stateInTripleDoubleQuoteString + i += 3 + continue + } + if char == '\'' { + state = stateInSingleQuoteString + i++ + continue + } + if char == '"' { + state = stateInDoubleQuoteString + i++ + continue + } + + case stateInSingleQuoteString: + if char == '\\' { + i += 2 + continue + } + if char == '\'' { + state = stateNormal + } + case stateInDoubleQuoteString: + if char == '\\' { + i += 2 + continue + } + if char == '"' { + state = stateNormal + } + case stateInTripleSingleQuoteString: + if hasPrefix(runes, i, "'''") { + state = stateNormal + i += 3 + continue + } + case stateInTripleDoubleQuoteString: + if hasPrefix(runes, i, `"""`) { + state = stateNormal + i += 3 + continue + } + case stateInSingleLineCommentDash, stateInSingleLineCommentHash: + if char == '\n' { + state = stateNormal + } + case stateInMultiLineComment: + if hasPrefix(runes, i, "*/") { + state = stateNormal + i += 2 + continue + } + case stateInRawSingleQuoteString: + if char == '\'' { + state = stateNormal + } + case stateInRawDoubleQuoteString: + if char == '"' { + state = stateNormal + } + case stateInRawTripleSingleQuoteString: + if hasPrefix(runes, i, "'''") { + state = stateNormal + i += 3 + continue + } + case stateInRawTripleDoubleQuoteString: + if hasPrefix(runes, i, `"""`) { + state = stateNormal + i += 3 + continue + } + } + i++ + } + + return false, nil } // parseIdentifierSequence parses a sequence of dot-separated identifiers. // It returns the parts of the identifier, the number of characters consumed, and an error. -func parseIdentifierSequence(s string) ([]string, int, error) { +func parseIdentifierSequence(runes []rune) ([]string, int, error) { var parts []string var totalConsumed int - for { - remaining := s[totalConsumed:] - trimmed := strings.TrimLeftFunc(remaining, unicode.IsSpace) - totalConsumed += len(remaining) - len(trimmed) - current := s[totalConsumed:] - - if len(current) == 0 { + // Skip whitespace and comments before identifier part + for { + originalConsumed := totalConsumed + for totalConsumed < len(runes) && unicode.IsSpace(runes[totalConsumed]) { + totalConsumed++ + } + if hasPrefix(runes, totalConsumed, "/*") { + endIdx := indexRunes(runes[totalConsumed:], "*/") + if endIdx != -1 { + totalConsumed += endIdx + 2 + } + } else if hasPrefix(runes, totalConsumed, "--") || (totalConsumed < len(runes) && runes[totalConsumed] == '#') { + endIdx := indexRunes(runes[totalConsumed:], "\n") + if endIdx != -1 { + totalConsumed += endIdx + 1 + } else { + totalConsumed = len(runes) + } + } + if totalConsumed == originalConsumed { + break + } + } + if totalConsumed >= len(runes) { break } var part string var consumed int - if current[0] == '`' { - end := strings.Index(current[1:], "`") + if runes[totalConsumed] == '`' { + end := indexRunes(runes[totalConsumed+1:], "`") if end == -1 { return nil, 0, fmt.Errorf("unclosed backtick identifier") } - part = current[1 : end+1] + part = string(runes[totalConsumed+1 : totalConsumed+end+1]) consumed = end + 2 - } else if len(current) > 0 && unicode.IsLetter(rune(current[0])) { - end := strings.IndexFunc(current, func(r rune) bool { - return !unicode.IsLetter(r) && !unicode.IsNumber(r) && r != '_' && r != '-' - }) - if end == -1 { - part = current - consumed = len(current) - } else { - part = current[:end] - consumed = end + } else if unicode.IsLetter(runes[totalConsumed]) || runes[totalConsumed] == '_' { + end := totalConsumed + for end < len(runes) && (unicode.IsLetter(runes[end]) || unicode.IsNumber(runes[end]) || runes[end] == '_' || runes[end] == '-') { + end++ } + part = string(runes[totalConsumed:end]) + consumed = end - totalConsumed } else { break } - if current[0] == '`' && strings.Contains(part, ".") { - // This handles cases like `project.dataset.table` but not `project.dataset`.table. - // If the character after the quoted identifier is not a dot, we treat it as a full name. - if len(current) <= consumed || current[consumed] != '.' { - parts = append(parts, strings.Split(part, ".")...) - totalConsumed += consumed + parts = append(parts, strings.Split(part, ".")...) + totalConsumed += consumed + + // Skip whitespace and comments between parts (before potential dot) + for { + originalConsumed := totalConsumed + for totalConsumed < len(runes) && unicode.IsSpace(runes[totalConsumed]) { + totalConsumed++ + } + if hasPrefix(runes, totalConsumed, "/*") { + endIdx := indexRunes(runes[totalConsumed:], "*/") + if endIdx != -1 { + totalConsumed += endIdx + 2 + } + } else if hasPrefix(runes, totalConsumed, "--") || (totalConsumed < len(runes) && runes[totalConsumed] == '#') { + endIdx := indexRunes(runes[totalConsumed:], "\n") + if endIdx != -1 { + totalConsumed += endIdx + 1 + } else { + totalConsumed = len(runes) + } + } + if totalConsumed == originalConsumed { break } } - parts = append(parts, strings.Split(part, ".")...) - totalConsumed += consumed - - if len(s) <= totalConsumed || s[totalConsumed] != '.' { + if totalConsumed >= len(runes) || runes[totalConsumed] != '.' { break } totalConsumed++ } + return parts, totalConsumed, nil } func formatTableID(parts []string, defaultProjectID string) (string, error) { + if len(parts) == 4 && strings.Contains(parts[1], ":") { + parts = []string{parts[0] + "." + parts[1], parts[2], parts[3]} + } if len(parts) < 2 || len(parts) > 3 { // Not a table identifier (could be a CTE, column, etc.). - // Return the consumed length so the main loop can skip this identifier. return "", nil } - var tableID string if len(parts) == 3 { // project.dataset.table - tableID = strings.Join(parts, ".") - } else { // dataset.table - if defaultProjectID == "" { - return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, ".")) - } - tableID = fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, ".")) + return strings.Join(parts, "."), nil + } + + // dataset.table + if defaultProjectID == "" { + return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, ".")) } + return fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, ".")), nil +} - return tableID, nil +func indexRunes(r []rune, sub string) int { + subRunes := []rune(sub) + if len(subRunes) == 0 { + return 0 + } + for i := 0; i <= len(r)-len(subRunes); i++ { + match := true + for j := 0; j < len(subRunes); j++ { + if r[i+j] != subRunes[j] { + match = false + break + } + } + if match { + return i + } + } + return -1 } diff --git a/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go b/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go index 75bedbe22de2..b24ae9496d38 100644 --- a/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go +++ b/internal/tools/bigquery/bigquerycommon/table_name_parser_test.go @@ -144,6 +144,27 @@ func TestTableParser(t *testing.T) { want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"}, wantErr: false, }, + { + name: "model as column name in where clause", + sql: "SELECT * FROM `proj.data.tbl` WHERE model = 'v1' AND status = 'active'", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl"}, + wantErr: false, + }, + { + name: "AI.FORECAST function call", + sql: "SELECT * FROM AI.FORECAST(TABLE `project.dataset.table`, data_col => 'val')", + defaultProjectID: "my-project", + want: []string{"project.dataset.table"}, + wantErr: false, + }, + { + name: "ML.GET_INSIGHTS function call", + sql: "SELECT * FROM ML.GET_INSIGHTS(MODEL `project.dataset.model`)", + defaultProjectID: "my-project", + want: []string{"project.dataset.model"}, + wantErr: false, + }, { name: "multi-statement with semicolon", sql: "SELECT * FROM `proj1.data1.tbl1`; SELECT * FROM `proj2.data2.tbl2`", @@ -471,6 +492,165 @@ func TestTableParser(t *testing.T) { wantErr: true, wantErrMsg: "unanalyzable statements like 'CREATE FUNCTION' are not allowed", }, + { + name: "alias filtering simple", + sql: "SELECT t1.col FROM proj.data.table AS t1", + defaultProjectID: "default-proj", + want: []string{"proj.data.table"}, + wantErr: false, + }, + { + name: "alias filtering complex", + sql: "SELECT t1.col1, t2.col2 FROM proj.data.tbl1 t1 JOIN proj.data.tbl2 AS t2 ON t1.id = t2.id", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl1", "proj.data.tbl2"}, + wantErr: false, + }, + { + name: "alias filtering in where clause", + sql: "SELECT * FROM proj.data.tbl1 AS t1 WHERE t1.id > 10", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl1"}, + wantErr: false, + }, + { + name: "unnest column reference", + sql: "SELECT x FROM `proj.ds.tbl` AS t, UNNEST(t.arr) AS x", + defaultProjectID: "default-proj", + want: []string{"proj.ds.tbl"}, + wantErr: false, + }, + { + name: "CTE with dots", + sql: "WITH `my.cte` AS (SELECT 1) SELECT * FROM `my.cte`", + defaultProjectID: "default-proj", + want: []string{}, + wantErr: false, + }, + { + name: "nested CTEs with dot-containing aliases", + sql: ` + WITH raw_metrics AS ( + SELECT id, score FROM production-data.analytics.events + ), + derived.results AS ( + SELECT id, score * 2 as double_score FROM raw_metrics + ) + SELECT * FROM derived.results WHERE double_score > 100 + `, + defaultProjectID: "default-proj", + want: []string{"production-data.analytics.events"}, + wantErr: false, + }, + { + name: "implicit join with comma", + sql: "SELECT * FROM proj.data.tbl1, proj.data.tbl2", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl1", "proj.data.tbl2"}, + wantErr: false, + }, + { + name: "implicit alias", + sql: "SELECT t.col FROM proj.data.tbl t", + defaultProjectID: "default-proj", + want: []string{"proj.data.tbl"}, + wantErr: false, + }, + { + name: "unnest column reference complex", + sql: "SELECT x FROM `proj.ds.tbl` AS t, UNNEST(t.arr) AS x JOIN `other.ds.tbl2` as o ON t.id = o.id", + defaultProjectID: "default-proj", + want: []string{"proj.ds.tbl", "other.ds.tbl2"}, + wantErr: false, + }, + { + name: "create schema statement", + sql: "CREATE SCHEMA proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'CREATE SCHEMA' are not allowed", + }, + { + name: "create dataset statement", + sql: "CREATE DATASET proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'CREATE DATASET' are not allowed", + }, + { + name: "drop schema statement", + sql: "DROP SCHEMA proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'DROP SCHEMA' are not allowed", + }, + { + name: "drop dataset statement", + sql: "DROP DATASET proj.data", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'DROP DATASET' are not allowed", + }, + { + name: "alter schema statement", + sql: "ALTER SCHEMA proj.data SET OPTIONS(description='new one')", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'ALTER SCHEMA' are not allowed", + }, + { + name: "alter dataset statement", + sql: "ALTER DATASET proj.data SET OPTIONS(description='new one')", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "dataset-level operations like 'ALTER DATASET' are not allowed", + }, + { + name: "call fully qualified procedure", + sql: "CALL proj.data.proc()", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "CALL is not allowed when dataset restrictions are in place", + }, + { + name: "create procedure statement", + sql: "CREATE PROCEDURE proj.data.proc() BEGIN SELECT 1; END", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed", + }, + { + name: "create or replace procedure statement", + sql: "CREATE OR REPLACE PROCEDURE proj.data.proc() BEGIN SELECT 1; END", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "unanalyzable statements like 'CREATE OR REPLACE PROCEDURE' are not allowed", + }, + { + name: "create function statement", + sql: "CREATE FUNCTION proj.data.func() RETURNS INT64 AS (1)", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "unanalyzable statements like 'CREATE FUNCTION' are not allowed", + }, + { + name: "simple execute immediate", + sql: "EXECUTE IMMEDIATE 'SELECT 1'", + defaultProjectID: "default-proj", + want: nil, + wantErr: true, + wantErrMsg: "EXECUTE IMMEDIATE is not allowed", + }, { name: "EXTERNAL_QUERY query", sql: "SELECT * FROM EXTERNAL_QUERY('my-conn', 'SELECT 1')", @@ -548,3 +728,131 @@ func TestTableParser(t *testing.T) { }) } } + +func TestIsAnyTableExplicitlyReferenced(t *testing.T) { + testCases := []struct { + name string + sql string + defaultProjectID string + targetTableIDs []string + want bool + }{ + { + name: "simple match", + sql: "SELECT * FROM `proj.ds.tbl`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "match without project id in sql", + sql: "SELECT * FROM `ds.tbl`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"def-proj.ds.tbl"}, + want: true, + }, + { + name: "no match", + sql: "SELECT * FROM `ds.view`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"def-proj.ds.tbl"}, + want: false, + }, + { + name: "ignore in strings", + sql: "SELECT 'proj.ds.tbl' FROM `ds.view` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: false, + }, + { + name: "ignore in comments", + sql: "SELECT * FROM `ds.view` -- referencing proj.ds.tbl here", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: false, + }, + { + name: "match in join", + sql: "SELECT * FROM `ds.view` JOIN `proj.ds.tbl` ON 1=1", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "match as column reference", + sql: "SELECT proj.ds.tbl.col FROM `ds.view`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "match with different casing", + sql: "SELECT * FROM `PROJ.ds.TBL`", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "raw string ignore", + sql: "SELECT r'proj.ds.tbl' FROM `ds.view` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: false, + }, + { + name: "mixed quoting style", + sql: "SELECT * FROM `proj`.ds.`tbl` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "all parts quoted separately", + sql: "SELECT * FROM `proj`.`ds`.`tbl` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "middle part quoted", + sql: "SELECT * FROM proj.`ds`.tbl ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "fully qualified column reference", + sql: "SELECT proj.ds.tbl.col FROM `something` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"proj.ds.tbl"}, + want: true, + }, + { + name: "domain scoped project ID match", + sql: "SELECT * FROM `google.com:project.dataset.table` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"google.com:project.dataset.table"}, + want: true, + }, + { + name: "domain scoped project ID match with column", + sql: "SELECT `google.com:project.dataset.table`.col FROM `something_else` ", + defaultProjectID: "def-proj", + targetTableIDs: []string{"google.com:project.dataset.table"}, + want: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := bigquerycommon.IsAnyTableExplicitlyReferenced(tc.sql, tc.defaultProjectID, tc.targetTableIDs) + if err != nil { + t.Fatalf("IsAnyTableExplicitlyReferenced() error = %v", err) + } + if got != tc.want { + t.Errorf("IsAnyTableExplicitlyReferenced() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 18280f00dc22..4c5720754d7a 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -16,14 +16,17 @@ package bigquerycommon import ( "context" + "errors" "fmt" "regexp" "sort" "strings" bigqueryapi "cloud.google.com/go/bigquery" + "github.com/googleapis/mcp-toolbox/internal/util" "github.com/googleapis/mcp-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" + "google.golang.org/api/googleapi" ) // validBQTableID matches BigQuery table identifiers in 'dataset.table' or @@ -44,6 +47,24 @@ func ValidTableID(s string) bool { return validBQTableID.MatchString(s) } +// ValidColumnParam returns true if s (stripped of leading/trailing single quotes) is a safe column name. +func ValidColumnParam(s string) bool { + return ValidColumnName(StripSingleQuotes(s)) +} + +// ValidContributionMetricParam returns true if s (stripped of leading/trailing single quotes) is a safe contribution metric (does not contain single quotes). +func ValidContributionMetricParam(s string) bool { + return !strings.ContainsRune(StripSingleQuotes(s), '\'') +} + +// StripSingleQuotes removes leading and trailing single quotes from a string if both are present. +func StripSingleQuotes(s string) string { + if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { + return s[1 : len(s)-1] + } + return s +} + // ValidColumnName returns true if s is a safe BigQuery column name. // Values that fail this check must not be interpolated as SQL identifiers // or into single-quoted SQL string arguments that represent column references. @@ -52,9 +73,18 @@ func ValidColumnName(s string) bool { } // DryRunQuery performs a dry run of the SQL query to validate it and get metadata. -func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty, maximumBytesBilled int64) (*bigqueryrestapi.Job, error) { +func DryRunQuery( + ctx context.Context, + restService *bigqueryrestapi.Service, + projectID string, + location string, + sql string, + params []*bigqueryrestapi.QueryParameter, + connProps []*bigqueryapi.ConnectionProperty, + maximumBytesBilled int64, + createSession bool, +) (*bigqueryrestapi.Job, error) { useLegacySql := false - restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps)) for i, prop := range connProps { restConnProps[i] = &bigqueryrestapi.ConnectionProperty{Key: prop.Key, Value: prop.Value} @@ -73,6 +103,7 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj ConnectionProperties: restConnProps, QueryParameters: params, MaximumBytesBilled: maximumBytesBilled, + CreateSession: createSession, }, }, } @@ -84,6 +115,148 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj return insertResponse, nil } +// DatasetValidator defines the interface for checking if a dataset is allowed. +type DatasetValidator interface { + IsDatasetAllowed(projectID, datasetID string) bool +} + +// ValidateQueryAgainstAllowedDatasets validates a SQL query against a list of allowed datasets. +// It uses both dry run and a local parser to support authorized views. +func ValidateQueryAgainstAllowedDatasets( + ctx context.Context, + restService *bigqueryrestapi.Service, + projectID string, + location string, + sql string, + params []*bigqueryrestapi.QueryParameter, + connProps []*bigqueryapi.ConnectionProperty, + validator DatasetValidator, + maximumBytesBilled int64, + createSession bool, +) (*bigqueryrestapi.Job, util.ToolboxError) { + dryRunJob, err := DryRunQuery(ctx, restService, projectID, location, sql, params, connProps, maximumBytesBilled, createSession) + if err != nil { + var gErr *googleapi.Error + if errors.As(err, &gErr) { + return nil, util.ProcessGcpError(err) + } + return nil, util.NewAgentError("query validation failed", err) + } + + if dryRunJob.Statistics == nil || dryRunJob.Statistics.Query == nil { + return nil, util.NewAgentError("dry run failed to return query statistics", nil) + } + + // Use a map to avoid duplicate table names from the dry run result. + tableIDSet := make(map[string]struct{}) + queryStats := dryRunJob.Statistics.Query + if queryStats != nil { + for _, tableRef := range queryStats.ReferencedTables { + if tableRef != nil { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } + } + if tableRef := queryStats.DdlTargetTable; tableRef != nil { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } + if tableRef := queryStats.DdlDestinationTable; tableRef != nil { + tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} + } + } + + var violatingTables []string + for tableID := range tableIDSet { + parts := strings.Split(tableID, ".") + if len(parts) == 3 { + // Skip validation for specific system functions that BigQuery reports as referenced tables. + if IsSystemResource(parts[1], parts[2]) { + continue + } + if !validator.IsDatasetAllowed(parts[0], parts[1]) { + violatingTables = append(violatingTables, tableID) + } + } + } + + // If violations were found, check if they are explicitly in the SQL to support authorized views. + if len(violatingTables) > 0 { + explicitlyReferenced, err := IsAnyTableExplicitlyReferenced(sql, projectID, violatingTables) + if err != nil { + return nil, util.NewAgentError("failed to analyze query for explicit table references", err) + } + if explicitlyReferenced { + violatingDatasets := []string{} + seenDatasets := make(map[string]struct{}) + for _, tableID := range violatingTables { + datasetFQN := strings.Join(strings.Split(tableID, ".")[:2], ".") + if _, seen := seenDatasets[datasetFQN]; !seen { + violatingDatasets = append(violatingDatasets, fmt.Sprintf("'%s'", datasetFQN)) + seenDatasets[datasetFQN] = struct{}{} + } + } + plural := "" + if len(violatingDatasets) > 1 { + plural = "s" + } + return nil, util.NewAgentError(fmt.Sprintf("access to dataset%s %s is not allowed", plural, strings.Join(violatingDatasets, ", ")), nil) + } + } + + // Fall back to TableParser for final intent verification or if dry run was inconclusive. + parsedTables, parseErr := TableParser(sql, projectID) + if parseErr != nil { + return nil, util.NewAgentError("could not safely analyze query with dataset restrictions", parseErr) + } + + var parsedViolatingDatasets []string + seenParsedDatasets := make(map[string]struct{}) + for _, tableID := range parsedTables { + parts := strings.Split(tableID, ".") + if len(parts) == 3 { + if IsSystemResource(parts[1], parts[2]) { + continue + } + if !validator.IsDatasetAllowed(parts[0], parts[1]) { + datasetFQN := fmt.Sprintf("%s.%s", parts[0], parts[1]) + if _, seen := seenParsedDatasets[datasetFQN]; !seen { + parsedViolatingDatasets = append(parsedViolatingDatasets, fmt.Sprintf("'%s'", datasetFQN)) + seenParsedDatasets[datasetFQN] = struct{}{} + } + } + } + } + if len(parsedViolatingDatasets) > 0 { + plural := "" + if len(parsedViolatingDatasets) > 1 { + plural = "s" + } + return nil, util.NewAgentError(fmt.Sprintf("access to dataset%s %s is not allowed", plural, strings.Join(parsedViolatingDatasets, ", ")), nil) + } + + return dryRunJob, nil +} + +// IsSystemResource checks if a given dataset and table/function ID refer to a BigQuery system resource +// (like a built-in AI or ML function) that should be exempted from dataset restriction checks. +func IsSystemResource(datasetID, resourceID string) bool { + datasetID = strings.ToUpper(datasetID) + resourceID = strings.ToUpper(resourceID) + + if datasetID == "AI" { + switch resourceID { + case "FORECAST", "GENERATE_TEXT", "EXTRACT_ENTITY", "SUMMARIZE": + return true + } + } + if datasetID == "ML" { + switch resourceID { + case "GET_INSIGHTS", "EXPLAIN_PREDICT", "GENERATE_TEXT", "DISTANCE", "PREDICT": + return true + } + } + return false +} + // BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string. func BQTypeStringFromToolType(toolType string) (string, error) { switch toolType { @@ -147,21 +320,3 @@ func InitializeDatasetParameters( return projectParam, datasetParam } - -// StripSingleQuotes removes leading and trailing single quotes from a string if both are present. -func StripSingleQuotes(s string) string { - if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { - return s[1 : len(s)-1] - } - return s -} - -// ValidColumnParam returns true if s (stripped of leading/trailing single quotes) is a safe column name. -func ValidColumnParam(s string) bool { - return ValidColumnName(StripSingleQuotes(s)) -} - -// ValidContributionMetricParam returns true if s (stripped of leading/trailing single quotes) is a safe contribution metric (does not contain single quotes). -func ValidContributionMetricParam(s string) bool { - return !strings.ContainsRune(StripSingleQuotes(s), '\'') -} diff --git a/internal/tools/bigquery/bigquerycommon/validators_test.go b/internal/tools/bigquery/bigquerycommon/validators_test.go index 9f0efae07c70..8e8e114578f9 100644 --- a/internal/tools/bigquery/bigquerycommon/validators_test.go +++ b/internal/tools/bigquery/bigquerycommon/validators_test.go @@ -139,3 +139,99 @@ func TestStripSingleQuotes(t *testing.T) { } } } + +func TestValidColumnParam(t *testing.T) { + tcs := []struct { + in string + valid bool + }{ + {"'sales'", true}, + {"sales", true}, + {"'sales_col'", true}, + {"_internal", true}, + {"'1col'", false}, + {"col'", false}, + {"'col", false}, + {"'col; DROP TABLE x'", false}, + {"", false}, + } + for _, tc := range tcs { + if got := bigquerycommon.ValidColumnParam(tc.in); got != tc.valid { + t.Errorf("ValidColumnParam(%q) = %v, want %v", tc.in, got, tc.valid) + } + } +} + +func TestValidContributionMetricParam(t *testing.T) { + tcs := []struct { + in string + valid bool + }{ + {"'metric'", true}, + {"metric", true}, + {"'metric's'", false}, + {"metric's", false}, + {"'metric", false}, + {"''metric''", false}, + } + for _, tc := range tcs { + if got := bigquerycommon.ValidContributionMetricParam(tc.in); got != tc.valid { + t.Errorf("ValidContributionMetricParam(%q) = %v, want %v", tc.in, got, tc.valid) + } + } +} + +func TestIsSystemResource(t *testing.T) { + tcs := []struct { + datasetID string + resourceID string + want bool + }{ + {"AI", "FORECAST", true}, + {"ai", "forecast", true}, + {"AI", "GENERATE_TEXT", true}, + {"AI", "SUMMARIZE", true}, + {"ML", "PREDICT", true}, + {"ml", "predict", true}, + {"ML", "DISTANCE", true}, + {"ML", "GET_INSIGHTS", true}, + {"AI", "INVALID", false}, + {"ML", "INVALID", false}, + {"OTHER", "FORECAST", false}, + } + for _, tc := range tcs { + if got := bigquerycommon.IsSystemResource(tc.datasetID, tc.resourceID); got != tc.want { + t.Errorf("IsSystemResource(%q, %q) = %v, want %v", tc.datasetID, tc.resourceID, got, tc.want) + } + } +} + +func TestBQTypeStringFromToolType(t *testing.T) { + tcs := []struct { + in string + want string + wantErr bool + }{ + {"string", "STRING", false}, + {"integer", "INT64", false}, + {"float", "FLOAT64", false}, + {"boolean", "BOOL", false}, + {"map", "STRUCT", false}, + {"invalid", "", true}, + } + for _, tc := range tcs { + got, err := bigquerycommon.BQTypeStringFromToolType(tc.in) + if tc.wantErr { + if err == nil { + t.Errorf("BQTypeStringFromToolType(%q) expected error, got nil", tc.in) + } + } else { + if err != nil { + t.Errorf("BQTypeStringFromToolType(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Errorf("BQTypeStringFromToolType(%q) = %q, want %q", tc.in, got, tc.want) + } + } + } +} diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index 214360f31e3c..543f692d34a0 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -138,11 +138,23 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source.GetMaximumBytesBilled()) - if err != nil { - return nil, util.ProcessGcpError(err) + var dryRunJob *bigqueryrestapi.Job + if len(source.BigQueryAllowedDatasets()) > 0 { + var validationErr util.ToolboxError + dryRunJob, validationErr = bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr + } + } else { + dryRunJob, err = bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, source.GetMaximumBytesBilled(), false) + if err != nil { + return nil, util.ProcessGcpError(err) + } } + if dryRunJob.Statistics == nil || dryRunJob.Statistics.Query == nil { + return nil, util.NewClientServerError("dry run failed to return query statistics", http.StatusInternalServerError, nil) + } statementType := dryRunJob.Statistics.Query.StatementType switch source.BigQueryWriteMode() { @@ -159,53 +171,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if len(source.BigQueryAllowedDatasets()) > 0 { - switch statementType { - case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": - return nil, util.NewAgentError(fmt.Sprintf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType), nil) - case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE": - return nil, util.NewAgentError(fmt.Sprintf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil) - case "CALL": - return nil, util.NewAgentError(fmt.Sprintf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil) - } - - // Use a map to avoid duplicate table names. - tableIDSet := make(map[string]struct{}) - - // Get all tables from the dry run result. This is the most reliable method. - queryStats := dryRunJob.Statistics.Query - if queryStats != nil { - for _, tableRef := range queryStats.ReferencedTables { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} - } - if tableRef := queryStats.DdlTargetTable; tableRef != nil { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} - } - if tableRef := queryStats.DdlDestinationTable; tableRef != nil { - tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{} - } - } - - // Always run the parser to ensure we catch views/tables that the dry run might bypass - parsedTables, parseErr := bqutil.TableParser(sql, bqClient.Project()) - if parseErr != nil { - return nil, util.NewAgentError("could not parse tables from query to validate against allowed datasets", parseErr) - } - for _, tableID := range parsedTables { - tableIDSet[tableID] = struct{}{} - } - - for tableID := range tableIDSet { - parts := strings.Split(tableID, ".") - if len(parts) == 3 { - projectID, datasetID := parts[0], parts[1] - if !source.IsDatasetAllowed(projectID, datasetID) { - return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID), nil) - } - } - } - } - if dryRun { if dryRunJob != nil { jobJSON, err := json.MarshalIndent(dryRunJob, "", " ") @@ -226,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) resp, err := source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps, map[string]string{"mcp-toolbox-tool": resourceType}) if err != nil { - return nil, util.NewClientServerError("error running sql", http.StatusInternalServerError, err) + return nil, util.ProcessGcpError(err) } return resp, nil } diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql_test.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql_test.go index 78d45943c462..20ae0dbc0b43 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql_test.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql_test.go @@ -210,7 +210,7 @@ func TestInvokeDatasetRestrictions(t *testing.T) { desc: "querying forbidden dataset table", sql: "SELECT * FROM forbidden_dataset.my_table", wantErr: true, - wantSub: "query accesses dataset 'test-project.forbidden_dataset', which is not in the allowed list", + wantSub: "access to dataset 'test-project.forbidden_dataset' is not allowed", }, { desc: "querying allowed dataset INFORMATION_SCHEMA tables", @@ -221,7 +221,7 @@ func TestInvokeDatasetRestrictions(t *testing.T) { desc: "querying forbidden dataset INFORMATION_SCHEMA tables", sql: "SELECT * FROM forbidden_dataset.INFORMATION_SCHEMA.TABLES", wantErr: true, - wantSub: "query accesses dataset 'test-project.forbidden_dataset', which is not in the allowed list", + wantSub: "access to dataset 'test-project.forbidden_dataset' is not allowed", }, { desc: "querying regional INFORMATION_SCHEMA schemata", @@ -239,7 +239,7 @@ func TestInvokeDatasetRestrictions(t *testing.T) { desc: "querying mixed allowed table and forbidden INFORMATION_SCHEMA view", sql: "SELECT * FROM allowed_dataset.my_table JOIN forbidden_dataset.INFORMATION_SCHEMA.TABLES ON true", wantErr: true, - wantSub: "query accesses dataset 'test-project.forbidden_dataset', which is not in the allowed list", + wantSub: "access to dataset 'test-project.forbidden_dataset' is not allowed", }, { desc: "querying EXTERNAL_QUERY", diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 3c4c819c0157..11da1e842b40 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -154,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - bqClient, _, err := source.RetrieveClientAndService(accessToken) + bqClient, restService, err := source.RetrieveClientAndService(accessToken) if err != nil { return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } @@ -162,6 +162,26 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var historyDataSource string trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData)) if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") { + if len(source.BigQueryAllowedDatasets()) > 0 { + var connProps []*bigqueryapi.ConnectionProperty + session, err := source.BigQuerySession()(ctx) + if err != nil { + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) + } + if session != nil { + connProps = []*bigqueryapi.ConnectionProperty{ + {Key: "session_id", Value: session.ID}, + } + } + + dryRunJob, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr + } + if dryRunJob.Statistics.Query.StatementType != "SELECT" { + return nil, util.NewAgentError(fmt.Sprintf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", dryRunJob.Statistics.Query.StatementType), nil) + } + } historyDataSource = fmt.Sprintf("(%s)", historyData) } else { if !bqutil.ValidTableID(historyData) { @@ -215,29 +235,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if len(source.BigQueryAllowedDatasets()) > 0 { - dryRunQuery := bqClient.Query(sql) - dryRunQuery.Location = bqClient.Location - if connProps != nil { - dryRunQuery.ConnectionProperties = connProps - } - dryRunQuery.DryRun = true - dryRunJob, err := dryRunQuery.Run(ctx) - if err != nil { - return nil, util.ProcessGcpError(err) - } - status := dryRunJob.LastStatus() - if status.Statistics != nil { - if qStats, ok := status.Statistics.Details.(*bigqueryapi.QueryStatistics); ok { - for _, tableRef := range qStats.ReferencedTables { - if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { - return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectID, tableRef.DatasetID), nil) - } - } - } else { - return nil, util.NewAgentError("could not get query statistics details during dry run validation", nil) - } - } else { - return nil, util.NewAgentError("could not dry run final query to validate allowed datasets", nil) + _, validationErr := bqutil.ValidateQueryAgainstAllowedDatasets(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, sql, nil, connProps, source, source.GetMaximumBytesBilled(), false) + if validationErr != nil { + return nil, validationErr } } diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast_test.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast_test.go index 242af97f9d14..ec9a94f23c1a 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast_test.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast_test.go @@ -31,6 +31,7 @@ import ( "github.com/googleapis/mcp-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/mcp-toolbox/internal/tools/bigquery/bigqueryforecast" "github.com/googleapis/mcp-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/option" ) @@ -299,9 +300,15 @@ func TestInvokeAllowedDatasetsValidation(t *testing.T) { t.Fatalf("failed to create mocked BigQuery client: %v", err) } + restService, err := bigqueryrestapi.NewService(ctx, option.WithEndpoint(mockServer.URL), option.WithoutAuthentication()) + if err != nil { + t.Fatalf("failed to create mocked BigQuery REST service: %v", err) + } + // 3. Define mock source that returns this client and allowed datasets configuration testSrc := &bigquerycommon.MockSource{ Client: bqClient, + Service: restService, AllowedDatasets: []string{"allowed_dataset"}, // only "allowed_dataset" is allowed! } @@ -329,7 +336,7 @@ func TestInvokeAllowedDatasetsValidation(t *testing.T) { // 4. Set up parameters mimicking the bypass/injection attempt // We try to run the tool, but the dry-run of the final query will detect the reference to "unauthorized_dataset" data := map[string]any{ - "history_data": "allowed_dataset.my_table", + "history_data": "SELECT * FROM unauthorized_dataset.some_table", "timestamp_col": "ts", "data_col": "val", "horizon": 5, @@ -351,7 +358,7 @@ func TestInvokeAllowedDatasetsValidation(t *testing.T) { t.Fatal("expected Invoke to return an error due to out-of-allowlist dataset reference, but got nil") } - expectedErr := "query accesses dataset 'test-project.unauthorized_dataset', which is not in the allowed list" + expectedErr := "access to dataset 'test-project.unauthorized_dataset' is not allowed" if !strings.Contains(err.Error(), expectedErr) { t.Errorf("expected error to contain %q, got: %v", expectedErr, err) } diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index b11e7807e3ee..1f12896f791f 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -140,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled()) + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps, source.GetMaximumBytesBilled(), false) if err != nil { return nil, util.ProcessGcpError(err) } diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 2cf0eaa99913..1b8d4fe53e58 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -89,8 +89,12 @@ func TestBigQueryToolEndpoints(t *testing.T) { } // create table name with UUID - datasetName := fmt.Sprintf("temp_toolbox_test_%s", uniqueID) - tableName := fmt.Sprintf("param_table_%s", uniqueID) + datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) + + cleanupDatasets := ensureTeardownDatasets(ctx, client, datasetName) + defer cleanupDatasets(t) + + tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) tableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, datasetName, @@ -125,7 +129,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { // global cleanup for this test run t.Cleanup(func() { - tests.CleanupBigQueryDatasets(t, context.Background(), client, []string{datasetName}) + CleanupBigQueryDatasets(t, context.Background(), client, []string{datasetName}) }) // set up data for param tool @@ -227,7 +231,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { func TestBigQueryToolWithDatasetRestriction(t *testing.T) { uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "") t.Logf("Starting restriction test with uniqueID: %s", uniqueID) - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() client, err := initBigQueryConnection(BigqueryProject) @@ -235,9 +239,15 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { t.Fatalf("unable to create BigQuery client: %s", err) } - allowedDatasetName1 := fmt.Sprintf("allowed_dataset_1_%s", uniqueID) - allowedDatasetName2 := fmt.Sprintf("allowed_dataset_2_%s", uniqueID) - disallowedDatasetName := fmt.Sprintf("disallowed_dataset_%s", uniqueID) + // Create two datasets, one allowed, one not. + baseName := strings.ReplaceAll(uuid.New().String(), "-", "") + allowedDatasetName1 := fmt.Sprintf("allowed_dataset_1_%s", baseName) + allowedDatasetName2 := fmt.Sprintf("allowed_dataset_2_%s", baseName) + disallowedDatasetName := fmt.Sprintf("disallowed_dataset_%s", baseName) + + cleanupDatasets := ensureTeardownDatasets(ctx, client, allowedDatasetName1, allowedDatasetName2, disallowedDatasetName) + defer cleanupDatasets(t) + allowedTableName1 := "allowed_table_1" allowedTableName2 := "allowed_table_2" disallowedTableName := "disallowed_table" @@ -251,7 +261,7 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { // global cleanup for this test run t.Cleanup(func() { - tests.CleanupBigQueryDatasets(t, context.Background(), client, []string{allowedDatasetName1, allowedDatasetName2, disallowedDatasetName}) + CleanupBigQueryDatasets(t, context.Background(), client, []string{allowedDatasetName1, allowedDatasetName2, disallowedDatasetName}) }) // Setup allowed table @@ -296,6 +306,35 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName) setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams) + // Setup authorized views in BOTH allowed datasets pointing to disallowed tables + viewInAllowedPointingToDisallowedForecastName := "auth_view_forecast" + viewInAllowedPointingToDisallowedAnalyzeName := "auth_view_analyze" + + // Create Forecast views + for _, dsName := range []string{allowedDatasetName1, allowedDatasetName2} { + teardownForecastView := setupBigQueryView(t, ctx, client, dsName, viewInAllowedPointingToDisallowedForecastName, fmt.Sprintf("SELECT * FROM %s", disallowedForecastTableFullName)) + defer teardownForecastView(t) + + teardownAnalyzeView := setupBigQueryView(t, ctx, client, dsName, viewInAllowedPointingToDisallowedAnalyzeName, fmt.Sprintf("SELECT * FROM %s", disallowedAnalyzeContributionTableFullName)) + defer teardownAnalyzeView(t) + } + + // Authorize ALL views to access the disallowed dataset + dsMetadata, err := client.Dataset(disallowedDatasetName).Metadata(ctx) + if err != nil { + t.Fatalf("failed to get disallowed dataset metadata: %v", err) + } + newAccess := append(dsMetadata.Access, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName1).Table(viewInAllowedPointingToDisallowedForecastName)}, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName2).Table(viewInAllowedPointingToDisallowedForecastName)}, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName1).Table(viewInAllowedPointingToDisallowedAnalyzeName)}, + &bigqueryapi.AccessEntry{EntityType: bigqueryapi.ViewEntity, View: client.Dataset(allowedDatasetName2).Table(viewInAllowedPointingToDisallowedAnalyzeName)}, + ) + update := bigqueryapi.DatasetMetadataToUpdate{Access: newAccess} + if _, err := client.Dataset(disallowedDatasetName).Update(ctx, update, dsMetadata.ETag); err != nil { + t.Fatalf("failed to authorize views: %v", err) + } + // Configure source with dataset restriction. sourceConfig := getBigQueryVars(t) sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2} @@ -396,7 +435,7 @@ func TestBigQueryWriteModeAllowed(t *testing.T) { sourceConfig := getBigQueryVars(t) sourceConfig["writeMode"] = "allowed" - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() datasetName := fmt.Sprintf("temp_toolbox_test_allowed_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) @@ -452,7 +491,7 @@ func TestBigQueryWriteModeBlocked(t *testing.T) { sourceConfig := getBigQueryVars(t) sourceConfig["writeMode"] = "blocked" - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() datasetName := fmt.Sprintf("temp_toolbox_test_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) @@ -653,6 +692,33 @@ func getBigQueryTmplToolStatement() (string, string) { return tmplSelectCombined, tmplSelectFilterCombined } +func ensureTeardownDatasets(ctx context.Context, client *bigqueryapi.Client, datasetNames ...string) func(*testing.T) { + return func(t *testing.T) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + for _, dsName := range datasetNames { + if err := client.Dataset(dsName).DeleteWithContents(cleanupCtx); err != nil { + t.Logf("failed to cleanup dataset %s: %v", dsName, err) + } + } + } +} + +func setupBigQueryView(t *testing.T, ctx context.Context, client *bigqueryapi.Client, datasetName, viewName, query string) func(*testing.T) { + if err := client.Dataset(datasetName).Table(viewName).Create(ctx, &bigqueryapi.TableMetadata{ + ViewQuery: query, + }); err != nil { + t.Fatalf("failed to create view %s in %s: %v", viewName, datasetName, err) + } + return func(t *testing.T) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + if err := client.Dataset(datasetName).Table(viewName).Delete(cleanupCtx); err != nil { + t.Errorf("failed to delete view %s in %s: %v", viewName, datasetName, err) + } + } +} + func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) { // Create dataset dataset := client.Dataset(datasetName) @@ -701,14 +767,16 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C } return func(t *testing.T) { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() // tear down table dropSQL := fmt.Sprintf("drop table %s", tableName) - dropJob, err := client.Query(dropSQL).Run(ctx) + dropJob, err := client.Query(dropSQL).Run(cleanupCtx) if err != nil { t.Errorf("Failed to start drop table job for %s: %v", tableName, err) return } - dropStatus, err := dropJob.Wait(ctx) + dropStatus, err := dropJob.Wait(cleanupCtx) if err != nil { t.Errorf("Failed to wait for drop table job for %s: %v", tableName, err) return @@ -716,19 +784,6 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C if err := dropStatus.Err(); err != nil { t.Errorf("Error dropping table %s: %v", tableName, err) } - - // tear down dataset - datasetToTeardown := client.Dataset(datasetName) - tablesIterator := datasetToTeardown.Tables(ctx) - _, err = tablesIterator.Next() - - if err == iterator.Done { - if err := datasetToTeardown.Delete(ctx); err != nil { - t.Errorf("Failed to delete dataset %s: %v", datasetName, err) - } - } else if err != nil { - t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err) - } } } @@ -2587,6 +2642,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed } func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string, allowedTableNames ...string) { + allowedTableNames = append(allowedTableNames, "auth_view_forecast", "auth_view_analyze") sort.Strings(allowedTableNames) var quotedNames []string for _, name := range allowedTableNames { @@ -2774,7 +2830,8 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed if len(allowedTableParts) != 3 { t.Fatalf("invalid allowed table name format: %s", allowedTableFullName) } - allowedDatasetID := allowedTableParts[1] + allowedProjectID, allowedDatasetID := allowedTableParts[0], allowedTableParts[1] + viewFullName := fmt.Sprintf("`%s.%s.auth_view_forecast`", allowedProjectID, allowedDatasetID) testCases := []struct { name string @@ -2787,32 +2844,36 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed sql: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName), wantStatusCode: http.StatusOK, }, + { + name: "invoke on authorized view", + sql: fmt.Sprintf("SELECT * FROM %s", viewFullName), + wantStatusCode: http.StatusOK, + }, { name: "invoke on disallowed table", sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", - strings.Join( - strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2], - ".")), + wantInError: fmt.Sprintf("access to dataset '%s.%s' is not allowed", + strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0], + strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[1]), }, { name: "disallowed create schema", sql: "CREATE SCHEMA another_dataset", wantStatusCode: http.StatusOK, - wantInError: "dataset-level operations like 'CREATE_SCHEMA' are not allowed", + wantInError: "dataset-level operations like 'CREATE SCHEMA' are not allowed", }, { name: "disallowed alter schema", sql: fmt.Sprintf("ALTER SCHEMA %s SET OPTIONS(description='new one')", allowedDatasetID), wantStatusCode: http.StatusOK, - wantInError: "dataset-level operations like 'ALTER_SCHEMA' are not allowed", + wantInError: "dataset-level operations like 'ALTER SCHEMA' are not allowed", }, { name: "disallowed create function", sql: fmt.Sprintf("CREATE FUNCTION %s.my_func() RETURNS INT64 AS (1)", allowedDatasetID), wantStatusCode: http.StatusOK, - wantInError: "creating stored routines ('CREATE_FUNCTION') is not allowed", + wantInError: "unanalyzable statements like 'CREATE FUNCTION' are not allowed", }, { name: "disallowed create procedure", @@ -2824,7 +2885,7 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed name: "disallowed execute immediate", sql: "EXECUTE IMMEDIATE 'SELECT 1'", wantStatusCode: http.StatusOK, - wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place", + wantInError: "EXECUTE IMMEDIATE is not allowed", }, } @@ -2939,6 +3000,10 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "") disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".") + allowedParts := strings.Split(allowedTableUnquoted, ".") + viewFullName := fmt.Sprintf("`%s.%s.auth_view_forecast`", allowedParts[0], allowedParts[1]) + viewUnquoted := strings.ReplaceAll(viewFullName, "`", "") + testCases := []struct { name string historyData string @@ -2954,6 +3019,12 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa wantStatusCode: http.StatusOK, wantInResult: `"forecast_timestamp"`, }, + { + name: "invoke with authorized view name", + historyData: viewUnquoted, + wantStatusCode: http.StatusOK, + wantInResult: `"forecast_timestamp"`, + }, { name: "invoke with disallowed table name", historyData: disallowedTableUnquoted, @@ -2966,11 +3037,17 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa wantStatusCode: http.StatusOK, wantInResult: `"forecast_timestamp"`, }, + { + name: "invoke with query on authorized view", + historyData: fmt.Sprintf("SELECT * FROM %s", viewFullName), + wantStatusCode: http.StatusOK, + wantInResult: `"forecast_timestamp"`, + }, { name: "invoke with query on disallowed table", historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("access to dataset '%s' is not allowed", disallowedDatasetFQN), }, { name: "invoke with SQL injection in timestamp_col", @@ -3070,6 +3147,10 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "") disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".") + allowedParts := strings.Split(allowedTableUnquoted, ".") + viewFullName := fmt.Sprintf("`%s.%s.auth_view_analyze`", allowedParts[0], allowedParts[1]) + viewUnquoted := strings.ReplaceAll(viewFullName, "`", "") + testCases := []struct { name string inputData string @@ -3086,6 +3167,12 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d wantStatusCode: http.StatusOK, wantInResult: `"relative_difference"`, }, + { + name: "invoke with authorized view name", + inputData: viewUnquoted, + wantStatusCode: http.StatusOK, + wantInResult: `"relative_difference"`, + }, { name: "invoke with disallowed table name", inputData: disallowedTableUnquoted, @@ -3098,11 +3185,17 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d wantStatusCode: http.StatusOK, wantInResult: `"relative_difference"`, }, + { + name: "invoke with query on authorized view", + inputData: fmt.Sprintf("SELECT * FROM %s", viewFullName), + wantStatusCode: http.StatusOK, + wantInResult: `"relative_difference"`, + }, { name: "invoke with query on disallowed table", inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), wantStatusCode: http.StatusOK, - wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantInError: fmt.Sprintf("access to dataset '%s' is not allowed", disallowedDatasetFQN), }, { name: "invoke with SQL injection in is_test_col", @@ -3216,3 +3309,43 @@ func getBigQueryVectorSearchStmts(vectorTableName string) (string, string) { searchStmt := fmt.Sprintf("SELECT id, content, ML.DISTANCE(embedding, @query, 'COSINE') AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName) return insertStmt, searchStmt } + +func CleanupBigQueryDatasets(t *testing.T, ctx context.Context, client *bigqueryapi.Client, datasetIDs []string) { + for _, id := range datasetIDs { + t.Logf("INTEGRATION CLEANUP: Purging dataset %s", id) + ds := client.Dataset(id) + + // Delete tables first since Dataset.Delete fails if not empty + tableIt := ds.Tables(ctx) + for { + table, err := tableIt.Next() + if err == iterator.Done { + break + } + if err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted (during table iteration)", id) + break + } + t.Errorf("INTEGRATION CLEANUP: Failed to iterate tables in %s: %v", id, err) + break + } + if err := table.Delete(ctx); err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + continue + } + t.Errorf("INTEGRATION CLEANUP: Failed to delete table %s: %v", table.TableID, err) + } + } + // delete empty dataset + if err := ds.Delete(ctx); err != nil { + if apiErr, ok := err.(*googleapi.Error); ok && apiErr.Code == 404 { + t.Logf("INTEGRATION CLEANUP: Dataset %s already deleted", id) + } else { + t.Errorf("INTEGRATION CLEANUP: Failed to delete dataset %s: %v", id, err) + } + } else { + t.Logf("INTEGRATION CLEANUP SUCCESS: Wiped dataset %s", id) + } + } +} diff --git a/tests/common.go b/tests/common.go index 4e540a59be27..06504065931a 100644 --- a/tests/common.go +++ b/tests/common.go @@ -24,7 +24,6 @@ import ( "strings" "testing" - "cloud.google.com/go/bigquery" "cloud.google.com/go/bigtable" "github.com/google/go-cmp/cmp" "github.com/googleapis/mcp-toolbox/internal/server" @@ -32,7 +31,6 @@ import ( "github.com/googleapis/mcp-toolbox/internal/testutils" "github.com/googleapis/mcp-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" - "google.golang.org/api/iterator" ) // GetToolsConfig returns a mock tools config file @@ -1090,35 +1088,6 @@ func CleanupMSSQLTables(t *testing.T, ctx context.Context, pool *sql.DB) { } -func CleanupBigQueryDatasets(t *testing.T, ctx context.Context, client *bigquery.Client, datasetIDs []string) { - for _, id := range datasetIDs { - t.Logf("INTEGRATION CLEANUP: Purging dataset %s", id) - ds := client.Dataset(id) - - //Delete tables first since Dataset.Delete fails if not empty - tableIt := ds.Tables(ctx) - for { - table, err := tableIt.Next() - if err == iterator.Done { - break - } - if err != nil { - t.Errorf("INTEGRATION CLEANUP: Failed to iterate tables in %s: %v", id, err) - break - } - if err := table.Delete(ctx); err != nil { - t.Errorf("INTEGRATION CLEANUP: Failed to delete table %s: %v", table.TableID, err) - } - } - //delete empty dataset - if err := ds.Delete(ctx); err != nil { - t.Errorf("INTEGRATION CLEANUP: Failed to delete dataset %s: %v", id, err) - } else { - t.Logf("INTEGRATION CLEANUP SUCCESS: Wiped dataset %s", id) - } - } -} - // finds and deletes all tables in a Bigtable instance that match the uniqueID. func CleanupBigtableTables(t *testing.T, ctx context.Context, adminClient *bigtable.AdminClient, uniqueID string) { tables, err := adminClient.Tables(ctx)