diff --git a/docs/docs/installation/github.md b/docs/docs/installation/github.md index cb15073dd7..3ca27ca9f7 100644 --- a/docs/docs/installation/github.md +++ b/docs/docs/installation/github.md @@ -280,6 +280,35 @@ To use local models via Ollama: **Note:** For local models, you'll need to use a self-hosted runner with Ollama installed, as GitHub Actions hosted runners cannot access localhost services. +##### Using Amazon Bedrock + +To use Amazon Bedrock models with static IAM credentials: + +```yaml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + config.model: "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + config.fallback_models: '["bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"]' + aws.AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws.AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws.AWS_REGION_NAME: "us-east-1" +``` + +**Recommended: IAM Role Credentials on AWS Compute** + +When the GitHub Actions runner is on AWS infrastructure (EC2, ECS, EKS), use the instance/task IAM role directly — no secrets required: + +```yaml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + config.model: "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + config.fallback_models: '["bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"]' + AWS_USE_IMDS: "true" + # AWS_REGION_NAME: us-east-1 # optional if instance metadata provides the region +``` + +The IAM role must have `bedrock:InvokeModel` on the target model ARN. See [Bedrock model configuration](../usage-guide/changing_a_model.md#amazon-bedrock) for the full IAM policy example and supported models. + #### Advanced Configuration Options ##### Custom Review Instructions @@ -732,4 +761,4 @@ After you set up AWS CodeCommit using the instructions above, here is an example PYTHONPATH="/PATH/TO/PROJECTS/pr-agent" python pr_agent/cli.py \ --pr_url https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/MY_REPO_NAME/pull-requests/321 \ review -``` \ No newline at end of file +``` diff --git a/docs/docs/usage-guide/changing_a_model.md b/docs/docs/usage-guide/changing_a_model.md index 140a16c612..5abbd62680 100644 --- a/docs/docs/usage-guide/changing_a_model.md +++ b/docs/docs/usage-guide/changing_a_model.md @@ -111,7 +111,7 @@ Please note that the `custom_model_max_tokens` setting should be configured in a Commercial models such as GPT-5, Claude Sonnet, and Gemini have demonstrated robust capabilities in generating structured output for code analysis tasks with large input. In contrast, most open-source models currently available (as of January 2025) face challenges with these complex tasks. Based on our testing, local open-source models are suitable for experimentation and learning purposes (mainly for the `ask` command), but they are not suitable for production-level code analysis tasks. - + Hence, for production workflows and real-world usage, we recommend using commercial models. ### Hugging Face @@ -251,6 +251,43 @@ model="bedrock/us.meta.llama4-scout-17b-instruct-v1:0" fallback_models=["bedrock/us.meta.llama4-maverick-17b-instruct-v1:0"] ``` +#### Using IAM Role Credentials (Recommended on AWS Compute) + +When running PR-Agent on AWS infrastructure (EC2, ECS/Fargate, EKS with IRSA, Lambda, or any self-hosted GitHub Actions runner on AWS), the instance or task already has an IAM role attached. You can use those ambient credentials directly instead of storing long-lived static keys. + +Set `AWS_USE_IMDS=true` in the environment. PR-Agent will resolve credentials via boto3's standard provider chain, which handles all AWS compute contexts transparently: + +| Compute context | Mechanism | +|---|---| +| EC2 instance with IAM role | IMDSv2 (169.254.169.254) | +| ECS / Fargate task role | Task metadata endpoint | +| EKS pod with IRSA | Web identity token + STS | +| Lambda function | Runtime-injected credentials | + +Minimal GitHub Actions workflow (no AWS secret keys required): + +```yaml +- uses: Codium-ai/pr-agent@main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + AWS_USE_IMDS: "true" + # AWS_REGION_NAME: us-east-1 # optional if the instance metadata provides it + with: + command: review +``` + +The IAM role must have `bedrock:InvokeModel` permission on the target model ARN, for example: + +```json +{ + "Effect": "Allow", + "Action": "bedrock:InvokeModel", + "Resource": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20240620-v1:0" +} +``` + +If you also configure static keys in `[aws]`, they serve as an automatic fallback: if the ambient credentials fail a Bedrock call (e.g., the role lacks `bedrock:InvokeModel`), PR-Agent retries with the static keys and logs a warning. + #### Custom Inference Profiles To use a custom inference profile with Amazon Bedrock (for cost allocation tags and other configuration settings), add the `model_id` parameter to your configuration: @@ -339,7 +376,7 @@ key = "..." # your Codestral api key To use model from Openrouter, for example, set: ```toml -[config] # in configuration.toml +[config] # in configuration.toml model="openrouter/anthropic/claude-3.7-sonnet" fallback_models=["openrouter/deepseek/deepseek-chat"] custom_model_max_tokens=20000 diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index a035d08fe8..2be64c265a 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -159,12 +159,22 @@ 'bedrock/anthropic.claude-sonnet-4-20250514-v1:0': 200000, 'bedrock/anthropic.claude-sonnet-4-5-20250929-v1:0': 200000, 'bedrock/anthropic.claude-sonnet-4-6': 200000, + 'bedrock/anthropic.claude-sonnet-4-6-v1:0': 200000, + 'bedrock/anthropic.claude-opus-4-5-20251101-v1:0': 200000, "bedrock/us.anthropic.claude-opus-4-20250514-v1:0": 200000, "bedrock/us.anthropic.claude-opus-4-1-20250805-v1:0": 200000, "bedrock/us.anthropic.claude-opus-4-6-20260120-v1:0": 200000, "bedrock/global.anthropic.claude-opus-4-5-20251101-v1:0": 200000, + "bedrock/eu.anthropic.claude-opus-4-5-20251101-v1:0": 200000, + "bedrock/au.anthropic.claude-opus-4-5-20251101-v1:0": 200000, + "bedrock/jp.anthropic.claude-opus-4-5-20251101-v1:0": 200000, + "bedrock/apac.anthropic.claude-opus-4-5-20251101-v1:0": 200000, "bedrock/us.anthropic.claude-opus-4-5-20251101-v1:0": 200000, "bedrock/global.anthropic.claude-opus-4-6-v1:0": 200000, + "bedrock/eu.anthropic.claude-opus-4-6-v1:0": 200000, + "bedrock/au.anthropic.claude-opus-4-6-v1:0": 200000, + "bedrock/jp.anthropic.claude-opus-4-6-v1:0": 200000, + "bedrock/apac.anthropic.claude-opus-4-6-v1:0": 200000, "bedrock/us.anthropic.claude-opus-4-6-v1:0": 200000, "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0": 100000, "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0": 200000, @@ -179,16 +189,23 @@ "bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0": 200000, "bedrock/au.anthropic.claude-sonnet-4-5-20250929-v1:0": 200000, "bedrock/us.anthropic.claude-sonnet-4-6": 200000, + "bedrock/us.anthropic.claude-sonnet-4-6-v1:0": 200000, "bedrock/au.anthropic.claude-sonnet-4-6": 200000, + "bedrock/au.anthropic.claude-sonnet-4-6-v1:0": 200000, "bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0": 100000, "bedrock/apac.anthropic.claude-3-7-sonnet-20250219-v1:0": 200000, "bedrock/apac.anthropic.claude-sonnet-4-20250514-v1:0": 200000, "bedrock/eu.anthropic.claude-sonnet-4-5-20250929-v1:0": 200000, "bedrock/eu.anthropic.claude-sonnet-4-6": 200000, + "bedrock/eu.anthropic.claude-sonnet-4-6-v1:0": 200000, "bedrock/jp.anthropic.claude-sonnet-4-5-20250929-v1:0": 200000, "bedrock/jp.anthropic.claude-sonnet-4-6": 200000, + "bedrock/jp.anthropic.claude-sonnet-4-6-v1:0": 200000, + "bedrock/apac.anthropic.claude-sonnet-4-6": 200000, + "bedrock/apac.anthropic.claude-sonnet-4-6-v1:0": 200000, "bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0": 200000, "bedrock/global.anthropic.claude-sonnet-4-6": 200000, + "bedrock/global.anthropic.claude-sonnet-4-6-v1:0": 200000, 'claude-3-5-sonnet': 100000, 'bedrock/us.meta.llama4-scout-17b-instruct-v1:0': 128000, 'bedrock/us.meta.llama4-maverick-17b-instruct-v1:0': 128000, diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index de9993284d..6aca90f0ee 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -1,18 +1,27 @@ +import asyncio +import contextlib +import json import os + import litellm import openai import requests from litellm import acompletion -from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt - -from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS +from tenacity import (retry, retry_if_exception_type, + retry_if_not_exception_type, stop_after_attempt) + +from pr_agent.algo import (CLAUDE_EXTENDED_THINKING_MODELS, + NO_SUPPORT_TEMPERATURE_MODELS, + STREAMING_REQUIRED_MODELS, + SUPPORT_REASONING_EFFORT_MODELS, + USER_MESSAGE_ONLY_MODELS) from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler -from pr_agent.algo.ai_handlers.litellm_helpers import _handle_streaming_response, MockResponse, _get_azure_ad_token, \ - _process_litellm_extra_body +from pr_agent.algo.ai_handlers.litellm_helpers import ( + MockResponse, _get_azure_ad_token, _handle_streaming_response, + _process_litellm_extra_body) from pr_agent.algo.utils import ReasoningEffort, get_version from pr_agent.config_loader import get_settings from pr_agent.log import get_logger -import json MODEL_RETRIES = 2 DUMMY_LITELLM_API_KEY = "dummy_key" # placeholder set when no OpenAI key is configured @@ -33,6 +42,11 @@ def __init__(self): self.azure = False self.api_base = None self.repetition_penalty = None + self._aws_imds_mode = False + self._aws_static_creds = None + self._aws_imds_fell_back = False + self._aws_boto3_creds = None # original boto3 credentials object for IMDS refresh + self._aws_bedrock_lock = asyncio.Lock() if get_settings().get("LITELLM.DISABLE_AIOHTTP", False): litellm.disable_aiohttp_transport = True @@ -41,11 +55,82 @@ def __init__(self): litellm.openai_key = get_settings().openai.key elif 'OPENAI_API_KEY' not in os.environ: litellm.api_key = DUMMY_LITELLM_API_KEY - if get_settings().get("aws.AWS_ACCESS_KEY_ID"): + if os.environ.get("AWS_USE_IMDS", "").strip().lower() in ("1", "true", "yes"): + import boto3 + import botocore.exceptions + session = boto3.Session() + try: + creds = session.get_credentials() + if creds: + self._aws_boto3_creds = creds # store for refresh; avoids env-var re-read + self._write_frozen_aws_creds_to_env(creds.get_frozen_credentials()) + self._aws_imds_mode = True + get_logger().info("Using ambient AWS credentials from IMDS/task-role/IRSA") + else: + get_logger().warning( + "AWS_USE_IMDS is set but boto3 found no credentials; " + "falling through to static keys" + ) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError): + # ClientError is intentionally not a BotoCoreError subclass in botocore's + # design; it is raised by STS-backed providers (AssumeRole, IRSA web-identity). + get_logger().exception( + "AWS_USE_IMDS: failed to resolve credentials via boto3; " + "falling through to static keys" + ) + if not os.environ.get("AWS_REGION_NAME"): + if get_settings().get("aws.AWS_REGION_NAME"): + os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME + else: + try: + region = session.region_name + if region: + os.environ["AWS_REGION_NAME"] = region + get_logger().info(f"AWS region resolved from environment: {region}") + else: + get_logger().warning( + "AWS_USE_IMDS: could not determine AWS region; " + "set AWS_REGION_NAME explicitly" + ) + except Exception as e: + get_logger().warning(f"AWS_USE_IMDS: failed to resolve region via boto3: {e}") + if get_settings().get("aws.AWS_ACCESS_KEY_ID"): + if get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME: + static_creds = { + "AWS_ACCESS_KEY_ID": get_settings().aws.AWS_ACCESS_KEY_ID, + "AWS_SECRET_ACCESS_KEY": get_settings().aws.AWS_SECRET_ACCESS_KEY, + "AWS_REGION_NAME": get_settings().aws.AWS_REGION_NAME, + } + static_token = get_settings().get("aws.AWS_SESSION_TOKEN", None) + if static_token: + static_creds["AWS_SESSION_TOKEN"] = static_token + if self._aws_imds_mode: + # IMDS succeeded; stash static keys for runtime fallback only + self._aws_static_creds = static_creds + else: + # IMDS failed; activate static credentials immediately and stash + # them so the runtime fallback path is also available if needed. + self._aws_static_creds = static_creds + os.environ["AWS_ACCESS_KEY_ID"] = static_creds["AWS_ACCESS_KEY_ID"] + os.environ["AWS_SECRET_ACCESS_KEY"] = static_creds["AWS_SECRET_ACCESS_KEY"] + os.environ["AWS_REGION_NAME"] = static_creds["AWS_REGION_NAME"] + if static_token: + os.environ["AWS_SESSION_TOKEN"] = static_token + elif "AWS_SESSION_TOKEN" in os.environ: + del os.environ["AWS_SESSION_TOKEN"] + get_logger().info( + "AWS_USE_IMDS: IMDS resolution failed; using static credentials" + ) + elif get_settings().get("aws.AWS_ACCESS_KEY_ID"): assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete" os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME + static_token = get_settings().get("aws.AWS_SESSION_TOKEN", None) + if static_token: + os.environ["AWS_SESSION_TOKEN"] = static_token + elif "AWS_SESSION_TOKEN" in os.environ: + del os.environ["AWS_SESSION_TOKEN"] if get_settings().get("LITELLM.DROP_PARAMS", None): litellm.drop_params = get_settings().litellm.drop_params if get_settings().get("LITELLM.SUCCESS_CALLBACK", None): @@ -108,7 +193,7 @@ def __init__(self): # Support mistral models if get_settings().get("MISTRAL.KEY", None): os.environ["MISTRAL_API_KEY"] = get_settings().get("MISTRAL.KEY") - + # Support codestral models if get_settings().get("CODESTRAL.KEY", None): os.environ["CODESTRAL_API_KEY"] = get_settings().get("CODESTRAL.KEY") @@ -120,7 +205,7 @@ def __init__(self): access_token = _get_azure_ad_token() litellm.api_key = access_token openai.api_key = access_token - + # Set API base from settings self.api_base = get_settings().azure_ad.api_base litellm.api_base = self.api_base @@ -153,6 +238,48 @@ def __init__(self): # Models that require streaming self.streaming_required_models = STREAMING_REQUIRED_MODELS + @staticmethod + def _write_frozen_aws_creds_to_env(frozen) -> None: + """Write a botocore FrozenCredentials snapshot into os.environ for litellm/Bedrock.""" + os.environ["AWS_ACCESS_KEY_ID"] = frozen.access_key + os.environ["AWS_SECRET_ACCESS_KEY"] = frozen.secret_key + if frozen.token: + os.environ["AWS_SESSION_TOKEN"] = frozen.token + elif "AWS_SESSION_TOKEN" in os.environ: + del os.environ["AWS_SESSION_TOKEN"] + + def _refresh_aws_imds_credentials(self) -> bool: + """Refresh ambient AWS credentials from boto3 provider chain. Called before each Bedrock call + to avoid serving stale credentials from long-lived processes (EC2 roles rotate every ~6h). + + Uses the credentials object stored during __init__ rather than creating a new boto3.Session, + which would read the already-set AWS_* env vars and return stale values. + + Returns True on success, False on failure (caller should trigger static fallback).""" + import botocore.exceptions + try: + if self._aws_boto3_creds is None: + get_logger().warning("IMDS credential refresh: no boto3 credentials object stored") + return False + self._write_frozen_aws_creds_to_env(self._aws_boto3_creds.get_frozen_credentials()) + return True + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError): + # ClientError (STS/AssumeRole failures) is not a BotoCoreError subclass. + get_logger().exception("IMDS credential refresh failed") + return False + + def _activate_static_aws_fallback(self): + """Swap process env to static credentials for Bedrock fallback after IMDS failure.""" + os.environ["AWS_ACCESS_KEY_ID"] = self._aws_static_creds["AWS_ACCESS_KEY_ID"] + os.environ["AWS_SECRET_ACCESS_KEY"] = self._aws_static_creds["AWS_SECRET_ACCESS_KEY"] + os.environ["AWS_REGION_NAME"] = self._aws_static_creds["AWS_REGION_NAME"] + if "AWS_SESSION_TOKEN" in self._aws_static_creds: + os.environ["AWS_SESSION_TOKEN"] = self._aws_static_creds["AWS_SESSION_TOKEN"] + elif "AWS_SESSION_TOKEN" in os.environ: + del os.environ["AWS_SESSION_TOKEN"] + self._aws_imds_fell_back = True + get_logger().warning("Bedrock call failed with ambient (IMDS) credentials; retrying with static credentials") + def prepare_logs(self, response, system, user, resp, finish_reason): response_log = response.dict().copy() response_log['system'] = system @@ -268,174 +395,190 @@ def deployment_id(self): stop=stop_after_attempt(MODEL_RETRIES), ) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): - try: - resp, finish_reason = None, None - deployment_id = self.deployment_id - if self.azure: - model = 'azure/' + model - if 'claude' in model and not system: - system = "No system prompt provided" - get_logger().warning( - "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.") - messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] - - if img_path: - try: - # check if the image link is alive - r = requests.head(img_path, allow_redirects=True) - if r.status_code == 404: - error_msg = f"The image link is not [alive](img_path).\nPlease repost the original image as a comment, and send the question again with 'quote reply' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context))." - get_logger().error(error_msg) - return f"{error_msg}", "error" - except Exception as e: - get_logger().error(f"Error fetching image: {img_path}", e) - return f"Error fetching image: {img_path}", "error" - messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]}, - {"type": "image_url", "image_url": {"url": img_path}}] - - thinking_kwargs_gpt5 = None - if model.startswith('gpt-5'): - # Use configured reasoning_effort or default to MEDIUM - config_effort = get_settings().config.reasoning_effort - try: - ReasoningEffort(config_effort) - effort = config_effort - except (ValueError, TypeError): - effort = ReasoningEffort.MEDIUM.value - if config_effort is not None: - get_logger().warning( - f"Invalid reasoning_effort '{config_effort}' in config. " - f"Using default '{effort}'. Valid values: {[e.value for e in ReasoningEffort]}" - ) - - thinking_kwargs_gpt5 = { - "reasoning_effort": effort, - "allowed_openai_params": ["reasoning_effort"], - } - get_logger().info(f"Using reasoning_effort='{effort}' for GPT-5 model") - model = 'openai/'+model.replace('_thinking', '') # remove _thinking suffix - - - # Currently, some models do not support a separate system and user prompts - if model in self.user_message_only_models or get_settings().config.custom_reasoning_model: - user = f"{system}\n\n\n{user}" - system = "" - get_logger().info(f"Using model {model}, combining system and user prompts") - messages = [{"role": "user", "content": user}] - kwargs = { - "model": model, - "deployment_id": deployment_id, - "messages": messages, - "timeout": get_settings().config.ai_timeout, - "api_base": self.api_base, - } - else: - kwargs = { - "model": model, - "deployment_id": deployment_id, - "messages": messages, - "timeout": get_settings().config.ai_timeout, - "api_base": self.api_base, - } - - # Add temperature only if model supports it - if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model: - # get_logger().info(f"Adding temperature with value {temperature} to model {model}.") - kwargs["temperature"] = temperature - - if thinking_kwargs_gpt5: - kwargs.update(thinking_kwargs_gpt5) - if 'temperature' in kwargs: - del kwargs['temperature'] - - # Add reasoning_effort if model supports it - if model in self.support_reasoning_models: - config_effort = get_settings().config.reasoning_effort - try: - ReasoningEffort(config_effort) - reasoning_effort = config_effort - except (ValueError, TypeError): - reasoning_effort = ReasoningEffort.MEDIUM.value - if config_effort is not None: - get_logger().warning( - f"Invalid reasoning_effort '{config_effort}' in config. " - f"Using default '{reasoning_effort}'. Valid values: {[e.value for e in ReasoningEffort]}" - ) - - get_logger().info(f"Adding reasoning_effort with value {reasoning_effort} to model {model}.") - kwargs["reasoning_effort"] = reasoning_effort - - # https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking - if (model in self.claude_extended_thinking_models) and get_settings().config.get("enable_claude_extended_thinking", False): - kwargs = self._configure_claude_extended_thinking(model, kwargs) - - if get_settings().litellm.get("enable_callbacks", False): - kwargs = self.add_litellm_callbacks(kwargs) - - seed = get_settings().config.get("seed", -1) - if temperature > 0 and seed >= 0: - raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0") - elif seed >= 0: - get_logger().info(f"Using fixed seed of {seed}") - kwargs["seed"] = seed - - if self.repetition_penalty: - kwargs["repetition_penalty"] = self.repetition_penalty - - #Added support for extra_headers while using litellm to call underlying model, via a api management gateway, would allow for passing custom headers for security and authorization - if get_settings().get("LITELLM.EXTRA_HEADERS", None): - try: - litellm_extra_headers = json.loads(get_settings().litellm.extra_headers) - if not isinstance(litellm_extra_headers, dict): - raise ValueError("LITELLM.EXTRA_HEADERS must be a JSON object") - except json.JSONDecodeError as e: - raise ValueError(f"LITELLM.EXTRA_HEADERS contains invalid JSON: {str(e)}") - kwargs["extra_headers"] = litellm_extra_headers - - # Support for custom OpenAI body fields (e.g., Flex Processing) - kwargs = _process_litellm_extra_body(kwargs) - - # Support for Bedrock custom inference profile via model_id - model_id = get_settings().get("litellm.model_id") - if model_id and 'bedrock/' in model: - kwargs["model_id"] = model_id - get_logger().info(f"Using Bedrock custom inference profile: {model_id}") - - get_logger().debug("Prompts", artifact={"system": system, "user": user}) + # Serialize env-var mutation + Bedrock call for IMDS mode to prevent concurrent + # requests from interleaving os.environ credentials during asyncio.gather usage. + _bedrock_imds = self._aws_imds_mode and 'bedrock/' in model + async with (self._aws_bedrock_lock if _bedrock_imds else contextlib.nullcontext()): + if _bedrock_imds and not self._aws_imds_fell_back: + if not self._refresh_aws_imds_credentials() and self._aws_static_creds: + self._activate_static_aws_fallback() + self._aws_imds_fell_back = True + try: + resp, finish_reason = None, None + deployment_id = self.deployment_id + if self.azure: + model = 'azure/' + model + if 'claude' in model and not system: + system = "No system prompt provided" + get_logger().warning( + "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.") + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + if img_path: + try: + # check if the image link is alive + r = requests.head(img_path, allow_redirects=True) + if r.status_code == 404: + error_msg = f"The image link is not [alive](img_path).\nPlease repost the original image as a comment, and send the question again with 'quote reply' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context))." + get_logger().error(error_msg) + return f"{error_msg}", "error" + except Exception as e: + get_logger().error(f"Error fetching image: {img_path}", e) + return f"Error fetching image: {img_path}", "error" + messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]}, + {"type": "image_url", "image_url": {"url": img_path}}] + + thinking_kwargs_gpt5 = None + if model.startswith('gpt-5'): + # Use configured reasoning_effort or default to MEDIUM + config_effort = get_settings().config.reasoning_effort + try: + ReasoningEffort(config_effort) + effort = config_effort + except (ValueError, TypeError): + effort = ReasoningEffort.MEDIUM.value + if config_effort is not None: + get_logger().warning( + f"Invalid reasoning_effort '{config_effort}' in config. " + f"Using default '{effort}'. Valid values: {[e.value for e in ReasoningEffort]}" + ) + + thinking_kwargs_gpt5 = { + "reasoning_effort": effort, + "allowed_openai_params": ["reasoning_effort"], + } + get_logger().info(f"Using reasoning_effort='{effort}' for GPT-5 model") + model = 'openai/'+model.replace('_thinking', '') # remove _thinking suffix + + + # Currently, some models do not support a separate system and user prompts + if model in self.user_message_only_models or get_settings().config.custom_reasoning_model: + user = f"{system}\n\n\n{user}" + system = "" + get_logger().info(f"Using model {model}, combining system and user prompts") + messages = [{"role": "user", "content": user}] + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "timeout": get_settings().config.ai_timeout, + "api_base": self.api_base, + } + else: + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "timeout": get_settings().config.ai_timeout, + "api_base": self.api_base, + } + # Add temperature only if model supports it + if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model: + # get_logger().info(f"Adding temperature with value {temperature} to model {model}.") + kwargs["temperature"] = temperature + + if thinking_kwargs_gpt5: + kwargs.update(thinking_kwargs_gpt5) + if 'temperature' in kwargs: + del kwargs['temperature'] + + # Add reasoning_effort if model supports it + if model in self.support_reasoning_models: + config_effort = get_settings().config.reasoning_effort + try: + ReasoningEffort(config_effort) + reasoning_effort = config_effort + except (ValueError, TypeError): + reasoning_effort = ReasoningEffort.MEDIUM.value + if config_effort is not None: + get_logger().warning( + f"Invalid reasoning_effort '{config_effort}' in config. " + f"Using default '{reasoning_effort}'. Valid values: {[e.value for e in ReasoningEffort]}" + ) + + get_logger().info(f"Adding reasoning_effort with value {reasoning_effort} to model {model}.") + kwargs["reasoning_effort"] = reasoning_effort + + # https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking + if (model in self.claude_extended_thinking_models) and get_settings().config.get("enable_claude_extended_thinking", False): + kwargs = self._configure_claude_extended_thinking(model, kwargs) + + if get_settings().litellm.get("enable_callbacks", False): + kwargs = self.add_litellm_callbacks(kwargs) + + seed = get_settings().config.get("seed", -1) + if temperature > 0 and seed >= 0: + raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0") + elif seed >= 0: + get_logger().info(f"Using fixed seed of {seed}") + kwargs["seed"] = seed + + if self.repetition_penalty: + kwargs["repetition_penalty"] = self.repetition_penalty + + #Added support for extra_headers while using litellm to call underlying model, via a api management gateway, would allow for passing custom headers for security and authorization + if get_settings().get("LITELLM.EXTRA_HEADERS", None): + try: + litellm_extra_headers = json.loads(get_settings().litellm.extra_headers) + if not isinstance(litellm_extra_headers, dict): + raise ValueError("LITELLM.EXTRA_HEADERS must be a JSON object") + except json.JSONDecodeError as e: + raise ValueError(f"LITELLM.EXTRA_HEADERS contains invalid JSON: {str(e)}") + kwargs["extra_headers"] = litellm_extra_headers + + # Support for custom OpenAI body fields (e.g., Flex Processing) + kwargs = _process_litellm_extra_body(kwargs) + + # Support for Bedrock custom inference profile via model_id + model_id = get_settings().get("litellm.model_id") + if model_id and 'bedrock/' in model: + kwargs["model_id"] = model_id + get_logger().info(f"Using Bedrock custom inference profile: {model_id}") + + get_logger().debug("Prompts", artifact={"system": system, "user": user}) + + if get_settings().config.verbosity_level >= 2: + get_logger().info(f"\nSystem prompt:\n{system}") + get_logger().info(f"\nUser prompt:\n{user}") + + # Inject api_key to the call. This key is populated during init by providers + # like Groq, XAI, Azure AD, and OpenRouter. Skip if None or placeholder. + if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY: + kwargs["api_key"] = litellm.api_key + + # Get completion with automatic streaming detection + resp, finish_reason, response_obj = await self._get_completion(**kwargs) + + except openai.RateLimitError as e: + get_logger().error(f"Rate limit error during LLM inference: {e}") + raise + except openai.APIError as e: + if _bedrock_imds and not self._aws_imds_fell_back and self._aws_static_creds: + self._activate_static_aws_fallback() + # Retry immediately while still holding the lock so that the + # env-var swap is fully visible to this call. Letting @retry + # handle the retry would release the lock between attempts, + # allowing a concurrent coroutine to overwrite os.environ. + resp, finish_reason, response_obj = await self._get_completion(**kwargs) + else: + get_logger().warning(f"Error during LLM inference: {e}") + raise + except Exception as e: + get_logger().warning(f"Unknown error during LLM inference: {e}") + raise openai.APIError from e + + get_logger().debug(f"\nAI response:\n{resp}") + + # log the full response for debugging + response_log = self.prepare_logs(response_obj, system, user, resp, finish_reason) + get_logger().debug("Full_response", artifact=response_log) + + # for CLI debugging if get_settings().config.verbosity_level >= 2: - get_logger().info(f"\nSystem prompt:\n{system}") - get_logger().info(f"\nUser prompt:\n{user}") - - # Inject api_key to the call. This key is populated during init by providers - # like Groq, XAI, Azure AD, and OpenRouter. Skip if None or placeholder. - if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY: - kwargs["api_key"] = litellm.api_key - - # Get completion with automatic streaming detection - resp, finish_reason, response_obj = await self._get_completion(**kwargs) - - except openai.RateLimitError as e: - get_logger().error(f"Rate limit error during LLM inference: {e}") - raise - except openai.APIError as e: - get_logger().warning(f"Error during LLM inference: {e}") - raise - except Exception as e: - get_logger().warning(f"Unknown error during LLM inference: {e}") - raise openai.APIError from e - - get_logger().debug(f"\nAI response:\n{resp}") - - # log the full response for debugging - response_log = self.prepare_logs(response_obj, system, user, resp, finish_reason) - get_logger().debug("Full_response", artifact=response_log) - - # for CLI debugging - if get_settings().config.verbosity_level >= 2: - get_logger().info(f"\nAI response:\n{resp}") + get_logger().info(f"\nAI response:\n{resp}") - return resp, finish_reason + return resp, finish_reason async def _get_completion(self, **kwargs): """ diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index b8a4875976..09806f6936 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -130,9 +130,14 @@ key = "" api_base = "" [aws] +# When running on AWS compute (EC2, ECS, EKS, Lambda) with an IAM role attached, +# set AWS_USE_IMDS=true in the environment instead of providing static keys here. +# These keys are only needed as an optional fallback when AWS_USE_IMDS=true, or +# when running outside AWS compute without ambient credentials. AWS_ACCESS_KEY_ID = "" AWS_SECRET_ACCESS_KEY = "" AWS_REGION_NAME = "" +AWS_SESSION_TOKEN = "" # optional: only needed for STS-derived or temporary credentials [aws_secrets_manager] secret_arn = "" # The ARN of the AWS Secrets Manager secret containing PR-Agent configuration diff --git a/tests/unittest/test_litellm_imds.py b/tests/unittest/test_litellm_imds.py new file mode 100644 index 0000000000..cd798ad0b8 --- /dev/null +++ b/tests/unittest/test_litellm_imds.py @@ -0,0 +1,647 @@ +""" +Tests for AWS_USE_IMDS ambient credential support in LiteLLMAIHandler. + +Covers: + - Credentials resolved via boto3 and written to os.environ + - AWS_SESSION_TOKEN set/cleared correctly + - Region auto-resolved from boto3 when not configured + - Static keys stashed for fallback (including session token) + - boto3 failure falls through gracefully (no crash) + - No boto3 call when AWS_USE_IMDS is absent + - _refresh_aws_imds_credentials called before each Bedrock chat_completion + - Fallback to static keys on Bedrock API failure + - _activate_static_aws_fallback correctly restores/clears session token +""" +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import openai +import pytest +from botocore.exceptions import ClientError, CredentialRetrievalError +from tenacity import RetryError + +import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _base_settings(extra_get=None): + """Minimal settings that satisfy __init__ and chat_completion.""" + def _get(self, key, default=None): + if extra_get: + v = extra_get(key) + if v is not None: + return v + return default + + return type("Settings", (), { + "config": type("Config", (), { + "reasoning_effort": None, + "ai_timeout": 30, + "custom_reasoning_model": False, + "max_model_tokens": 32000, + "verbosity_level": 0, + "seed": -1, + "get": lambda self, key, default=None: default, + })(), + "litellm": type("LiteLLM", (), { + "get": lambda self, key, default=None: default, + })(), + "get": _get, + })() + + +def _mock_acompletion_response(): + mock = MagicMock() + mock.__getitem__ = lambda self, key: { + "choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}] + }[key] + mock.dict.return_value = {"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}]} + return mock + + +def _static_aws_settings(session_token=None): + """Settings object with static AWS credentials configured.""" + keys = { + "aws.AWS_ACCESS_KEY_ID": "STATICKEY", + "aws.AWS_SECRET_ACCESS_KEY": "STATICSECRET", + "aws.AWS_REGION_NAME": "us-east-1", + } + if session_token: + keys["aws.AWS_SESSION_TOKEN"] = session_token + settings = _base_settings(extra_get=lambda key: keys.get(key)) + aws_attrs = { + "AWS_ACCESS_KEY_ID": "STATICKEY", + "AWS_SECRET_ACCESS_KEY": "STATICSECRET", + "AWS_REGION_NAME": "us-east-1", + } + if session_token: + aws_attrs["AWS_SESSION_TOKEN"] = session_token + settings.aws = type("AWS", (), aws_attrs)() + return settings + + +def _frozen_creds( + access_key="FAKE-KEY", + secret_key="FAKE-SECRET", + token=None, +): + frozen = MagicMock() + frozen.access_key = access_key + frozen.secret_key = secret_key + frozen.token = token + return frozen + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def clean_aws_env(monkeypatch): + """Ensure AWS env vars don't bleed between tests.""" + for var in ("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", "AWS_REGION_NAME", "AWS_USE_IMDS"): + monkeypatch.delenv(var, raising=False) + + +@pytest.fixture(autouse=True) +def default_settings(monkeypatch): + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _base_settings()) + + +# --------------------------------------------------------------------------- +# __init__ — credential resolution +# --------------------------------------------------------------------------- + +class TestImdsInit: + + def test_imds_creds_written_to_env(self, monkeypatch): + """When AWS_USE_IMDS=true, boto3 creds are placed in os.environ.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = None + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert os.environ["AWS_ACCESS_KEY_ID"] == frozen.access_key + assert os.environ["AWS_SECRET_ACCESS_KEY"] == frozen.secret_key + assert handler._aws_imds_mode is True + + def test_imds_session_token_set_when_present(self, monkeypatch): + monkeypatch.setenv("AWS_USE_IMDS", "1") + frozen = _frozen_creds(token="session-token-xyz") + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = None + + with patch("boto3.Session", return_value=mock_session): + LiteLLMAIHandler() + + assert os.environ["AWS_SESSION_TOKEN"] == "session-token-xyz" + + def test_imds_session_token_cleared_when_absent(self, monkeypatch): + """Stale AWS_SESSION_TOKEN from the environment is removed when IMDS creds have no token.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + monkeypatch.setenv("AWS_SESSION_TOKEN", "stale-token") + frozen = _frozen_creds(token=None) + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = None + + with patch("boto3.Session", return_value=mock_session): + LiteLLMAIHandler() + + assert "AWS_SESSION_TOKEN" not in os.environ + + def test_imds_region_auto_resolved(self, monkeypatch): + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "eu-west-1" + + with patch("boto3.Session", return_value=mock_session): + LiteLLMAIHandler() + + assert os.environ["AWS_REGION_NAME"] == "eu-west-1" + + def test_imds_region_not_overwritten_when_already_set(self, monkeypatch): + monkeypatch.setenv("AWS_USE_IMDS", "true") + monkeypatch.setenv("AWS_REGION_NAME", "us-west-2") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "eu-west-1" + + with patch("boto3.Session", return_value=mock_session): + LiteLLMAIHandler() + + assert os.environ["AWS_REGION_NAME"] == "us-west-2" + + def test_imds_configured_region_exported_to_env(self, monkeypatch): + """aws.AWS_REGION_NAME in settings must be written to env even in IMDS mode.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "eu-west-1" # would be used if settings region absent + + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _static_aws_settings()) + + with patch("boto3.Session", return_value=mock_session): + LiteLLMAIHandler() + + # settings region (us-east-1) takes precedence over boto3-resolved region (eu-west-1) + assert os.environ["AWS_REGION_NAME"] == "us-east-1" + + def test_imds_boto3_creds_stored_for_refresh(self, monkeypatch): + """The boto3 credentials object must be stored so refresh avoids re-reading env vars.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.return_value = frozen + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert handler._aws_boto3_creds is mock_creds + + def test_imds_no_creds_from_boto3(self, monkeypatch): + """When boto3 returns no credentials, _aws_imds_mode remains False.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + mock_session = MagicMock() + mock_session.get_credentials.return_value = None + mock_session.region_name = None + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert handler._aws_imds_mode is False + + def test_imds_boto3_exception_does_not_crash(self, monkeypatch): + """A boto3 exception during credential resolution must not crash __init__.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + mock_session = MagicMock() + mock_session.get_credentials.side_effect = CredentialRetrievalError( + provider="imds", error_msg="connection timeout" + ) + mock_session.region_name = None + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() # must not raise + + assert handler._aws_imds_mode is False + + def test_imds_sts_client_error_does_not_crash(self, monkeypatch): + """A ClientError from STS/AssumeRole (IRSA path) must not crash __init__. + ClientError is not a BotoCoreError subclass, so it needs an explicit catch.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + mock_session = MagicMock() + mock_session.get_credentials.side_effect = ClientError( + error_response={"Error": {"Code": "AccessDenied", "Message": "Not authorized to assume role"}}, + operation_name="AssumeRole", + ) + mock_session.region_name = None + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() # must not raise + + assert handler._aws_imds_mode is False + + def test_static_session_token_exported_without_imds(self, monkeypatch): + """Non-IMDS static path exports AWS_SESSION_TOKEN when present in settings.""" + settings = _static_aws_settings(session_token="STS-TOKEN-NOIMDS") + monkeypatch.setattr(litellm_handler, "get_settings", lambda: settings) + + LiteLLMAIHandler() + + assert os.environ["AWS_SESSION_TOKEN"] == "STS-TOKEN-NOIMDS" + + def test_static_session_token_cleared_without_imds(self, monkeypatch): + """Non-IMDS static path clears stale AWS_SESSION_TOKEN when not in settings.""" + monkeypatch.setenv("AWS_SESSION_TOKEN", "stale-token") + settings = _static_aws_settings() # no session_token + monkeypatch.setattr(litellm_handler, "get_settings", lambda: settings) + + LiteLLMAIHandler() + + assert "AWS_SESSION_TOKEN" not in os.environ + + def test_no_imds_when_env_var_absent(self, monkeypatch): + """boto3 must never be imported or called when AWS_USE_IMDS is not set.""" + with patch("boto3.Session") as mock_boto3: + LiteLLMAIHandler() + + mock_boto3.assert_not_called() + + def test_static_keys_stashed_for_fallback(self, monkeypatch): + """Static keys in config are stashed in _aws_static_creds when IMDS mode is active.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = None + + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _static_aws_settings()) + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert handler._aws_static_creds is not None + assert handler._aws_static_creds["AWS_ACCESS_KEY_ID"] == "STATICKEY" + assert handler._aws_static_creds["AWS_SECRET_ACCESS_KEY"] == "STATICSECRET" + assert handler._aws_static_creds["AWS_REGION_NAME"] == "us-east-1" + + def test_static_session_token_stashed(self, monkeypatch): + """AWS_SESSION_TOKEN from static config is included in _aws_static_creds.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = None + + monkeypatch.setattr( + litellm_handler, "get_settings", + lambda: _static_aws_settings(session_token="STATIC-SESSION-TOKEN") + ) + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert handler._aws_static_creds.get("AWS_SESSION_TOKEN") == "STATIC-SESSION-TOKEN" + + def test_static_keys_applied_when_imds_returns_no_creds(self, monkeypatch): + """When AWS_USE_IMDS is set but boto3 returns None, static keys are applied to env.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + mock_session = MagicMock() + mock_session.get_credentials.return_value = None + mock_session.region_name = None + + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _static_aws_settings()) + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert handler._aws_imds_mode is False + assert os.environ["AWS_ACCESS_KEY_ID"] == "STATICKEY" + assert os.environ["AWS_SECRET_ACCESS_KEY"] == "STATICSECRET" + assert os.environ["AWS_REGION_NAME"] == "us-east-1" + + def test_imds_failed_path_clears_stale_session_token(self, monkeypatch): + """When IMDS fails and static creds have no token, a stale AWS_SESSION_TOKEN is cleared.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + monkeypatch.setenv("AWS_SESSION_TOKEN", "stale-imds-token") + mock_session = MagicMock() + mock_session.get_credentials.return_value = None + mock_session.region_name = None + + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _static_aws_settings()) + + with patch("boto3.Session", return_value=mock_session): + LiteLLMAIHandler() + + assert "AWS_SESSION_TOKEN" not in os.environ + + def test_static_keys_applied_when_boto3_raises(self, monkeypatch): + """When AWS_USE_IMDS is set but boto3 throws, static keys are applied to env.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + mock_session = MagicMock() + mock_session.get_credentials.side_effect = CredentialRetrievalError( + provider="imds", error_msg="connection timeout" + ) + mock_session.region_name = None + + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _static_aws_settings()) + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + assert handler._aws_imds_mode is False + assert os.environ["AWS_ACCESS_KEY_ID"] == "STATICKEY" + assert os.environ["AWS_SECRET_ACCESS_KEY"] == "STATICSECRET" + + +# --------------------------------------------------------------------------- +# chat_completion — credential refresh and fallback +# --------------------------------------------------------------------------- + +class TestImdsCallBehavior: + + @pytest.mark.asyncio + async def test_refresh_called_before_bedrock_call(self, monkeypatch): + """_refresh_aws_imds_credentials is called before each Bedrock chat_completion.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + with patch.object(handler, "_refresh_aws_imds_credentials") as mock_refresh, \ + patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock) as mock_call: + mock_call.return_value = _mock_acompletion_response() + await handler.chat_completion( + model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + system="sys", user="usr" + ) + + mock_refresh.assert_called_once() + + def test_refresh_uses_stored_creds_not_new_session(self, monkeypatch): + """_refresh_aws_imds_credentials must call get_frozen_credentials on the stored object, + not create a new boto3.Session (which would re-read env vars and return stale creds).""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen1 = _frozen_creds(access_key="FIRST-KEY", secret_key="FIRST-SECRET") + frozen2 = _frozen_creds(access_key="ROTATED-KEY", secret_key="ROTATED-SECRET") + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.side_effect = [frozen1, frozen2] + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + # boto3.Session should only be called once (in __init__), not during refresh + with patch("boto3.Session") as mock_boto3_refresh: + handler._refresh_aws_imds_credentials() + + mock_boto3_refresh.assert_not_called() + assert os.environ["AWS_ACCESS_KEY_ID"] == "ROTATED-KEY" + assert os.environ["AWS_SECRET_ACCESS_KEY"] == "ROTATED-SECRET" + + def test_refresh_returns_false_and_warns_when_no_stored_creds(self, monkeypatch): + """_refresh_aws_imds_credentials returns False and logs a warning when _aws_boto3_creds is None.""" + handler = LiteLLMAIHandler() + assert handler._aws_boto3_creds is None + result = handler._refresh_aws_imds_credentials() + assert result is False + + def test_refresh_returns_false_on_exception(self, monkeypatch): + """_refresh_aws_imds_credentials returns False when get_frozen_credentials raises.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.side_effect = [ + frozen, + CredentialRetrievalError(provider="imds", error_msg="token expired"), + ] + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + result = handler._refresh_aws_imds_credentials() + assert result is False + + def test_refresh_returns_false_on_sts_client_error(self, monkeypatch): + """_refresh_aws_imds_credentials returns False on ClientError (STS/AssumeRole path). + ClientError is not a BotoCoreError subclass, so it needs an explicit catch.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.side_effect = [ + frozen, + ClientError( + error_response={"Error": {"Code": "AccessDenied", "Message": "Token expired"}}, + operation_name="AssumeRoleWithWebIdentity", + ), + ] + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + result = handler._refresh_aws_imds_credentials() + assert result is False + + @pytest.mark.asyncio + async def test_static_fallback_activated_on_refresh_failure(self, monkeypatch): + """When refresh fails and static creds are available, fallback activates before the call.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.side_effect = [ + frozen, + CredentialRetrievalError(provider="imds", error_msg="IMDS unreachable"), + ] + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + handler._aws_static_creds = { + "AWS_ACCESS_KEY_ID": "STATICKEY", + "AWS_SECRET_ACCESS_KEY": "STATICSECRET", + "AWS_REGION_NAME": "us-east-1", + } + + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock) as mock_call: + mock_call.return_value = _mock_acompletion_response() + await handler.chat_completion( + model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + system="sys", user="usr" + ) + + assert handler._aws_imds_fell_back is True + assert os.environ["AWS_ACCESS_KEY_ID"] == "STATICKEY" + + @pytest.mark.asyncio + async def test_refresh_not_called_for_non_bedrock_model(self, monkeypatch): + """_refresh_aws_imds_credentials is NOT called when model is not a Bedrock model.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + with patch.object(handler, "_refresh_aws_imds_credentials") as mock_refresh, \ + patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock) as mock_call: + mock_call.return_value = _mock_acompletion_response() + await handler.chat_completion(model="gpt-4o", system="sys", user="usr") + + mock_refresh.assert_not_called() + + @pytest.mark.asyncio + async def test_fallback_to_static_on_bedrock_failure(self, monkeypatch): + """On Bedrock APIError, _activate_static_aws_fallback is called and the + in-lock retry uses static credentials (second acompletion call).""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + # Manually inject static creds as fallback + handler._aws_static_creds = { + "AWS_ACCESS_KEY_ID": "STATICKEY", + "AWS_SECRET_ACCESS_KEY": "STATICSECRET", + "AWS_REGION_NAME": "us-east-1", + } + + call_count = 0 + + async def flaky_completion(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise openai.APIError("Bedrock auth failed", request=MagicMock(), body=None) + return _mock_acompletion_response() + + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + side_effect=flaky_completion): + await handler.chat_completion( + model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + system="sys", user="usr" + ) + + assert call_count == 2 + assert handler._aws_imds_fell_back is True + assert os.environ["AWS_ACCESS_KEY_ID"] == "STATICKEY" + assert os.environ["AWS_SECRET_ACCESS_KEY"] == "STATICSECRET" + + @pytest.mark.asyncio + async def test_fallback_not_triggered_without_static_creds(self, monkeypatch): + """If no static fallback credentials exist, APIError propagates normally.""" + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds() + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "us-east-1" + + with patch("boto3.Session", return_value=mock_session): + handler = LiteLLMAIHandler() + + # No static creds stashed + handler._aws_static_creds = None + + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock, + side_effect=openai.APIError("auth failed", request=MagicMock(), body=None)): + # tenacity exhausts MODEL_RETRIES and re-raises as RetryError + with pytest.raises((openai.APIError, RetryError)): + await handler.chat_completion( + model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + system="sys", user="usr" + ) + + +# --------------------------------------------------------------------------- +# _activate_static_aws_fallback — session token handling +# --------------------------------------------------------------------------- + +class TestActivateStaticFallback: + + def _make_handler_in_imds_mode(self, monkeypatch): + monkeypatch.setenv("AWS_USE_IMDS", "true") + frozen = _frozen_creds(token="imds-session-token") + mock_session = MagicMock() + mock_session.get_credentials.return_value.get_frozen_credentials.return_value = frozen + mock_session.region_name = "us-east-1" + with patch("boto3.Session", return_value=mock_session): + return LiteLLMAIHandler() + + def test_restores_static_session_token(self, monkeypatch): + """If static creds include a session token, it is restored in env.""" + handler = self._make_handler_in_imds_mode(monkeypatch) + handler._aws_static_creds = { + "AWS_ACCESS_KEY_ID": "SK", + "AWS_SECRET_ACCESS_KEY": "SS", + "AWS_REGION_NAME": "us-east-1", + "AWS_SESSION_TOKEN": "static-sts-token", + } + handler._activate_static_aws_fallback() + + assert os.environ["AWS_SESSION_TOKEN"] == "static-sts-token" + + def test_clears_session_token_when_static_creds_have_none(self, monkeypatch): + """IMDS session token is removed from env when static creds have no token.""" + handler = self._make_handler_in_imds_mode(monkeypatch) + # IMDS token was set during init + assert os.environ.get("AWS_SESSION_TOKEN") == "imds-session-token" + handler._aws_static_creds = { + "AWS_ACCESS_KEY_ID": "SK", + "AWS_SECRET_ACCESS_KEY": "SS", + "AWS_REGION_NAME": "us-east-1", + # no AWS_SESSION_TOKEN + } + handler._activate_static_aws_fallback() + + assert "AWS_SESSION_TOKEN" not in os.environ + + def test_sets_fell_back_flag(self, monkeypatch): + handler = self._make_handler_in_imds_mode(monkeypatch) + handler._aws_static_creds = { + "AWS_ACCESS_KEY_ID": "SK", + "AWS_SECRET_ACCESS_KEY": "SS", + "AWS_REGION_NAME": "us-east-1", + } + assert handler._aws_imds_fell_back is False + handler._activate_static_aws_fallback() + assert handler._aws_imds_fell_back is True