Skip to content

Commit 1685b19

Browse files
mattfleseb
authored andcommitted
chore: update the vertexai inference impl to use openai-python for openai-compat functions (llamastack#3377)
# What does this PR do? update VertexAI inference provider to use openai-python for openai-compat functions ## Test Plan ``` $ VERTEX_AI_PROJECT=... uv run llama stack build --image-type venv --providers inference=remote::vertexai --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model vertexai/vertex_ai/gemini-2.5-flash tests/integration/inference/test_openai_completion.py ... ``` i don't have an account to test this. `get_api_key` may also need to be updated per https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai --------- Signed-off-by: Sébastien Han <[email protected]> Co-authored-by: Sébastien Han <[email protected]>
1 parent 4b84012 commit 1685b19

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

llama_stack/providers/registry/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]:
218218
api=Api.inference,
219219
adapter=AdapterSpec(
220220
adapter_type="vertexai",
221-
pip_packages=["litellm", "google-cloud-aiplatform"],
221+
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
222222
module="llama_stack.providers.remote.inference.vertexai",
223223
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
224224
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",

llama_stack/providers/remote/inference/vertexai/vertexai.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66

77
from typing import Any
88

9+
import google.auth.transport.requests
10+
from google.auth import default
11+
912
from llama_stack.apis.inference import ChatCompletionRequest
1013
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
1114
LiteLLMOpenAIMixin,
1215
)
16+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
1317

1418
from .config import VertexAIConfig
1519
from .models import MODEL_ENTRIES
1620

1721

18-
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
22+
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
1923
def __init__(self, config: VertexAIConfig) -> None:
2024
LiteLLMOpenAIMixin.__init__(
2125
self,
@@ -27,9 +31,30 @@ def __init__(self, config: VertexAIConfig) -> None:
2731
self.config = config
2832

2933
def get_api_key(self) -> str:
30-
# Vertex AI doesn't use API keys, it uses Application Default Credentials
31-
# Return empty string to let litellm handle authentication via ADC
32-
return ""
34+
"""
35+
Get an access token for Vertex AI using Application Default Credentials.
36+
37+
Vertex AI uses ADC instead of API keys. This method obtains an access token
38+
from the default credentials and returns it for use with the OpenAI-compatible client.
39+
"""
40+
try:
41+
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
42+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
43+
credentials.refresh(google.auth.transport.requests.Request())
44+
return credentials.token
45+
except Exception:
46+
# If we can't get credentials, return empty string to let LiteLLM handle it
47+
# This allows the LiteLLM mixin to work with ADC directly
48+
return ""
49+
50+
def get_base_url(self) -> str:
51+
"""
52+
Get the Vertex AI OpenAI-compatible API base URL.
53+
54+
Returns the Vertex AI OpenAI-compatible endpoint URL.
55+
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
56+
"""
57+
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
3358

3459
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
3560
# Get base parameters from parent

tests/integration/inference/test_openai_completion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
7676
"remote::gemini",
7777
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
7878
"remote::anthropic",
79+
"remote::vertexai",
80+
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
81+
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
7982
):
8083
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
8184

0 commit comments

Comments
 (0)