diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..0ed16cf --- /dev/null +++ b/PLAN.md @@ -0,0 +1,48 @@ +# SecretMasking Capability + +## Problem + +Without credential masking, secrets (API keys, tokens, connection strings, private keys) can leak through: +- Tool outputs (e.g., `git push` echoing credentials) +- Model responses (LLM reproducing secrets from context) +- Conversation history, logs, and serialized state + +Closes #78. + +## Design + +`SecretMasking` is an `AbstractCapability` that uses two hooks: + +1. **`after_tool_execute`** -- scrubs tool return values before they enter message history +2. **`after_model_request`** -- scrubs `TextPart` content in model responses + +### Built-in pattern categories + +| Category | Patterns | +|---|---| +| `api_keys` | OpenAI (`sk-*`), Anthropic (`sk-ant-*`), AWS (`AKIA*`), GitHub (`gh[psorat]_*`), Slack (`xox[bpas]-*`), Google (`AIza*`), generic `api_key=` | +| `tokens` | Bearer tokens, JWTs | +| `connection_strings` | Passwords in URLs (`://user:pass@host`), database connection strings (postgres, mongo, mysql, redis, amqp) | +| `private_keys` | RSA, EC, OpenSSH private key headers | + +All patterns are compiled at module level as constants. + +### Configuration + +- `categories`: select which built-in categories to enable (default: all) +- `custom_patterns`: additional `{name: regex}` pairs +- `replacement`: the replacement string (default: `"[REDACTED]"`) + +### Non-string tool results + +For string results, masking is applied directly. For non-string results, we convert to string to check for matches; if any secret is found, the masked string is returned instead of the original object (safe default -- the model sees the sanitized representation). + +## Scope + +This PR implements regex-based secret _masking_ (redaction). The broader SecretRegistry / encrypted storage / env-var blocking described in #78 are left for follow-up work, as they are infrastructure concerns rather than capability hooks. + +## Files + +- `src/pydantic_harness/secret_masking.py` -- capability implementation +- `src/pydantic_harness/__init__.py` -- public export +- `tests/test_secret_masking.py` -- 45 tests, 100% coverage diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..e29b8d6 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,8 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.secret_masking import SecretMasking + +__all__: list[str] = [ + 'SecretMasking', +] diff --git a/src/pydantic_harness/secret_masking.py b/src/pydantic_harness/secret_masking.py new file mode 100644 index 0000000..fe46e2c --- /dev/null +++ b/src/pydantic_harness/secret_masking.py @@ -0,0 +1,217 @@ +"""Secret masking capability that redacts secrets from tool outputs and model responses.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any, cast + +from pydantic_ai.capabilities import AbstractCapability, ValidatedToolArgs +from pydantic_ai.messages import ModelResponse, TextPart, ToolCallPart +from pydantic_ai.models import ModelRequestContext +from pydantic_ai.tools import RunContext, ToolDefinition + +# --- Built-in pattern categories --- + +_API_KEY_PATTERNS: dict[str, re.Pattern[str]] = { + 'openai_key': re.compile(r'sk-[A-Za-z0-9_-]{20,}'), + 'anthropic_key': re.compile(r'sk-ant-[A-Za-z0-9_-]{20,}'), + 'aws_access_key': re.compile(r'AKIA[0-9A-Z]{16}'), + 'github_token': re.compile(r'gh[psorat]_[A-Za-z0-9_]{36,}'), + 'slack_token': re.compile(r'xox[bpas]-[A-Za-z0-9-]+'), + 'google_api_key': re.compile(r'AIza[A-Za-z0-9_-]{35}'), + 'azure_subscription_key': re.compile(r'(?i)Ocp-Apim-Subscription-Key\s*[:=]\s*[A-Fa-f0-9]{32}'), + 'stripe_secret_key': re.compile(r'sk_live_[A-Za-z0-9]{24,}'), + 'stripe_publishable_key': re.compile(r'pk_live_[A-Za-z0-9]{24,}'), + 'sendgrid_key': re.compile(r'SG\.[A-Za-z0-9_-]{22,}\.[A-Za-z0-9_-]{22,}'), + 'twilio_key': re.compile(r'SK[0-9a-fA-F]{32}'), + 'gcp_service_account_key': re.compile(r'"private_key"\s*:\s*"-----BEGIN (?:RSA )?PRIVATE KEY-----'), + 'generic_api_key': re.compile( + r"""(?i)(?:api[_-]?key|api[_-]?secret|access[_-]?key)\s*[:=]\s*['"]?[A-Za-z0-9_\-/+=]{16,}['"]?""" + ), +} + +_TOKEN_PATTERNS: dict[str, re.Pattern[str]] = { + 'bearer_token': re.compile(r'Bearer\s+[A-Za-z0-9_\-./+=]{20,}'), + 'jwt': re.compile(r'eyJ[A-Za-z0-9_-]+\.eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_\-+=]+'), +} + +_CONNECTION_STRING_PATTERNS: dict[str, re.Pattern[str]] = { + 'password_in_url': re.compile(r'://[^:/?#\s]+:[^@/?#\s]+@[^/?#\s]+'), + 'database_connection': re.compile(r'(?i)(?:mongodb(?:\+srv)?|postgres(?:ql)?|mysql|redis|amqp)://[^\s]+'), +} + +_PRIVATE_KEY_PATTERNS: dict[str, re.Pattern[str]] = { + 'private_key': re.compile(r'-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----'), +} + +_ENV_FILE_PATTERNS: dict[str, re.Pattern[str]] = { + 'env_key_value': re.compile(r'(?m)^[A-Z][A-Z0-9_]+=.+$'), +} + +_BUILTIN_CATEGORIES: dict[str, dict[str, re.Pattern[str]]] = { + 'api_keys': _API_KEY_PATTERNS, + 'tokens': _TOKEN_PATTERNS, + 'connection_strings': _CONNECTION_STRING_PATTERNS, + 'private_keys': _PRIVATE_KEY_PATTERNS, + 'env_file': _ENV_FILE_PATTERNS, +} + +_ALL_BUILTIN_PATTERNS: dict[str, re.Pattern[str]] = {} +for _patterns in _BUILTIN_CATEGORIES.values(): + _ALL_BUILTIN_PATTERNS.update(_patterns) + + +def _mask_text(text: str, patterns: dict[str, re.Pattern[str]], replacement: str) -> str: + """Apply all patterns to replace matched secrets in `text`.""" + for pattern in patterns.values(): + text = pattern.sub(replacement, text) + return text + + +def _partial_mask_text(text: str, patterns: dict[str, re.Pattern[str]], visible_chars: int = 4) -> str: + """Apply all patterns, keeping the first `visible_chars` characters and masking the rest.""" + for pattern in patterns.values(): + text = pattern.sub( + lambda m: m.group()[:visible_chars] + '****' if len(m.group()) > visible_chars else '****', text + ) + return text + + +def _mask_dict_values( + d: dict[str, Any], + patterns: dict[str, re.Pattern[str]], + replacement: str, + *, + partial: bool = False, + visible_chars: int = 4, +) -> dict[str, Any]: + """Recursively scrub secret patterns from string values in a dict.""" + result: dict[str, Any] = {} + for key, value in d.items(): + if isinstance(value, str): + if partial: + result[key] = _partial_mask_text(value, patterns, visible_chars) + else: + result[key] = _mask_text(value, patterns, replacement) + elif isinstance(value, dict): + result[key] = _mask_dict_values( + cast(dict[str, Any], value), patterns, replacement, partial=partial, visible_chars=visible_chars + ) + else: + result[key] = value + return result + + +@dataclass +class SecretMasking(AbstractCapability[Any]): + """Redacts secrets, API keys, and sensitive data from tool args, outputs, and model responses. + + Uses `before_tool_execute` to scrub secrets from tool arguments, + `after_tool_execute` to scrub tool return values, and `after_model_request` + to scrub model response text before they enter the conversation history. + + By default all built-in pattern categories are enabled: `api_keys`, `tokens`, + `connection_strings`, `private_keys`, and `env_file`. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_harness import SecretMasking + + agent = Agent('openai:gpt-5', capabilities=[SecretMasking()]) + ``` + """ + + categories: list[str] | None = None + """Built-in pattern categories to enable. + + Choose from `'api_keys'`, `'tokens'`, `'connection_strings'`, `'private_keys'`, `'env_file'`. + When `None` (default), all categories are enabled. + """ + + custom_patterns: dict[str, str] | None = None + """Additional regex patterns as `{name: pattern}` pairs. + + These are compiled once at init time and applied alongside the built-in patterns. + """ + + replacement: str = '[REDACTED]' + """The string that replaces matched secrets.""" + + partial_mask: bool = False + """When True, keep the first 4 characters of matched secrets visible and mask the rest + (e.g. ``sk-pr****`` instead of ``[REDACTED]``). The ``replacement`` field is ignored + when partial masking is enabled.""" + + _compiled: dict[str, re.Pattern[str]] = field(default_factory=lambda: {}, init=False, repr=False) + + def __post_init__(self) -> None: + """Compile built-in and custom patterns.""" + if self.categories is not None: + for category in self.categories: + if category not in _BUILTIN_CATEGORIES: + raise ValueError( + f'Unknown secret pattern category {category!r}, expected one of {sorted(_BUILTIN_CATEGORIES)}' + ) + self._compiled.update(_BUILTIN_CATEGORIES[category]) + else: + self._compiled.update(_ALL_BUILTIN_PATTERNS) + + if self.custom_patterns: + for name, pattern in self.custom_patterns.items(): + self._compiled[name] = re.compile(pattern) + + def _apply_mask(self, text: str) -> str: + """Mask secrets in ``text`` using full or partial masking depending on config.""" + if self.partial_mask: + return _partial_mask_text(text, self._compiled) + return _mask_text(text, self._compiled, self.replacement) + + async def before_tool_execute( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: ValidatedToolArgs, + ) -> ValidatedToolArgs: + """Scrub secrets from tool call argument values before the tool executes.""" + return _mask_dict_values( + args, + self._compiled, + self.replacement, + partial=self.partial_mask, + ) + + async def after_tool_execute( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: ValidatedToolArgs, + result: Any, + ) -> Any: + """Scrub secrets from tool return values.""" + if isinstance(result, str): + return self._apply_mask(result) + # For non-string results, convert to string to check, but only replace if secrets found. + text = str(result) + masked = self._apply_mask(text) + if masked != text: + return masked + return result + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + """Scrub secrets from model response text parts.""" + for part in response.parts: + if isinstance(part, TextPart): + part.content = self._apply_mask(part.content) + return response diff --git a/tests/test_secret_masking.py b/tests/test_secret_masking.py new file mode 100644 index 0000000..7fd9810 --- /dev/null +++ b/tests/test_secret_masking.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import re +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic_ai.messages import ModelResponse, TextPart, ToolCallPart + +from pydantic_harness.secret_masking import ( + _ALL_BUILTIN_PATTERNS, + _BUILTIN_CATEGORIES, + SecretMasking, + _mask_dict_values, + _mask_text, + _partial_mask_text, +) + +# --- Unit tests for _mask_text --- + + +class TestMaskText: + def test_no_match_returns_original(self): + assert _mask_text('hello world', _ALL_BUILTIN_PATTERNS, '[REDACTED]') == 'hello world' + + def test_openai_key(self): + text = 'key is sk-abc123def456ghi789jkl012mno' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'sk-abc123' not in result + assert '[REDACTED]' in result + + def test_anthropic_key(self): + text = 'sk-ant-api03-abcdefghijklmnopqrstuvwxyz' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'sk-ant-' not in result + assert result == '[REDACTED]' + + def test_aws_access_key(self): + text = 'AWS key: AKIAIOSFODNN7EXAMPLE' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'AKIA' not in result + assert '[REDACTED]' in result + + def test_github_token(self): + text = 'token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'ghp_' not in result + + def test_slack_token(self): + text = 'xoxb-123456789012-1234567890123-AbCdEfGhIjKlMnOpQrStUvWx' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'xoxb-' not in result + + def test_google_api_key(self): + text = 'AIzaSyD-abcdefghijklmnopqrstuvwxyz01234' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'AIza' not in result + + def test_generic_api_key(self): + text = 'api_key = "abcdef1234567890abcdef"' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'abcdef1234567890' not in result + + def test_bearer_token(self): + text = 'Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'Bearer eyJ' not in result + + def test_jwt(self): + text = 'eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.abc123def456' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'eyJ' not in result + + def test_password_in_url(self): + text = 'postgresql://admin:s3cret_pass@db.example.com:5432/mydb' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 's3cret_pass' not in result + + def test_database_connection_string(self): + text = 'mongodb+srv://user:pass@cluster.mongodb.net/db' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'user:pass' not in result + + def test_private_key(self): + text = '-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAK...' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert '-----BEGIN RSA PRIVATE KEY-----' not in result + + def test_ec_private_key(self): + text = '-----BEGIN EC PRIVATE KEY-----\ndata...' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert '-----BEGIN EC PRIVATE KEY-----' not in result + + def test_openssh_private_key(self): + text = '-----BEGIN OPENSSH PRIVATE KEY-----\ndata...' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert '-----BEGIN OPENSSH PRIVATE KEY-----' not in result + + def test_multiple_secrets_in_one_string(self): + text = 'key=sk-abc123def456ghi789jkl012mno, token=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'sk-abc123' not in result + assert 'ghp_' not in result + assert result.count('[REDACTED]') >= 2 + + def test_custom_replacement(self): + text = 'sk-abc123def456ghi789jkl012mno' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '***') + assert result == '***' + + # --- New provider key patterns --- + + def test_azure_subscription_key(self): + text = 'Ocp-Apim-Subscription-Key: abcdef1234567890abcdef1234567890' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'abcdef1234567890' not in result + assert '[REDACTED]' in result + + def test_stripe_secret_key(self): + text = 'sk_live_abcdefghijklmnopqrstuvwx' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'sk_live_' not in result + assert result == '[REDACTED]' + + def test_stripe_publishable_key(self): + text = 'pk_live_abcdefghijklmnopqrstuvwx' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'pk_live_' not in result + assert result == '[REDACTED]' + + def test_sendgrid_key(self): + text = 'SG.abcdefghijklmnopqrstuv.abcdefghijklmnopqrstuvwx' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'SG.' not in result + assert result == '[REDACTED]' + + def test_twilio_key(self): + text = 'SK0123456789abcdef0123456789abcdef' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'SK01234' not in result + assert '[REDACTED]' in result + + def test_gcp_service_account_key(self): + text = '"private_key": "-----BEGIN RSA PRIVATE KEY-----\\nMIIE..."' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert '-----BEGIN RSA PRIVATE KEY-----' not in result + assert '[REDACTED]' in result + + # --- .env content detection --- + + def test_env_key_value_single_line(self): + text = 'DATABASE_URL=postgres://user:pass@localhost/db' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'DATABASE_URL=' not in result + + def test_env_key_value_multiline(self): + text = 'API_KEY=some_secret_value\nDB_PASSWORD=hunter2\nDEBUG=true' + result = _mask_text(text, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert 'some_secret_value' not in result + assert 'hunter2' not in result + + def test_env_key_value_does_not_match_lowercase(self): + """Lowercase variable names are not typical .env format and should not match.""" + patterns = _BUILTIN_CATEGORIES['env_file'] + text = 'lowercase_var=value' + result = _mask_text(text, patterns, '[REDACTED]') + assert result == text + + +# --- Unit tests for _partial_mask_text --- + + +class TestPartialMaskText: + def test_partial_mask_openai_key(self): + text = 'sk-abc123def456ghi789jkl012mno' + result = _partial_mask_text(text, _ALL_BUILTIN_PATTERNS) + assert result.startswith('sk-a') + assert result.endswith('****') + assert 'abc123def456' not in result + + def test_partial_mask_preserves_surrounding_text(self): + text = 'key is sk-abc123def456ghi789jkl012mno here' + result = _partial_mask_text(text, _ALL_BUILTIN_PATTERNS) + assert result.startswith('key is sk-a') + assert 'here' in result + + def test_partial_mask_short_match_becomes_stars(self): + """When matched text is 4 chars or fewer, the whole thing becomes ****.""" + patterns = {'short': re.compile(r'AB')} + result = _partial_mask_text('xABx', patterns) + assert result == 'x****x' + + +# --- Unit tests for _mask_dict_values --- + + +class TestMaskDictValues: + def test_masks_string_values(self): + d = {'key': 'sk-abc123def456ghi789jkl012mno', 'name': 'safe'} + result = _mask_dict_values(d, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert result['key'] == '[REDACTED]' + assert result['name'] == 'safe' + + def test_masks_nested_dicts(self): + d = {'outer': {'inner_key': 'sk-abc123def456ghi789jkl012mno'}} + result = _mask_dict_values(d, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert result['outer']['inner_key'] == '[REDACTED]' + + def test_non_string_values_unchanged(self): + d: dict[str, Any] = {'count': 42, 'flag': True, 'items': [1, 2]} + result = _mask_dict_values(d, _ALL_BUILTIN_PATTERNS, '[REDACTED]') + assert result == d + + def test_partial_mask_in_dict(self): + d = {'token': 'sk-abc123def456ghi789jkl012mno'} + result = _mask_dict_values(d, _ALL_BUILTIN_PATTERNS, '[REDACTED]', partial=True) + assert result['token'].startswith('sk-a') + assert result['token'].endswith('****') + + def test_empty_dict(self): + assert _mask_dict_values({}, _ALL_BUILTIN_PATTERNS, '[REDACTED]') == {} + + +# --- Tests for SecretMasking dataclass construction --- + + +class TestSecretMaskingInit: + def test_defaults(self): + sm = SecretMasking() + assert sm.categories is None + assert sm.custom_patterns is None + assert sm.replacement == '[REDACTED]' + assert sm.partial_mask is False + assert sm._compiled == _ALL_BUILTIN_PATTERNS + + def test_specific_categories(self): + sm = SecretMasking(categories=['api_keys', 'tokens']) + expected = {**_BUILTIN_CATEGORIES['api_keys'], **_BUILTIN_CATEGORIES['tokens']} + assert sm._compiled == expected + + def test_single_category(self): + sm = SecretMasking(categories=['private_keys']) + assert sm._compiled == _BUILTIN_CATEGORIES['private_keys'] + + def test_unknown_category_raises(self): + with pytest.raises(ValueError, match="Unknown secret pattern category 'bogus'"): + SecretMasking(categories=['bogus']) + + def test_custom_patterns(self): + sm = SecretMasking(custom_patterns={'my_secret': r'SECRET-\d{6}'}) + assert 'my_secret' in sm._compiled + assert sm._compiled['my_secret'].pattern == r'SECRET-\d{6}' + + def test_custom_patterns_with_categories(self): + sm = SecretMasking(categories=['api_keys'], custom_patterns={'my_secret': r'SECRET-\d{6}'}) + assert 'openai_key' in sm._compiled + assert 'my_secret' in sm._compiled + assert 'bearer_token' not in sm._compiled + + def test_custom_replacement(self): + sm = SecretMasking(replacement='') + assert sm.replacement == '' + + def test_partial_mask_flag(self): + sm = SecretMasking(partial_mask=True) + assert sm.partial_mask is True + + def test_env_file_category(self): + sm = SecretMasking(categories=['env_file']) + assert 'env_key_value' in sm._compiled + + +# --- Tests for before_tool_execute --- + + +class TestBeforeToolExecute: + @pytest.fixture() + def capability(self) -> SecretMasking: + return SecretMasking() + + @pytest.fixture() + def ctx(self) -> Any: + return MagicMock() + + @pytest.fixture() + def call(self) -> Any: + return MagicMock() + + @pytest.fixture() + def tool_def(self) -> Any: + return MagicMock() + + @pytest.mark.anyio() + async def test_scrubs_secret_in_args(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + args = {'api_key': 'sk-abc123def456ghi789jkl012mno', 'query': 'hello'} + result = await capability.before_tool_execute(ctx, call=call, tool_def=tool_def, args=args) + assert result['api_key'] == '[REDACTED]' + assert result['query'] == 'hello' + + @pytest.mark.anyio() + async def test_scrubs_nested_dict_args(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + args: dict[str, Any] = {'config': {'token': 'sk-abc123def456ghi789jkl012mno'}, 'name': 'test'} + result = await capability.before_tool_execute(ctx, call=call, tool_def=tool_def, args=args) + assert result['config']['token'] == '[REDACTED]' + assert result['name'] == 'test' + + @pytest.mark.anyio() + async def test_no_secrets_unchanged(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + args = {'query': 'hello world', 'count': 5} + result = await capability.before_tool_execute(ctx, call=call, tool_def=tool_def, args=args) + assert result == args + + @pytest.mark.anyio() + async def test_partial_mask_in_args(self, ctx: Any, call: Any, tool_def: Any): + capability = SecretMasking(partial_mask=True) + args = {'key': 'sk-abc123def456ghi789jkl012mno'} + result = await capability.before_tool_execute(ctx, call=call, tool_def=tool_def, args=args) + assert result['key'].startswith('sk-a') + assert result['key'].endswith('****') + + @pytest.mark.anyio() + async def test_empty_args(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + result = await capability.before_tool_execute(ctx, call=call, tool_def=tool_def, args={}) + assert result == {} + + +# --- Tests for after_tool_execute --- + + +class TestAfterToolExecute: + @pytest.fixture() + def capability(self) -> SecretMasking: + return SecretMasking() + + @pytest.fixture() + def ctx(self) -> Any: + return MagicMock() + + @pytest.fixture() + def call(self) -> Any: + return MagicMock() + + @pytest.fixture() + def tool_def(self) -> Any: + return MagicMock() + + @pytest.mark.anyio() + async def test_string_result_with_secret(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + result = await capability.after_tool_execute( + ctx, call=call, tool_def=tool_def, args={}, result='key: sk-abc123def456ghi789jkl012mno' + ) + assert isinstance(result, str) + assert 'sk-abc123' not in result + assert '[REDACTED]' in result + + @pytest.mark.anyio() + async def test_string_result_without_secret(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + result = await capability.after_tool_execute(ctx, call=call, tool_def=tool_def, args={}, result='hello world') + assert result == 'hello world' + + @pytest.mark.anyio() + async def test_non_string_result_with_secret(self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any): + result = await capability.after_tool_execute( + ctx, call=call, tool_def=tool_def, args={}, result=['key', 'sk-abc123def456ghi789jkl012mno'] + ) + assert isinstance(result, str) + assert 'sk-abc123' not in result + + @pytest.mark.anyio() + async def test_non_string_result_without_secret( + self, capability: SecretMasking, ctx: Any, call: Any, tool_def: Any + ): + result = await capability.after_tool_execute( + ctx, call=call, tool_def=tool_def, args={}, result={'status': 'ok'} + ) + assert result == {'status': 'ok'} + + @pytest.mark.anyio() + async def test_custom_replacement(self, ctx: Any, call: Any, tool_def: Any): + capability = SecretMasking(replacement='') + result = await capability.after_tool_execute( + ctx, call=call, tool_def=tool_def, args={}, result='sk-abc123def456ghi789jkl012mno' + ) + assert result == '' + + @pytest.mark.anyio() + async def test_custom_pattern(self, ctx: Any, call: Any, tool_def: Any): + capability = SecretMasking(categories=[], custom_patterns={'internal': r'INT-[A-Z]{8}'}) + result = await capability.after_tool_execute( + ctx, call=call, tool_def=tool_def, args={}, result='secret: INT-ABCDEFGH' + ) + assert 'INT-ABCDEFGH' not in result + assert '[REDACTED]' in result + + @pytest.mark.anyio() + async def test_partial_mask_in_tool_result(self, ctx: Any, call: Any, tool_def: Any): + capability = SecretMasking(partial_mask=True) + result = await capability.after_tool_execute( + ctx, call=call, tool_def=tool_def, args={}, result='sk-abc123def456ghi789jkl012mno' + ) + assert isinstance(result, str) + assert result.startswith('sk-a') + assert result.endswith('****') + + +# --- Tests for after_model_request --- + + +class TestAfterModelRequest: + @pytest.fixture() + def capability(self) -> SecretMasking: + return SecretMasking() + + @pytest.fixture() + def ctx(self) -> Any: + return MagicMock() + + @pytest.fixture() + def request_context(self) -> Any: + return MagicMock() + + def _make_response(self, *texts: str) -> ModelResponse: + return ModelResponse(parts=[TextPart(content=t) for t in texts]) + + @pytest.mark.anyio() + async def test_scrubs_text_parts(self, capability: SecretMasking, ctx: Any, request_context: Any): + response = self._make_response('Your key is sk-abc123def456ghi789jkl012mno') + result = await capability.after_model_request(ctx, request_context=request_context, response=response) + assert isinstance(result.parts[0], TextPart) + assert 'sk-abc123' not in result.parts[0].content + assert '[REDACTED]' in result.parts[0].content + + @pytest.mark.anyio() + async def test_clean_text_unchanged(self, capability: SecretMasking, ctx: Any, request_context: Any): + response = self._make_response('No secrets here') + result = await capability.after_model_request(ctx, request_context=request_context, response=response) + assert isinstance(result.parts[0], TextPart) + assert result.parts[0].content == 'No secrets here' + + @pytest.mark.anyio() + async def test_multiple_parts(self, capability: SecretMasking, ctx: Any, request_context: Any): + response = self._make_response( + 'key: AKIAIOSFODNN7EXAMPLE', + 'clean text', + 'token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn', + ) + result = await capability.after_model_request(ctx, request_context=request_context, response=response) + parts = result.parts + assert isinstance(parts[0], TextPart) + assert 'AKIA' not in parts[0].content + assert isinstance(parts[1], TextPart) + assert parts[1].content == 'clean text' + assert isinstance(parts[2], TextPart) + assert 'ghp_' not in parts[2].content + + @pytest.mark.anyio() + async def test_non_text_parts_are_untouched(self, capability: SecretMasking, ctx: Any, request_context: Any): + tool_call = ToolCallPart(tool_name='get_secret', args='{}') + response = ModelResponse(parts=[tool_call]) + result = await capability.after_model_request(ctx, request_context=request_context, response=response) + assert result.parts[0] is tool_call + + @pytest.mark.anyio() + async def test_partial_mask_in_model_response(self, ctx: Any, request_context: Any): + capability = SecretMasking(partial_mask=True) + response = ModelResponse(parts=[TextPart(content='key: sk-abc123def456ghi789jkl012mno')]) + result = await capability.after_model_request(ctx, request_context=request_context, response=response) + part = result.parts[0] + assert isinstance(part, TextPart) + assert part.content.startswith('key: sk-a') + assert part.content.endswith('****') + + +# --- Test pattern categories --- + + +class TestPatternCategories: + def test_all_categories_exist(self): + assert set(_BUILTIN_CATEGORIES) == { + 'api_keys', + 'tokens', + 'connection_strings', + 'private_keys', + 'env_file', + } + + def test_api_keys_category(self): + patterns = _BUILTIN_CATEGORIES['api_keys'] + assert 'openai_key' in patterns + assert 'anthropic_key' in patterns + assert 'aws_access_key' in patterns + assert 'github_token' in patterns + assert 'slack_token' in patterns + assert 'google_api_key' in patterns + assert 'generic_api_key' in patterns + assert 'azure_subscription_key' in patterns + assert 'stripe_secret_key' in patterns + assert 'stripe_publishable_key' in patterns + assert 'sendgrid_key' in patterns + assert 'twilio_key' in patterns + assert 'gcp_service_account_key' in patterns + + def test_tokens_category(self): + patterns = _BUILTIN_CATEGORIES['tokens'] + assert 'bearer_token' in patterns + assert 'jwt' in patterns + + def test_connection_strings_category(self): + patterns = _BUILTIN_CATEGORIES['connection_strings'] + assert 'password_in_url' in patterns + assert 'database_connection' in patterns + + def test_private_keys_category(self): + patterns = _BUILTIN_CATEGORIES['private_keys'] + assert 'private_key' in patterns + + def test_env_file_category(self): + patterns = _BUILTIN_CATEGORIES['env_file'] + assert 'env_key_value' in patterns + + def test_all_builtin_is_union_of_categories(self): + expected: dict[str, re.Pattern[str]] = {} + for cat_patterns in _BUILTIN_CATEGORIES.values(): + expected.update(cat_patterns) + assert _ALL_BUILTIN_PATTERNS == expected + + +# --- Edge cases --- + + +class TestEdgeCases: + def test_empty_categories_list_with_custom(self): + sm = SecretMasking(categories=[], custom_patterns={'test': r'TEST-\d+'}) + # Only custom patterns, no builtins. + assert 'test' in sm._compiled + assert 'openai_key' not in sm._compiled + + def test_empty_categories_no_custom(self): + sm = SecretMasking(categories=[]) + assert sm._compiled == {} + + @pytest.mark.anyio() + async def test_empty_string_tool_result(self): + sm = SecretMasking() + ctx = MagicMock() + result = await sm.after_tool_execute(ctx, call=MagicMock(), tool_def=MagicMock(), args={}, result='') + assert result == '' + + @pytest.mark.anyio() + async def test_none_tool_result(self): + sm = SecretMasking() + ctx = MagicMock() + result = await sm.after_tool_execute(ctx, call=MagicMock(), tool_def=MagicMock(), args={}, result=None) + assert result is None