Skip to content

Commit 15f13b2

Browse files
authored
Python: allow default_headers configuration and users to pass in custom AzureOpenAI/OpenAI clients. (microsoft#3903)
### Motivation and Context Currently, users in Python have no ability to configure default_headers for the AzureOpenAI/OpenAI clients we use. Additionally, if a user wants to set up their own AzureOpenAI/OpenAI client, this isn't possible. This PR addresses microsoft#2895. <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> ### Description Allow users to set default_headers for Chat/Text/Embedding Completion classes for both OpenAI and AzureOpenAI. Additionally, allow users to pass in a customer OpenAI/AzureOpenAI client, if desired. Update tests accordingly. <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone 😄 --------- Co-authored-by: Evan Mattson <[email protected]>
1 parent 10e1998 commit 15f13b2

26 files changed

+861
-78
lines changed

.vscode/settings.json

+6-1
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,10 @@
9191
},
9292
"java.debug.settings.onBuildFailureProceed": true,
9393
"java.compile.nullAnalysis.mode": "disabled",
94-
"dotnet.defaultSolution": "dotnet\\SK-dotnet.sln"
94+
"dotnet.defaultSolution": "dotnet\\SK-dotnet.sln",
95+
"python.testing.pytestArgs": [
96+
"python/tests"
97+
],
98+
"python.testing.unittestEnabled": false,
99+
"python.testing.pytestEnabled": true
95100
}

python/semantic_kernel/connectors/ai/open_ai/const.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from typing import final
44

55
DEFAULT_AZURE_API_VERSION: final = "2023-05-15"
6+
USER_AGENT: final = "User-Agent"

python/semantic_kernel/connectors/ai/open_ai/services/azure_chat_completion.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33

44
from logging import Logger
5-
from typing import Dict, Optional, Union, overload
5+
from typing import Dict, Mapping, Optional, Union, overload
66

7+
from openai import AsyncAzureOpenAI
78
from openai.lib.azure import AsyncAzureADTokenProvider
89

910
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
@@ -36,6 +37,7 @@ def __init__(
3637
api_key: Optional[str] = None,
3738
ad_token: Optional[str] = None,
3839
ad_token_provider: Optional[AsyncAzureADTokenProvider] = None,
40+
default_headers: Optional[Mapping[str, str]] = None,
3941
log: Optional[Logger] = None,
4042
) -> None:
4143
"""
@@ -59,6 +61,8 @@ def __init__(
5961
The default value is "2023-05-15".
6062
ad_auth: Whether to use Azure Active Directory authentication. (Optional)
6163
The default value is False.
64+
default_headers: The default headers mapping of string keys to
65+
string values for HTTP requests. (Optional)
6266
log: The logger instance to use. (Optional)
6367
logger: deprecated, use 'log' instead.
6468
"""
@@ -72,6 +76,7 @@ def __init__(
7276
api_key: Optional[str] = None,
7377
ad_token: Optional[str] = None,
7478
ad_token_provider: Optional[AsyncAzureADTokenProvider] = None,
79+
default_headers: Optional[Mapping[str, str]] = None,
7580
log: Optional[Logger] = None,
7681
) -> None:
7782
"""
@@ -93,10 +98,32 @@ def __init__(
9398
The default value is "2023-05-15".
9499
ad_auth: Whether to use Azure Active Directory authentication. (Optional)
95100
The default value is False.
101+
default_headers: The default headers mapping of string keys to
102+
string values for HTTP requests. (Optional)
96103
log: The logger instance to use. (Optional)
97104
logger: deprecated, use 'log' instead.
98105
"""
99106

