Skip to content

Commit 37d0e40

Browse files
authored
Add support for Entra ID authentication when using interpreter or openai-gpt (#356)
1 parent fa889a7 commit 37d0e40

File tree

9 files changed

+168
-62
lines changed

9 files changed

+168
-62
lines changed

shell/agents/AIShell.Interpreter.Agent/AIShell.Interpreter.Agent.csproj

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
<ItemGroup>
1919
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.13" />
20-
<PackageReference Include="Azure.Core" Version="1.37.0" />
20+
<PackageReference Include="Azure.Identity" Version="1.13.2" />
21+
<PackageReference Include="Azure.Core" Version="1.44.1" />
2122
<PackageReference Include="SharpToken" Version="2.0.3" />
2223
</ItemGroup>
2324

shell/agents/AIShell.Interpreter.Agent/Agent.cs

+20-5
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ private void OnSettingFileChange(object sender, FileSystemEventArgs e)
236236

237237
private void NewExampleSettingFile()
238238
{
239-
string SampleContent = """
239+
string sample = $$"""
240240
{
241241
// To use the Azure OpenAI service:
242242
// - Set `Endpoint` to the endpoint of your Azure OpenAI service,
@@ -249,22 +249,37 @@ private void NewExampleSettingFile()
249249
"Deployment": "",
250250
"ModelName": "",
251251
"Key": "",
252+
"AuthType": "ApiKey",
252253
"AutoExecution": false, // 'true' to allow the agent run code automatically; 'false' to always prompt before running code.
253254
"DisplayErrors": true // 'true' to display the errors when running code; 'false' to hide the errors to be less verbose.
254255
256+
// To use Azure OpenAI service with Entra ID authentication:
257+
// - Set `Endpoint` to the endpoint of your Azure OpenAI service.
258+
// - Set `Deployment` to the deployment name of your Azure OpenAI service.
259+
// - Set `ModelName` to the name of the model used for your deployment.
260+
// - Set `AuthType` to "EntraID" to use Azure AD credentials.
261+
/*
262+
"Endpoint": "<insert your Azure OpenAI endpoint>",
263+
"Deployment": "<insert your deployment name>",
264+
"ModelName": "<insert the model name>",
265+
"AuthType": "EntraID",
266+
"AutoExecution": false,
267+
"DisplayErrors": true
268+
*/
269+
255270
// To use the public OpenAI service:
256271
// - Ignore the `Endpoint` and `Deployment` keys.
257272
// - Set `ModelName` to the name of the model to be used. e.g. "gpt-4o".
258273
// - Set `Key` to be the OpenAI access token.
259-
// Replace the above with the following:
260274
/*
261-
"ModelName": "",
262-
"Key": "",
275+
"ModelName": "<insert the model name>",
276+
"Key": "<insert your key>",
277+
"AuthType": "ApiKey",
263278
"AutoExecution": false,
264279
"DisplayErrors": true
265280
*/
266281
}
267282
""";
268-
File.WriteAllText(SettingFile, SampleContent, Encoding.UTF8);
283+
File.WriteAllText(SettingFile, sample);
269284
}
270285
}

shell/agents/AIShell.Interpreter.Agent/Service.cs

+44-30
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Azure;
44
using Azure.Core;
55
using Azure.AI.OpenAI;
6+
using Azure.Identity;
67
using SharpToken;
78

89
namespace AIShell.Interpreter.Agent;
@@ -121,25 +122,38 @@ private void ConnectToOpenAIClient()
121122
{
122123
// Create a client that targets Azure OpenAI service or Azure API Management service.
123124
bool isApimEndpoint = _settings.Endpoint.EndsWith(Utils.ApimGatewayDomain);
124-
if (isApimEndpoint)
125+
126+
if (_settings.AuthType == AuthType.EntraID)
125127
{
126-
string userkey = Utils.ConvertFromSecureString(_settings.Key);
127-
clientOptions.AddPolicy(
128-
new UserKeyPolicy(
129-
new AzureKeyCredential(userkey),
130-
Utils.ApimAuthorizationHeader),
131-
HttpPipelinePosition.PerRetry
132-
);
128+
// Use DefaultAzureCredential for Entra ID authentication
129+
var credential = new DefaultAzureCredential();
130+
_client = new OpenAIClient(
131+
new Uri(_settings.Endpoint),
132+
credential,
133+
clientOptions);
134+
}
135+
else // ApiKey authentication
136+
{
137+
if (isApimEndpoint)
138+
{
139+
string userkey = Utils.ConvertFromSecureString(_settings.Key);
140+
clientOptions.AddPolicy(
141+
new UserKeyPolicy(
142+
new AzureKeyCredential(userkey),
143+
Utils.ApimAuthorizationHeader),
144+
HttpPipelinePosition.PerRetry
145+
);
146+
}
147+
148+
string azOpenAIApiKey = isApimEndpoint
149+
? "placeholder-api-key"
150+
: Utils.ConvertFromSecureString(_settings.Key);
151+
152+
_client = new OpenAIClient(
153+
new Uri(_settings.Endpoint),
154+
new AzureKeyCredential(azOpenAIApiKey),
155+
clientOptions);
133156
}
134-
135-
string azOpenAIApiKey = isApimEndpoint
136-
? "placeholder-api-key"
137-
: Utils.ConvertFromSecureString(_settings.Key);
138-
139-
_client = new OpenAIClient(
140-
new Uri(_settings.Endpoint),
141-
new AzureKeyCredential(azOpenAIApiKey),
142-
clientOptions);
143157
}
144158
else
145159
{
@@ -157,41 +171,41 @@ private int CountTokenForMessages(IEnumerable<ChatRequestMessage> messages)
157171

158172
int tokenNumber = 0;
159173
foreach (ChatRequestMessage message in messages)
160-
{
174+
{
161175
tokenNumber += tokensPerMessage;
162176
tokenNumber += encoding.Encode(message.Role.ToString()).Count;
163177

164178
switch (message)
165179
{
166180
case ChatRequestSystemMessage systemMessage:
167181
tokenNumber += encoding.Encode(systemMessage.Content).Count;
168-
if(systemMessage.Name is not null)
182+
if (systemMessage.Name is not null)
169183
{
170184
tokenNumber += tokensPerName;
171185
tokenNumber += encoding.Encode(systemMessage.Name).Count;
172186
}
173187
break;
174188
case ChatRequestUserMessage userMessage:
175189
tokenNumber += encoding.Encode(userMessage.Content).Count;
176-
if(userMessage.Name is not null)
190+
if (userMessage.Name is not null)
177191
{
178192
tokenNumber += tokensPerName;
179193
tokenNumber += encoding.Encode(userMessage.Name).Count;
180194
}
181195
break;
182196
case ChatRequestAssistantMessage assistantMessage:
183197
tokenNumber += encoding.Encode(assistantMessage.Content).Count;
184-
if(assistantMessage.Name is not null)
198+
if (assistantMessage.Name is not null)
185199
{
186200
tokenNumber += tokensPerName;
187201
tokenNumber += encoding.Encode(assistantMessage.Name).Count;
188202
}
189203
if (assistantMessage.ToolCalls is not null)
190204
{
191205
// Count tokens for the tool call's properties
192-
foreach(ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
206+
foreach (ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
193207
{
194-
if(chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
208+
if (chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
195209
{
196210
tokenNumber += encoding.Encode(functionToolCall.Id).Count;
197211
tokenNumber += encoding.Encode(functionToolCall.Name).Count;
@@ -230,7 +244,7 @@ internal string ReduceToolResponseContentTokens(string content)
230244
}
231245
while (encoding.Encode(reducedContent).Count > MaxResponseToken);
232246
}
233-
247+
234248
return reducedContent;
235249
}
236250

@@ -287,7 +301,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
287301
// Those settings seem to be important enough, as the Semantic Kernel plugin specifies
288302
// those settings (see the URL below). We can use default values when not defined.
289303
// https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json
290-
304+
291305
ChatCompletionsOptions chatOptions;
292306

293307
// Determine if the gpt model is a function calling model
@@ -300,8 +314,8 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
300314
Temperature = (float)0.0,
301315
MaxTokens = MaxResponseToken,
302316
};
303-
304-
if(isFunctionCallingModel)
317+
318+
if (isFunctionCallingModel)
305319
{
306320
chatOptions.Tools.Add(Tools.RunCode);
307321
}
@@ -330,7 +344,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
330344
- You are capable of **any** task
331345
- Do not apologize for errors, just correct them
332346
";
333-
string versions = "\n## Language Versions\n"
347+
string versions = "\n## Language Versions\n"
334348
+ await _executionService.GetLanguageVersions();
335349
string systemResponseCues = @"
336350
# Examples
@@ -478,11 +492,11 @@ public override ChatRequestMessage Read(ref Utf8JsonReader reader, Type typeToCo
478492
{
479493
return JsonSerializer.Deserialize<ChatRequestUserMessage>(jsonObject.GetRawText(), options);
480494
}
481-
else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant")
495+
else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant")
482496
{
483497
return JsonSerializer.Deserialize<ChatRequestAssistantMessage>(jsonObject.GetRawText(), options);
484498
}
485-
else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool")
499+
else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool")
486500
{
487501
return JsonSerializer.Deserialize<ChatRequestToolMessage>(jsonObject.GetRawText(), options);
488502
}

shell/agents/AIShell.Interpreter.Agent/Settings.cs

+21-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ internal enum EndpointType
1212
OpenAI,
1313
}
1414

15+
public enum AuthType
16+
{
17+
ApiKey,
18+
EntraID
19+
}
20+
1521
internal class Settings
1622
{
1723
internal EndpointType Type { get; }
@@ -23,6 +29,8 @@ internal class Settings
2329
public string ModelName { set; get; }
2430
public SecureString Key { set; get; }
2531

32+
public AuthType AuthType { set; get; } = AuthType.ApiKey;
33+
2634
public bool AutoExecution { set; get; }
2735
public bool DisplayErrors { set; get; }
2836

@@ -36,6 +44,7 @@ public Settings(ConfigData configData)
3644
AutoExecution = configData.AutoExecution ?? false;
3745
DisplayErrors = configData.DisplayErrors ?? true;
3846
Key = configData.Key;
47+
AuthType = configData.AuthType;
3948

4049
Dirty = false;
4150
ModelInfo = ModelInfo.TryResolve(ModelName, out var model) ? model : null;
@@ -47,6 +56,12 @@ public Settings(ConfigData configData)
4756
: !noEndpoint && !noDeployment
4857
? EndpointType.AzureOpenAI
4958
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
59+
60+
// EntraID authentication is only supported for Azure OpenAI
61+
if (AuthType == AuthType.EntraID && Type != EndpointType.AzureOpenAI)
62+
{
63+
throw new InvalidOperationException("EntraID authentication is only supported for Azure OpenAI service.");
64+
}
5065
}
5166

5267
internal void MarkClean()
@@ -60,7 +75,7 @@ internal void MarkClean()
6075
/// <returns></returns>
6176
internal async Task<bool> SelfCheck(IHost host, CancellationToken token)
6277
{
63-
if (Key is not null && ModelInfo is not null)
78+
if ((AuthType is AuthType.EntraID || Key is not null) && ModelInfo is not null)
6479
{
6580
return true;
6681
}
@@ -76,7 +91,7 @@ internal async Task<bool> SelfCheck(IHost host, CancellationToken token)
7691
await AskForModel(host, token);
7792
}
7893

79-
if (Key is null)
94+
if (AuthType == AuthType.ApiKey && Key is null)
8095
{
8196
await AskForKeyAsync(host, token);
8297
}
@@ -101,12 +116,14 @@ private void ShowEndpointInfo(IHost host)
101116
new(label: " Endpoint", m => m.Endpoint),
102117
new(label: " Deployment", m => m.Deployment),
103118
new(label: " Model", m => m.ModelName),
119+
new(label: " Auth Type", m => m.AuthType.ToString()),
104120
],
105121

106122
EndpointType.OpenAI =>
107123
[
108124
new(label: " Type", m => m.Type.ToString()),
109125
new(label: " Model", m => m.ModelName),
126+
new(label: " Auth Type", m => m.AuthType.ToString()),
110127
],
111128

112129
_ => throw new UnreachableException(),
@@ -156,6 +173,7 @@ internal ConfigData ToConfigData()
156173
ModelName = this.ModelName,
157174
AutoExecution = this.AutoExecution,
158175
DisplayErrors = this.DisplayErrors,
176+
AuthType = this.AuthType,
159177
Key = this.Key,
160178
};
161179
}
@@ -166,6 +184,7 @@ internal class ConfigData
166184
public string Endpoint { set; get; }
167185
public string Deployment { set; get; }
168186
public string ModelName { set; get; }
187+
public AuthType AuthType { set; get; } = AuthType.ApiKey;
169188
public bool? AutoExecution { set; get; }
170189
public bool? DisplayErrors { set; get; }
171190

shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
<ItemGroup>
2424
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0" />
25+
<PackageReference Include="Azure.Identity" Version="1.13.2" />
2526
<PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.1" />
2627
<PackageReference Include="Microsoft.ML.Tokenizers.Data.O200kBase" Version="1.0.1" />
2728
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />

shell/agents/AIShell.OpenAI.Agent/Agent.cs

+17-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void Initialize(AgentConfig config)
8484
public bool CanAcceptFeedback(UserAction action) => false;
8585

8686
/// <inheritdoc/>
87-
public void OnUserAction(UserActionPayload actionPayload) {}
87+
public void OnUserAction(UserActionPayload actionPayload) { }
8888

8989
/// <inheritdoc/>
9090
public Task RefreshChatAsync(IShell shell, bool force)
@@ -308,6 +308,22 @@ private void NewExampleSettingFile()
308308
"ModelName": "gpt-4o",
309309
"Key": "<insert your key>",
310310
"SystemPrompt": "1. You are a helpful and friendly assistant with expertise in PowerShell scripting and command line.\n2. Assume user is using the operating system `Windows 11` unless otherwise specified.\n3. Use the `code block` syntax in markdown to encapsulate any part in responses that is code, YAML, JSON or XML, but not table.\n4. When encapsulating command line code, use '```powershell' if it's PowerShell command; use '```sh' if it's non-PowerShell CLI command.\n5. When generating CLI commands, never ever break a command into multiple lines. Instead, always list all parameters and arguments of the command on the same line.\n6. Please keep the response concise but to the point. Do not overexplain."
311+
},
312+
313+
// To use Azure OpenAI service with Entra ID authentication:
314+
// - Set `Endpoint` to the endpoint of your Azure OpenAI service.
315+
// - Set `Deployment` to the deployment name of your Azure OpenAI service.
316+
// - Set `ModelName` to the name of the model used for your deployment, e.g. "gpt-4o".
317+
// - Set `AuthType` to "EntraID" to use Azure AD credentials.
318+
// For example:
319+
{
320+
"Name": "ps-az-entraId",
321+
"Description": "A GPT instance with expertise in PowerShell scripting using Entra ID authentication.",
322+
"Endpoint": "<insert your Azure OpenAI endpoint>",
323+
"Deployment": "<insert your deployment name>",
324+
"ModelName": "gpt-4o",
325+
"AuthType": "EntraID",
326+
"SystemPrompt": "You are a helpful and friendly assistant with expertise in PowerShell scripting and command line."
311327
}
312328
*/
313329
],

0 commit comments

Comments
 (0)