Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: azure ai inference support #364

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
),
),
remote_provider_spec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just naming feedback: given that we are trying to get azure to provide a full llama stack distribution, suggest calling the provider: azure_ai

api=Api.inference,
adapter=AdapterSpec(
adapter_type="azure-ai-inference",
pip_packages=["azure-ai-inference", "azure-identity", "aiohttp"],
module="llama_stack.providers.remote.inference.azure_ai_inference",
config_class="llama_stack.providers.remote.inference.azure_ai_inference.AzureAIInferenceConfig",
),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .azure_ai_inference import AzureAIInferenceAdapter
from .config import AzureAIInferenceConfig


async def get_adapter_impl(config: AzureAIInferenceConfig, _deps):
assert isinstance(config, AzureAIInferenceConfig), f"Unexpected config type: {type(config)}"

impl = AzureAIInferenceAdapter(config)

await impl.initialize()

return impl
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging
from typing import AsyncGenerator

from llama_models.llama3.api.chat_format import ChatFormat

from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer

from azure.ai.inference.aio import ChatCompletionsClient as ChatCompletionsClientAsync
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from azure.identity import DefaultAzureCredential

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate

from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)

from .config import AzureAIInferenceConfig

# Mapping of model names from the Llama model names to the Azure AI model catalog names
SUPPORTED_INSTRUCT_MODELS = {
"Llama3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"Llama3.2-1B-Instruct": "Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
}

logger = logging.getLogger(__name__)

class AzureAIInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: AzureAIInferenceConfig) -> None:
tokenizer = Tokenizer.get_instance()

self.config = config
self.formatter = ChatFormat(tokenizer)
self._model_name = None

@property
def client(self) -> ChatCompletionsClientAsync:
if self.config.credential is None:
credential = DefaultAzureCredential()
else:
credential = AzureKeyCredential(self.config.credential)

if self.config.api_version:
return ChatCompletionsClientAsync(
endpoint=self.config.endpoint,
credential=credential,
user_agent="llama-stack",
api_version=self.config.api_version,
)
else:
return ChatCompletionsClientAsync(
endpoint=self.config.endpoint,
credential=credential,
user_agent="llama-stack",
)

async def initialize(self) -> None:
async with self.client as async_client:
try:
model_info = await async_client.get_model_info()
if model_info:
self._model_name = model_info.get("model_name", None)
logger.info(
f"Endpoint {self.config.endpoint} supports model {self._model_name}"
)
if self._model_name not in SUPPORTED_INSTRUCT_MODELS.values():
logger.warning(
f"Endpoints serves model {self._model_name} which may not be supported"
)
except HttpResponseError:
logger.info(
f"Endpoint {self.config.endpoint} supports multiple models"
)
self._model_name = None


async def shutdown(self) -> None:
pass


async def list_models(self) -> List[ModelDef]:
if self._model_name is None:
return [
ModelDef(identifier=model_name, llama_model=azure_model_id)
for model_name, azure_model_id in SUPPORTED_INSTRUCT_MODELS.items()
]
else:
# find if there is a value in the SUPPORTED_INSTRUCT_MODELS that matches the model name
supported_model = next(
(model for model in SUPPORTED_INSTRUCT_MODELS if SUPPORTED_INSTRUCT_MODELS[model] == self._model_name),
None
)
return [
ModelDef(
identifier=supported_model or self._model_name,
llama_model=self._model_name
)
]


async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()


async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model or self.config.model_name,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
params = self._get_params(request)
if stream:
return self._stream_chat_completion(params)
else:
return await self._nonstream_chat_completion(params)


async def _nonstream_chat_completion(
self, params: dict
) -> ChatCompletionResponse:
async with self.client as client:
r = await client.complete(**params)
return process_chat_completion_response(r, self.formatter)


async def _stream_chat_completion(
self, params: dict
) -> AsyncGenerator:
async with self.client as client:
stream = await client.complete(**params, stream=True)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk


@staticmethod
def _get_sampling_options(
params: SamplingParams,
logprobs: Optional[LogProbConfig] = None
) -> dict:
options = {}
model_extras = {}
if params:
# repetition_penalty is not supported by Azure AI inference API
for attr in {"temperature", "top_p", "max_tokens"}:
if getattr(params, attr):
options[attr] = getattr(params, attr)

if params.top_k is not None and params.top_k != 0:
model_extras["top_k"] = params.top_k

if logprobs is not None:
model_extras["logprobs"] = params.logprobs

if model_extras:
options["model_extras"] = model_extras

return options

@staticmethod
def _to_azure_ai_messages(messages: List[Message]) -> List[dict]:
"""
Convert the messages to the format expected by the Azure AI API.
"""
azure_ai_messages = []
for message in messages:
role = message.role
content = message.content

if role == "user":
azure_ai_messages.append({"role": role, "content": content})
elif role == "assistant":
azure_ai_messages.append({"role": role, "content": content, "tool_calls": message.tool_calls})
elif role == "system":
azure_ai_messages.append({"role": role, "content": content})
elif role == "ipython":
azure_ai_messages.append(
{
"role": "tool",
"content": content,
"tool_call_id": message.call_id
}
)

return azure_ai_messages


def _get_params(self, request: ChatCompletionRequest) -> dict:
"""
Gets the parameters for the Azure AI model inference API from the Chat completions request.
Parameters are returned as a dictionary.
"""
options = self._get_sampling_options(request.sampling_params, request.logprobs)
messages = self._to_azure_ai_messages(chat_completion_request_to_messages(request))
if (self._model_name):
# If the model name is already resolved, then the endpoint
# is serving a single model and we don't need to specify it
return {
"messages": messages,
**options
}
else:
return {
"messages": messages,
"model": SUPPORTED_INSTRUCT_MODELS.get(request.model, request.model),
**options
}

async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import * # noqa: F403

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class AzureAIInferenceConfig(BaseModel):
endpoint: str = Field(
default=None,
description="The endpoint URL where the model(s) is/are deployed.",
)
credential: Optional[str] = Field(
default=None,
description="The secret to access the model. If None, then `DefaultAzureCredential` is attempted.",
)
api_version: Optional[str] = Field(
default=None,
description="The API version to use in the endpoint. Indicating None will use the default version in the "
"`azure-ai-inference` package. Default use environment variable: AZURE_AI_API_VERSION",
)
3 changes: 3 additions & 0 deletions llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content

if hasattr(choice, "message"):
return choice.message.content
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bad merge?


if hasattr(choice, "message"):
return choice.message.content
Expand Down