From 27a0545f5fc08390816be7e723d6d6c63d729f40 Mon Sep 17 00:00:00 2001 From: Facundo Santiago Date: Mon, 4 Nov 2024 06:41:15 +0000 Subject: [PATCH 1/3] feat: azure ai inference support --- .../inference/azure_ai_inference/__init__.py | 17 ++ .../azure_ai_inference/azure_ai_inference.py | 259 ++++++++++++++++++ .../inference/azure_ai_inference/config.py | 26 ++ llama_stack/providers/registry/inference.py | 9 + .../utils/inference/openai_compat.py | 6 + 5 files changed, 317 insertions(+) create mode 100644 llama_stack/providers/adapters/inference/azure_ai_inference/__init__.py create mode 100644 llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py create mode 100644 llama_stack/providers/adapters/inference/azure_ai_inference/config.py diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/__init__.py b/llama_stack/providers/adapters/inference/azure_ai_inference/__init__.py new file mode 100644 index 0000000000..b332d926c2 --- /dev/null +++ b/llama_stack/providers/adapters/inference/azure_ai_inference/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from .azure_ai_inference import AzureAIInferenceAdapter +from .config import AzureAIInferenceConfig + + +async def get_adapter_impl(config: AzureAIInferenceConfig, _deps): + assert isinstance(config, AzureAIInferenceConfig), f"Unexpected config type: {type(config)}" + + impl = AzureAIInferenceAdapter(config) + + await impl.initialize() + + return impl diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py b/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py new file mode 100644 index 0000000000..f36f9e3b55 --- /dev/null +++ b/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import AsyncGenerator + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model + +from azure.ai.inference.aio import ChatCompletionsClient as ChatCompletionsClientAsync +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import HttpResponseError +from azure.identity import DefaultAzureCredential + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.datatypes import ModelsProtocolPrivate + +from llama_stack.providers.utils.inference.openai_compat import ( + process_chat_completion_response, + process_chat_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) + +from .config import AzureAIInferenceConfig + +# Mapping of model names from the Llama model names to the Azure AI model catalog names +SUPPORTED_INSTRUCT_MODELS = { + "Llama3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", + "Llama3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", + "Llama3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", + "Llama3.2-1B-Instruct": "Llama-3.2-1B-Instruct", + "Llama3.2-3B-Instruct": "Llama-3.2-3B-Instruct", + "Llama3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", + "Llama3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", +} + +logger = logging.getLogger(__name__) + +class AzureAIInferenceAdapter(Inference, ModelsProtocolPrivate): + def __init__(self, config: AzureAIInferenceConfig) -> None: + tokenizer = Tokenizer.get_instance() + + self.config = config + self.formatter = ChatFormat(tokenizer) + self._model_name = None + + @property + def client(self) -> ChatCompletionsClientAsync: + if self.config.credential is None: + credential = DefaultAzureCredential() + else: + credential = AzureKeyCredential(self.config.credential) + + if self.config.api_version: + return ChatCompletionsClientAsync( + endpoint=self.config.endpoint, + credential=credential, + user_agent="llama-stack", + api_version=self.config.api_version, + ) + else: + return ChatCompletionsClientAsync( + endpoint=self.config.endpoint, + credential=credential, + user_agent="llama-stack", + ) + + async def initialize(self) -> None: + async with self.client as async_client: + try: + model_info = await async_client.get_model_info() + if model_info: + self._model_name = model_info.get("model_name", None) + logger.info( + f"Endpoint {self.config.endpoint} supports model {self._model_name}" + ) + if self._model_name not in SUPPORTED_INSTRUCT_MODELS.values(): + logger.warning( + f"Endpoints serves model {self._model_name} which may not be supported" + ) + except HttpResponseError: + logger.info( + f"Endpoint {self.config.endpoint} supports multiple models" + ) + self._model_name = None + + + async def shutdown(self) -> None: + pass + + + async def list_models(self) -> List[ModelDef]: + print("Model name: ", self._model_name) + if self._model_name is None: + return [ + ModelDef(identifier=model_name, llama_model=azure_model_id) + for model_name, azure_model_id in SUPPORTED_INSTRUCT_MODELS.items() + ] + else: + # find if there is a value in the SUPPORTED_INSTRUCT_MODELS that matches the model name + supported_model = next( + (model for model in SUPPORTED_INSTRUCT_MODELS if SUPPORTED_INSTRUCT_MODELS[model] == self._model_name), + None + ) + return [ + ModelDef( + identifier=supported_model or self._model_name, + llama_model=self._model_name + ) + ] + + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError() + + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model or self.config.model_name, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + params = self._get_params(request) + if stream: + return self._stream_chat_completion(params) + else: + return await self._nonstream_chat_completion(params) + + + async def _nonstream_chat_completion( + self, params: dict + ) -> ChatCompletionResponse: + async with self.client as client: + r = await client.complete(**params) + return process_chat_completion_response(r, self.formatter) + + + async def _stream_chat_completion( + self, params: dict + ) -> AsyncGenerator: + async with self.client as client: + stream = await client.complete(**params, stream=True) + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + + @staticmethod + def _get_sampling_options( + params: SamplingParams, + logprobs: Optional[LogProbConfig] = None + ) -> dict: + options = {} + model_extras = {} + if params: + # repetition_penalty is not supported by Azure AI inference API + for attr in {"temperature", "top_p", "max_tokens"}: + if getattr(params, attr): + options[attr] = getattr(params, attr) + + if params.top_k is not None and params.top_k != 0: + model_extras["top_k"] = params.top_k + + if logprobs is not None: + model_extras["logprobs"] = params.logprobs + + if model_extras: + options["model_extras"] = model_extras + + return options + + @staticmethod + def _to_azure_ai_messages(messages: List[Message]) -> List[dict]: + """ + Convert the messages to the format expected by the Azure AI API. + """ + azure_ai_messages = [] + for message in messages: + role = message.role + content = message.content + + if role == "user": + azure_ai_messages.append({"role": role, "content": content}) + elif role == "assistant": + azure_ai_messages.append({"role": role, "content": content, "tool_calls": message.tool_calls}) + elif role == "system": + azure_ai_messages.append({"role": role, "content": content}) + elif role == "ipython": + azure_ai_messages.append( + { + "role": "tool", + "content": content, + "tool_call_id": message.call_id + } + ) + + return azure_ai_messages + + + def _get_params(self, request: ChatCompletionRequest) -> dict: + """ + Gets the parameters for the Azure AI model inference API from the Chat completions request. + Parameters are returned as a dictionary. + """ + options = self._get_sampling_options(request.sampling_params, request.logprobs) + messages = self._to_azure_ai_messages(chat_completion_request_to_messages(request)) + if (self._model_name): + # If the model name is already resolved, then the endpoint + # is serving a single model and we don't need to specify it + return { + "messages": messages, + **options + } + else: + return { + "messages": messages, + "model": SUPPORTED_INSTRUCT_MODELS.get(request.model, request.model), + **options + } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/config.py b/llama_stack/providers/adapters/inference/azure_ai_inference/config.py new file mode 100644 index 0000000000..e24d8afc7a --- /dev/null +++ b/llama_stack/providers/adapters/inference/azure_ai_inference/config.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import * # noqa: F403 + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class AzureAIInferenceConfig(BaseModel): + endpoint: str = Field( + default=None, + description="The endpoint URL where the model(s) is/are deployed.", + ) + credential: Optional[str] = Field( + default=None, + description="The secret to access the model. If None, then `DefaultAzureCredential` is attempted.", + ) + api_version: Optional[str] = Field( + default=None, + description="The API version to use in the endpoint. Indicating None will use the default version in the " + "`azure-ai-inference` package. Default use environment variable: AZURE_AI_API_VERSION", + ) \ No newline at end of file diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b46..e034089ff3 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -140,6 +140,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="azure-ai-inference", + pip_packages=["azure-ai-inference", "azure-identity", "aiohttp"], + module="llama_stack.providers.adapters.inference.azure_ai_inference", + config_class="llama_stack.providers.adapters.inference.azure_ai_inference.AzureAIInferenceConfig", + ), + ), InlineProviderSpec( api=Api.inference, provider_type="vllm", diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 086227c731..9ab3d29efa 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -45,6 +45,9 @@ def get_sampling_options(params: SamplingParams) -> dict: def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: return choice.delta.content + + if hasattr(choice, "message"): + return choice.message.content return choice.text @@ -158,6 +161,9 @@ async def process_chat_completion_stream_response( break text = text_from_choice(choice) + if not text: + continue + # check if its a tool call ( aka starts with <|python_tag|> ) if not ipython and text.startswith("<|python_tag|>"): ipython = True From e247849d1bdcc66c3d7c03756af228587cd324f5 Mon Sep 17 00:00:00 2001 From: Facundo Santiago Date: Mon, 4 Nov 2024 07:54:31 +0000 Subject: [PATCH 2/3] nit --- .../inference/azure_ai_inference/azure_ai_inference.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py b/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py index f36f9e3b55..cbc2d61d90 100644 --- a/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py +++ b/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py @@ -9,9 +9,8 @@ from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from azure.ai.inference.aio import ChatCompletionsClient as ChatCompletionsClientAsync from azure.core.credentials import AzureKeyCredential @@ -55,7 +54,7 @@ def __init__(self, config: AzureAIInferenceConfig) -> None: @property def client(self) -> ChatCompletionsClientAsync: if self.config.credential is None: - credential = DefaultAzureCredential() + credential = DefaultAzureCredential() else: credential = AzureKeyCredential(self.config.credential) @@ -68,7 +67,7 @@ def client(self) -> ChatCompletionsClientAsync: ) else: return ChatCompletionsClientAsync( - endpoint=self.config.endpoint, + endpoint=self.config.endpoint, credential=credential, user_agent="llama-stack", ) @@ -98,7 +97,6 @@ async def shutdown(self) -> None: async def list_models(self) -> List[ModelDef]: - print("Model name: ", self._model_name) if self._model_name is None: return [ ModelDef(identifier=model_name, llama_model=azure_model_id) From 2b21e976244cd5a545563b93eb74293dbf0a2669 Mon Sep 17 00:00:00 2001 From: Facundo Santiago Date: Mon, 11 Nov 2024 21:14:52 +0000 Subject: [PATCH 3/3] feat: refactor code base --- llama_stack/providers/registry/inference.py | 4 ++-- .../inference/azure_ai_inference/__init__.py | 0 .../inference/azure_ai_inference/azure_ai_inference.py | 0 .../inference/azure_ai_inference/config.py | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename llama_stack/providers/{adapters => remote}/inference/azure_ai_inference/__init__.py (100%) rename llama_stack/providers/{adapters => remote}/inference/azure_ai_inference/azure_ai_inference.py (100%) rename llama_stack/providers/{adapters => remote}/inference/azure_ai_inference/config.py (100%) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f56df223fe..f22270b6c4 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -154,8 +154,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="azure-ai-inference", pip_packages=["azure-ai-inference", "azure-identity", "aiohttp"], - module="llama_stack.providers.adapters.inference.azure_ai_inference", - config_class="llama_stack.providers.adapters.inference.azure_ai_inference.AzureAIInferenceConfig", + module="llama_stack.providers.remote.inference.azure_ai_inference", + config_class="llama_stack.providers.remote.inference.azure_ai_inference.AzureAIInferenceConfig", ), ), ] diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/__init__.py b/llama_stack/providers/remote/inference/azure_ai_inference/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/azure_ai_inference/__init__.py rename to llama_stack/providers/remote/inference/azure_ai_inference/__init__.py diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py b/llama_stack/providers/remote/inference/azure_ai_inference/azure_ai_inference.py similarity index 100% rename from llama_stack/providers/adapters/inference/azure_ai_inference/azure_ai_inference.py rename to llama_stack/providers/remote/inference/azure_ai_inference/azure_ai_inference.py diff --git a/llama_stack/providers/adapters/inference/azure_ai_inference/config.py b/llama_stack/providers/remote/inference/azure_ai_inference/config.py similarity index 100% rename from llama_stack/providers/adapters/inference/azure_ai_inference/config.py rename to llama_stack/providers/remote/inference/azure_ai_inference/config.py