Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
19 changes: 16 additions & 3 deletions cmd/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Config struct {
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
Prompts server.PromptConfigs `yaml:"prompts"`
PiiPolicies server.PiiPolicyConfigs `yaml:"piiPolicies"`
}

type ConfigParser struct {
Expand Down Expand Up @@ -150,7 +151,7 @@ func (p *ConfigParser) ParseConfig(ctx context.Context, raw []byte) (Config, err
}

// Parse contents
config.Sources, config.AuthServices, config.EmbeddingModels, config.Tools, config.Toolsets, config.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
config.Sources, config.AuthServices, config.EmbeddingModels, config.Tools, config.Toolsets, config.Prompts, config.PiiPolicies, err = server.UnmarshalResourceConfig(ctx, raw)
if err != nil {
return config, err
}
Expand Down Expand Up @@ -180,7 +181,7 @@ func ConvertConfig(raw []byte) ([]byte, error) {
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
encoder := yaml.NewEncoder(&buf, yaml.UseLiteralStyleIfMultiline(true))

nestedFormatKey := []string{"sources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"}
nestedFormatKey := []string{"sources", "authServices", "embeddingModels", "tools", "toolsets", "prompts", "piiPolicies"}
docIndex := 0
for {
if err := decoder.Decode(&input); err != nil {
Expand Down Expand Up @@ -217,6 +218,8 @@ func ConvertConfig(raw []byte) ([]byte, error) {
key = "toolset"
case "prompts":
key = "prompt"
case "piiPolicies":
key = "piiPolicy"
}
transformed, err := transformDocs(key, slice)
if err != nil {
Expand Down Expand Up @@ -316,6 +319,7 @@ func mergeConfigs(files ...Config) (Config, error) {
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: make(server.PromptConfigs),
PiiPolicies: make(server.PiiPolicyConfigs),
}

var conflicts []string
Expand Down Expand Up @@ -376,11 +380,20 @@ func mergeConfigs(files ...Config) (Config, error) {
merged.Prompts[name] = prompt
}
}

// Check for conflicts and merge piiPolicies
for name, piiPolicy := range file.PiiPolicies {
if _, exists := merged.PiiPolicies[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("piiPolicy '%s' (file #%d)", name, fileIndex+1))
} else {
merged.PiiPolicies[name] = piiPolicy
}
}
}

// If conflicts were detected, return an error
if len(conflicts) > 0 {
return Config{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - "))
return Config{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset, prompt and piiPolicy has a unique name across all files", strings.Join(conflicts, "\n - "))
}

// Ensure only one authService has mcpEnabled = true
Expand Down
4 changes: 2 additions & 2 deletions cmd/internal/invoke/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ func runInvoke(cmd *cobra.Command, args []string, opts *internal.ToolboxOptions)
}

// Initialize Resources
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, _, err := server.InitializeConfigs(ctx, opts.Cfg)
if err != nil {
errMsg := fmt.Errorf("failed to initialize resources: %w", err)
opts.Logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}

resourceMgr := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
resourceMgr := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil)

// Execute Tool
toolName := args[0]
Expand Down
1 change: 1 addition & 0 deletions cmd/internal/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ func (opts *ToolboxOptions) LoadConfig(ctx context.Context, parser *ConfigParser
opts.Cfg.ToolConfigs = finalConfig.Tools
opts.Cfg.ToolsetConfigs = finalConfig.Toolsets
opts.Cfg.PromptConfigs = finalConfig.Prompts
opts.Cfg.PiiPolicyConfigs = finalConfig.PiiPolicies

return isCustomConfigured, nil
}
2 changes: 1 addition & 1 deletion cmd/internal/skills/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (c *skillsCmd) collectTools(ctx context.Context, opts *internal.ToolboxOpti
return nil, fmt.Errorf("failed to initialize resources: %w", err)
}

resourceMgr := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsetsMap, nil, nil)
resourceMgr := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsetsMap, nil, nil, nil)

skillsToTools := make(map[string]map[string]tools.Tool)

Expand Down
13 changes: 7 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/googleapis/mcp-toolbox/cmd/internal/skills"
"github.com/googleapis/mcp-toolbox/internal/auth"
"github.com/googleapis/mcp-toolbox/internal/embeddingmodels"
"github.com/googleapis/mcp-toolbox/internal/piipolicy"
"github.com/googleapis/mcp-toolbox/internal/prompts"
"github.com/googleapis/mcp-toolbox/internal/server"
"github.com/googleapis/mcp-toolbox/internal/sources"
Expand Down Expand Up @@ -138,22 +139,22 @@ func handleDynamicReload(ctx context.Context, toolsFile internal.Config, s *serv
panic(err)
}

sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, piiPoliciesMap, err := validateReloadEdits(ctx, toolsFile)
if err != nil {
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
logger.WarnContext(ctx, errMsg.Error())
return err
}

s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, piiPoliciesMap)

return nil
}

// validateReloadEdits checks that the reloaded config configs can initialized without failing
func validateReloadEdits(
ctx context.Context, toolsFile internal.Config,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, map[string]piipolicy.Config, error,
) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
Expand Down Expand Up @@ -181,14 +182,14 @@ func validateReloadEdits(
IgnoreUnknownTools: util.IgnoreUnknownToolsFromContext(ctx),
}

sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, piiPoliciesMap, err := server.InitializeConfigs(ctx, reloadedConfig)
if err != nil {
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
logger.WarnContext(ctx, errMsg.Error())
return nil, nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, nil, err
}

return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, piiPoliciesMap, nil
}

// Helper to check if a file has a newer ModTime than stored in the map
Expand Down
4 changes: 2 additions & 2 deletions internal/embeddingmodels/gemini/gemini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestParseFromYamlGemini(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
// Parse contents
_, _, got, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
_, _, got, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestFailParseFromYamlGemini(t *testing.T) {
t.Setenv("GOOGLE_CLOUD_PROJECT", "")
t.Setenv("GOOGLE_CLOUD_LOCATION", "")

_, embeddingConfigs, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
_, embeddingConfigs, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
if err.Error() != tc.err {
t.Fatalf("unexpected unmarshal error:\ngot: %q\nwant: %q", err.Error(), tc.err)
Expand Down
192 changes: 192 additions & 0 deletions internal/piipolicy/engine.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package piipolicy

import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
)
Comment on lines +17 to +23

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performance & Correctness: Thread-Safe Regex Cache

Regex compilation via regexp.Compile is an expensive operation. Currently, regexes are compiled on the fly during every string evaluation. Introducing a thread-safe cache using sync.Map avoids redundant compilation and significantly improves performance, especially when processing large datasets.

import (
	"context"
	"encoding/json"
	"fmt"
	"regexp"
	"strings"
	"sync"
	"unicode/utf8"
)

var regexCache sync.Map // map[string]*regexp.Regexp

func getRegexp(pattern string) (*regexp.Regexp, error) {
	if val, ok := regexCache.Load(pattern); ok {
		return val.(*regexp.Regexp), nil
	}
	re, err := regexp.Compile(pattern)
	if err != nil {
		return nil, err
	}
	regexCache.Store(pattern, re)
	return re, nil
}


// ApplyPolicy evaluates the PII policy against the provided data given the user's claims.
func ApplyPolicy(ctx context.Context, config Config, claims map[string]any, data any) (any, error) {
if len(config.Rules) == 0 {
return data, nil // Nothing to do
}

// Determine the user's tier based on the claim
tier := ""
if config.TierClaim != "" && claims != nil {
if val, ok := claims[config.TierClaim]; ok {
if strVal, ok := val.(string); ok {
tier = strVal
}
}
}

switch v := data.(type) {
case string:
return applyToString(v, config.Rules, tier)
case map[string]any:
return applyToMap(ctx, config, claims, tier, v)
case []map[string]any:
return applyToMapSlice(ctx, config, claims, tier, v)
case []any:
var out []any
for _, el := range v {
res, err := ApplyPolicy(ctx, config, claims, el)
if err != nil {
return nil, err
}
out = append(out, res)
}
return out, nil
default:
Comment on lines +57 to +58

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Correctness & Efficiency: Prevent Type Mutation of Primitives

Currently, any primitive type other than string (such as int, bool, float64, or nil) falls into the default case of the type switch, triggering a full json.Marshal and json.Unmarshal cycle. For integer types, this decodes them as float64, causing type mutation that can break downstream code expecting concrete integer types. Handling primitive types directly in the switch avoids this mutation and eliminates serialization overhead.

Suggested change
return out, nil
default:
return out, nil
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool, nil:
return data, nil
default:

// Attempt to convert struct to map/slice using JSON
bytes, err := json.Marshal(data)
if err == nil {
var parsed any
if err := json.Unmarshal(bytes, &parsed); err == nil {
// Prevent infinite loop if type doesn't change
if fmt.Sprintf("%T", parsed) != fmt.Sprintf("%T", data) {
return ApplyPolicy(ctx, config, claims, parsed)
}
}
}
// Unsupported type, return as is
return data, nil
}
}

func getActionForRule(rule Rule, tier string) Action {
if action, ok := rule.Actions[tier]; ok && tier != "" {
return action
}
if action, ok := rule.Actions["default"]; ok {
return action
}
// Fallback secure default
return MaskFull
}

func applyToString(data string, rules []Rule, tier string) (string, error) {
result := data
for _, rule := range rules {
if rule.Pattern == "" {
continue // Column-based rule, skip for unstructured text
}
action := getActionForRule(rule, tier)
if action == Unmask {
continue
}

re, err := regexp.Compile(rule.Pattern)
if err != nil {
return "", fmt.Errorf("invalid pattern in rule %q: %w", rule.Name, err)
}
Comment on lines +97 to +100

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performance: Use Cached Regex

Use the newly introduced getRegexp helper to retrieve the compiled regex from the cache instead of compiling it on every string evaluation.

Suggested change
re, err := regexp.Compile(rule.Pattern)
if err != nil {
return "", fmt.Errorf("invalid pattern in rule %q: %w", rule.Name, err)
}
re, err := getRegexp(rule.Pattern)
if err != nil {
return "", fmt.Errorf("invalid pattern in rule %q: %w", rule.Name, err)
}


result = re.ReplaceAllStringFunc(result, func(match string) string {
return applyActionToString(action, match)
})
}
return result, nil
}

func applyToMapSlice(ctx context.Context, config Config, claims map[string]any, tier string, data []map[string]any) ([]map[string]any, error) {
var out []map[string]any
for _, m := range data {
res, err := applyToMap(ctx, config, claims, tier, m)
if err != nil {
return nil, err
}
out = append(out, res)
}
return out, nil
}

func applyToMap(ctx context.Context, config Config, claims map[string]any, tier string, data map[string]any) (map[string]any, error) {
out := make(map[string]any, len(data))
for k, v := range data {
res, err := ApplyPolicy(ctx, config, claims, v)
if err != nil {
return nil, err
}
out[k] = res
}

for _, rule := range config.Rules {
if rule.Column == "" {
// Try applying pattern if value is a string, even in structured data
if rule.Pattern != "" {
for k, v := range out {
if strVal, ok := v.(string); ok {
action := getActionForRule(rule, tier)
if action == Unmask {
continue
}
re, err := regexp.Compile(rule.Pattern)
if err != nil {
return nil, fmt.Errorf("invalid pattern in rule %q: %w", rule.Name, err)
}
out[k] = re.ReplaceAllStringFunc(strVal, func(match string) string {
return applyActionToString(action, match)
})
}
}
}
continue
Comment on lines +134 to +151

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance: Hoist Action Check and Regex Compilation

In applyToMap, the action and regex pattern are constant for the entire rule. Hoisting the action check and regex compilation outside of the map key iteration loop avoids redundant lookups and regex retrievals for every key in the map.

Suggested change
if rule.Pattern != "" {
for k, v := range out {
if strVal, ok := v.(string); ok {
action := getActionForRule(rule, tier)
if action == Unmask {
continue
}
re, err := regexp.Compile(rule.Pattern)
if err != nil {
return nil, fmt.Errorf("invalid pattern in rule %q: %w", rule.Name, err)
}
out[k] = re.ReplaceAllStringFunc(strVal, func(match string) string {
return applyActionToString(action, match)
})
}
}
}
continue
if rule.Pattern != "" {
action := getActionForRule(rule, tier)
if action != Unmask {
re, err := getRegexp(rule.Pattern)
if err != nil {
return nil, fmt.Errorf("invalid pattern in rule %q: %w", rule.Name, err)
}
for k, v := range out {
if strVal, ok := v.(string); ok {
out[k] = re.ReplaceAllStringFunc(strVal, func(match string) string {
return applyActionToString(action, match)
})
}
}
}
}
continue

}

// Exact column match
if val, exists := out[rule.Column]; exists {
action := getActionForRule(rule, tier)
if action == Unmask {
continue
}
if action == DenyField {
out[rule.Column] = "[DENIED]"
continue
}

// Apply masking
if strVal, ok := val.(string); ok {
out[rule.Column] = applyActionToString(action, strVal)
} else {
// Non-string value masked entirely
out[rule.Column] = "***"
}
}
}
return out, nil
}

func applyActionToString(action Action, val string) string {
switch action {
case MaskFull:
return strings.Repeat("*", len(val))
case MaskPartial:
if len(val) <= 2 {
return strings.Repeat("*", len(val))
}
visibleLen := len(val) / 2
return val[:visibleLen] + strings.Repeat("*", len(val)-visibleLen)
case DenyField:
return "[DENIED]"
default:
return val // Unmask or unknown
}
}
Comment thread
Deeven-Seru marked this conversation as resolved.
Comment on lines +177 to +192

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Correctness: Handle Multi-Byte UTF-8 Characters Correctly

In Go, len(val) returns the number of bytes, not characters (runes). For multi-byte UTF-8 characters (e.g., accented characters, non-ASCII names, or emojis), strings.Repeat("*", len(val)) will produce more asterisks than the actual character count, leaking byte-length information.

Additionally, slicing a string using byte indices (val[:visibleLen]) can slice in the middle of a multi-byte UTF-8 character, producing invalid UTF-8 sequences. Converting the string to []rune and using utf8.RuneCountInString ensures correct character-based masking and slicing.

Suggested change
func applyActionToString(action Action, val string) string {
switch action {
case MaskFull:
return strings.Repeat("*", len(val))
case MaskPartial:
if len(val) <= 2 {
return strings.Repeat("*", len(val))
}
visibleLen := len(val) / 2
return val[:visibleLen] + strings.Repeat("*", len(val)-visibleLen)
case DenyField:
return "[DENIED]"
default:
return val // Unmask or unknown
}
}
func applyActionToString(action Action, val string) string {
switch action {
case MaskFull:
return strings.Repeat("*", utf8.RuneCountInString(val))
case MaskPartial:
runes := []rune(val)
if len(runes) <= 2 {
return strings.Repeat("*", len(runes))
}
visibleLen := len(runes) / 2
return string(runes[:visibleLen]) + strings.Repeat("*", len(runes)-visibleLen)
case DenyField:
return "[DENIED]"
default:
return val // Unmask or unknown
}
}

Loading