diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 4176f85a6a..541fbb432a 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="vertexai", - pip_packages=["litellm", "google-cloud-aiplatform"], + pip_packages=["litellm", "google-cloud-aiplatform", "openai"], module="llama_stack.providers.remote.inference.vertexai", config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py index 8807fd0e6a..27f953ab97 100644 --- a/llama_stack/providers/remote/inference/vertexai/vertexai.py +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -6,16 +6,20 @@ from typing import Any +import google.auth.transport.requests +from google.auth import default + from llama_stack.apis.inference import ChatCompletionRequest from llama_stack.providers.utils.inference.litellm_openai_mixin import ( LiteLLMOpenAIMixin, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import VertexAIConfig from .models import MODEL_ENTRIES -class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): +class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: VertexAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -27,9 +31,30 @@ def __init__(self, config: VertexAIConfig) -> None: self.config = config def get_api_key(self) -> str: - # Vertex AI doesn't use API keys, it uses Application Default Credentials - # Return empty string to let litellm handle authentication via ADC - return "" + """ + Get an access token for Vertex AI using Application Default Credentials. + + Vertex AI uses ADC instead of API keys. This method obtains an access token + from the default credentials and returns it for use with the OpenAI-compatible client. + """ + try: + # Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS + credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + credentials.refresh(google.auth.transport.requests.Request()) + return credentials.token + except Exception: + # If we can't get credentials, return empty string to let LiteLLM handle it + # This allows the LiteLLM mixin to work with ADC directly + return "" + + def get_base_url(self) -> str: + """ + Get the Vertex AI OpenAI-compatible API base URL. + + Returns the Vertex AI OpenAI-compatible endpoint URL. + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai + """ + return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi" async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: # Get base parameters from parent diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index df1184f1c4..f9c837ebdb 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id): "remote::gemini", # https://docs.anthropic.com/en/api/openai-sdk#simple-fields "remote::anthropic", + "remote::vertexai", + # Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but + # the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'} ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")