107+
@overload
108+
def __init__(
109+
self,
110+
deployment_name: str,
111+
async_client: AsyncAzureOpenAI,
112+
log: Optional[Logger] = None,
113+
) -> None:
114+
"""
115+
Initialize an AzureChatCompletion service.
116+
117+
Arguments:
118+
deployment_name: The name of the Azure deployment. This value
119+
will correspond to the custom name you chose for your deployment
120+
when you deployed a model. This value can be found under
121+
Resource Management > Deployments in the Azure portal or, alternatively,
122+
under Management > Deployments in Azure OpenAI Studio.
123+
async_client {AsyncAzureOpenAI} -- An existing client to use.
124+
log: The logger instance to use. (Optional)
125+
"""
126+
100127
def __init__(
101128
self,
102129
deployment_name: str,
@@ -106,8 +133,10 @@ def __init__(
106133
api_key: Optional[str] = None,
107134
ad_token: Optional[str] = None,
108135
ad_token_provider: Optional[AsyncAzureADTokenProvider] = None,
136+
default_headers: Optional[Mapping[str, str]] = None,
109137
log: Optional[Logger] = None,
110138
logger: Optional[Logger] = None,
139+
async_client: Optional[AsyncAzureOpenAI] = None,
111140
) -> None:
112141
"""
113142
Initialize an AzureChatCompletion service.
@@ -134,8 +163,11 @@ def __init__(
134163
The default value is "2023-05-15".
135164
ad_auth: Whether to use Azure Active Directory authentication. (Optional)
136165
The default value is False.
166+
default_headers: The default headers mapping of string keys to
167+
string values for HTTP requests. (Optional)
137168
log: The logger instance to use. (Optional)
138169
logger: deprecated, use 'log' instead.
170+
async_client {Optional[AsyncAzureOpenAI]} -- An existing client to use. (Optional)
139171
"""
140172
if logger:
141173
logger.warning("The 'logger' argument is deprecated, use 'log' instead.")
@@ -151,8 +183,10 @@ def __init__(
151183
api_key=api_key,
152184
ad_token=ad_token,
153185
ad_token_provider=ad_token_provider,
186+
default_headers=default_headers,
154187
log=log or logger,
155188
ai_model_type=OpenAIModelTypes.CHAT,
189+
async_client=async_client,
156190
)
157191

158192
@classmethod
@@ -163,7 +197,7 @@ def from_dict(cls, settings: Dict[str, str]) -> "AzureChatCompletion":
163197
Arguments:
164198
settings: A dictionary of settings for the service.
165199
should contains keys: deployment_name, endpoint, api_key
166-
and optionally: api_version, ad_auth, log
200+
and optionally: api_version, ad_auth, default_headers, log
167201
"""
168202
return AzureChatCompletion(
169203
deployment_name=settings.get("deployment_name"),
@@ -173,5 +207,6 @@ def from_dict(cls, settings: Dict[str, str]) -> "AzureChatCompletion":
173207
api_key=settings.get("api_key"),
174208
ad_token=settings.get("ad_token"),
175209
ad_token_provider=settings.get("ad_token_provider"),
210+
default_headers=settings.get("default_headers"),
176211
log=settings.get("log"),
177212
)

python/semantic_kernel/connectors/ai/open_ai/services/azure_config_base.py

+64-35
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
import json
44
from logging import Logger
5-
from typing import Any, Awaitable, Callable, Dict, Optional, Union
5+
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Union
66

77
from openai import AsyncAzureOpenAI
88
from pydantic import validate_call
99

1010
from semantic_kernel.connectors.ai.ai_exception import AIException
11-
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
11+
from semantic_kernel.connectors.ai.open_ai.const import (
12+
DEFAULT_AZURE_API_VERSION,
13+
USER_AGENT,
14+
)
1215
from semantic_kernel.connectors.ai.open_ai.services.open_ai_handler import (
1316
OpenAIHandler,
1417
OpenAIModelTypes,
@@ -31,56 +34,82 @@ def __init__(
3134
api_key: Optional[str] = None,
3235
ad_token: Optional[str] = None,
3336
ad_token_provider: Optional[Callable[[], Union[str, Awaitable[str]]]] = None,
37+
default_headers: Union[Mapping[str, str], None] = None,
3438
log: Optional[Logger] = None,
39+
async_client: Optional[AsyncAzureOpenAI] = None,
3540
) -> None:
36-
# TODO: add SK user-agent here
37-
if not api_key and not ad_token and not ad_token_provider:
38-
raise AIException(
39-
AIException.ErrorCodes.InvalidConfiguration,
40-
"Please provide either api_key, ad_token or ad_token_provider",
41-
)
42-
if base_url:
43-
client = AsyncAzureOpenAI(
44-
base_url=str(base_url),
45-
api_version=api_version,
46-
api_key=api_key,
47-
azure_ad_token=ad_token,
48-
azure_ad_token_provider=ad_token_provider,
49-
default_headers={"User-Agent": json.dumps(APP_INFO)}
50-
if APP_INFO
51-
else None,
52-
)
53-
else:
54-
if not endpoint:
41+
"""Internal class for configuring a connection to an Azure OpenAI service.
42+
43+
Arguments:
44+
deployment_name {str} -- Name of the deployment.
45+
ai_model_type {OpenAIModelTypes} -- The type of OpenAI model to deploy.
46+
endpoint {Optional[HttpsUrl]} -- The specific endpoint URL for the deployment. (Optional)
47+
base_url {Optional[HttpsUrl]} -- The base URL for Azure services. (Optional)
48+
api_version {str} -- Azure API version. Defaults to the defined DEFAULT_AZURE_API_VERSION.
49+
api_key {Optional[str]} -- API key for Azure services. (Optional)
50+
ad_token {Optional[str]} -- Azure AD token for authentication. (Optional)
51+
ad_token_provider {Optional[Callable[[], Union[str, Awaitable[str]]]]} -- A callable
52+
or coroutine function providing Azure AD tokens. (Optional)
53+
default_headers {Union[Mapping[str, str], None]} -- Default headers for HTTP requests. (Optional)
54+
log {Optional[Logger]} -- Logger instance for logging purposes. (Optional)
55+
async_client {Optional[AsyncAzureOpenAI]} -- An existing client to use. (Optional)
56+
57+
The `validate_call` decorator is used with a configuration that allows arbitrary types.
58+
This is necessary for types like `HttpsUrl` and `OpenAIModelTypes`.
59+
"""
60+
# Merge APP_INFO into the headers if it exists
61+
merged_headers = default_headers.copy() if default_headers else {}
62+
if APP_INFO:
63+
merged_headers[USER_AGENT] = json.dumps(APP_INFO)
64+
65+
if not async_client:
66+
if not api_key and not ad_token and not ad_token_provider:
5567
raise AIException(
5668
AIException.ErrorCodes.InvalidConfiguration,
57-
"Please provide either base_url or endpoint",
69+
"Please provide either api_key, ad_token or ad_token_provider",
5870
)
59-
client = AsyncAzureOpenAI(
60-
azure_endpoint=endpoint,
61-
azure_deployment=deployment_name,
62-
api_version=api_version,
63-
api_key=api_key,
64-
azure_ad_token=ad_token,
65-
azure_ad_token_provider=ad_token_provider,
66-
default_headers={"User-Agent": json.dumps(APP_INFO)}
67-
if APP_INFO
68-
else None,
69-
)
71+
if base_url:
72+
async_client = AsyncAzureOpenAI(
73+
base_url=str(base_url),
74+
api_version=api_version,
75+
api_key=api_key,
76+
azure_ad_token=ad_token,
77+
azure_ad_token_provider=ad_token_provider,
78+
default_headers=merged_headers,
79+
)
80+
else:
81+
if not endpoint:
82+
raise AIException(
83+
AIException.ErrorCodes.InvalidConfiguration,
84+
"Please provide either base_url or endpoint",
85+
)
86+
async_client = AsyncAzureOpenAI(
87+
azure_endpoint=endpoint,
88+
azure_deployment=deployment_name,
89+
api_version=api_version,
90+
api_key=api_key,
91+
azure_ad_token=ad_token,
92+
azure_ad_token_provider=ad_token_provider,
93+
default_headers=merged_headers,
94+
)
95+
7096
super().__init__(
7197
ai_model_id=deployment_name,
7298
log=log,
73-
client=client,
99+
client=async_client,
74100
ai_model_type=ai_model_type,
75101
)
76102

77103
def to_dict(self) -> Dict[str, str]:
78104
client_settings = {
79-
"base_url": self.client.base_url,
105+
"base_url": str(self.client.base_url),
80106
"api_version": self.client._custom_query["api-version"],
81107
"api_key": self.client.api_key,
82108
"ad_token": self.client._azure_ad_token,
83109
"ad_token_provider": self.client._azure_ad_token_provider,
110+
"default_headers": {
111+
k: v for k, v in self.client.default_headers.items() if k != USER_AGENT
112+
},
84113
}
85114
base = self.model_dump(
86115
exclude={

python/semantic_kernel/connectors/ai/open_ai/services/azure_text_completion.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33

44
from logging import Logger
5-
from typing import Dict, Optional, overload
5+
from typing import Dict, Mapping, Optional, overload
66

7+
from openai import AsyncAzureOpenAI
78
from openai.lib.azure import AsyncAzureADTokenProvider
89

910
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
@@ -29,6 +30,7 @@ def __init__(
2930
api_key: Optional[str] = None,
3031
ad_token: Optional[str] = None,
3132
ad_token_provider: Optional[AsyncAzureADTokenProvider] = None,
33+
default_headers: Optional[Mapping[str, str]] = None,
3234
log: Optional[Logger] = None,
3335
) -> None:
3436
"""
@@ -50,6 +52,8 @@ def __init__(
5052
The default value is "2023-05-15".
5153
ad_auth: Whether to use Azure Active Directory authentication. (Optional)
5254
The default value is False.
55+
default_headers: The default headers mapping of string keys to
56+
string values for HTTP requests. (Optional)
5357
log: The logger instance to use. (Optional)
5458
logger: deprecated, use 'log' instead.
5559
"""
@@ -63,6 +67,7 @@ def __init__(
6367
api_key: Optional[str] = None,
6468
ad_token: Optional[str] = None,
6569
ad_token_provider: Optional[AsyncAzureADTokenProvider] = None,
70+
default_headers: Optional[Mapping[str, str]] = None,
6671
log: Optional[Logger] = None,
6772
) -> None:
6873
"""
@@ -84,10 +89,32 @@ def __init__(
8489
The default value is "2023-05-15".
8590
ad_auth: Whether to use Azure Active Directory authentication. (Optional)
8691
The default value is False.
92+
default_headers: The default headers mapping of string keys to
93+
string values for HTTP requests. (Optional)
8794
log: The logger instance to use. (Optional)
8895
logger: deprecated, use 'log' instead.
8996
"""
9097

98+
@overload
99+
def __init__(
100+
self,
101+
deployment_name: str,
102+
async_client: AsyncAzureOpenAI,
103+
log: Optional[Logger] = None,
104+
) -> None:
105+
"""
106+
Initialize an AzureChatCompletion service.
107+
108+
Arguments:
109+
deployment_name: The name of the Azure deployment. This value
110+
will correspond to the custom name you chose for your deployment
111+
when you deployed a model. This value can be found under
112+
Resource Management > Deployments in the Azure portal or, alternatively,
113+
under Management > Deployments in Azure OpenAI Studio.
114+
async_client {AsyncAzureOpenAI} -- An existing client to use.
115+
log: The logger instance to use. (Optional)
116+
"""
117+
91118
def __init__(
92119
self,
93120
deployment_name: Optional[str] = None,
@@ -97,8 +124,10 @@ def __init__(
97124
api_key: Optional[str] = None,
98125
ad_token: Optional[str] = None,
99126
ad_token_provider: Optional[AsyncAzureADTokenProvider] = None,
127+
default_headers: Optional[Mapping[str, str]] = None,
100128
log: Optional[Logger] = None,
101129
logger: Optional[Logger] = None,
130+
async_client: Optional[AsyncAzureOpenAI] = None,
102131
) -> None:
103132
"""
104133
Initialize an AzureTextCompletion service.
@@ -119,8 +148,11 @@ def __init__(
119148
The default value is "2023-03-15-preview".
120149
ad_auth: Whether to use Azure Active Directory authentication. (Optional)
121150
The default value is False.
151+
default_headers: The default headers mapping of string keys to
152+
string values for HTTP requests. (Optional)
122153
log: The logger instance to use. (Optional)
123154
logger: deprecated, use 'log' instead.
155+
async_client {Optional[AsyncAzureOpenAI]} -- An existing client to use.
124156
"""
125157
if logger:
126158
logger.warning("The 'logger' argument is deprecated, use 'log' instead.")
@@ -132,8 +164,10 @@ def __init__(
132164
api_key=api_key,
133165
ad_token=ad_token,
134166
ad_token_provider=ad_token_provider,
167+
default_headers=default_headers,
135168
log=log or logger,
136169
ai_model_type=OpenAIModelTypes.TEXT,
170+
async_client=async_client,
137171
)
138172

139173
@classmethod
@@ -155,5 +189,6 @@ def from_dict(cls, settings: Dict[str, str]) -> "AzureTextCompletion":
155189
api_key=settings["api_key"],
156190
ad_token=settings.get("ad_token"),
157191
ad_token_provider=settings.get("ad_token_provider"),
192+
default_headers=settings.get("default_headers"),
158193
log=settings.get("log"),
159194
)

0 commit comments

Comments
 (0)