Skip to content

Commit

Permalink
Resolve last parts of Issue 294 (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
robch authored Jul 5, 2024
1 parent 2bf1c64 commit 341a900
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/ai/.x/help/chat.assistant.create
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ USAGE: ai chat assistant create [...]
--files FILE1 [...]
--file-id ID
--file-ids ID1 [...]
--file-search TRUE/FALSE

TOOLS
--code-interpreter TRUE/FALSE
Expand Down
1 change: 1 addition & 0 deletions src/ai/.x/help/chat.assistant.update
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ USAGE: ai chat assistant update [...]
--files FILE1 [...]
--file-id ID
--file-ids ID1 [...]
--file-search TRUE/FALSE

TOOLS
--code-interpreter TRUE/FALSE
Expand Down
92 changes: 69 additions & 23 deletions src/ai/commands/chat_command.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ private async Task<bool> DoChatAssistantCreate()

DemandKeyAndEndpoint(out var key, out var endpoint);

var codeInterpreter = CodeInterpreterToken.Data().GetOrDefault(_values, false);
var codeInterpreter = CodeInterpreterTrueFalseToken.Data().GetOrDefault(_values, false);
var assistant = await OpenAIAssistantHelpers.CreateAssistantAsync(key, endpoint, name, deployment, instructions, codeInterpreter);
IdHelpers.CheckWriteOutputNameOrId(assistant.Id, _values, "chat", IdKind.Id);
IdHelpers.CheckWriteOutputNameOrId(assistant.Name, _values, "chat", IdKind.Name);
Expand All @@ -1115,7 +1115,8 @@ private async Task<bool> DoChatAssistantCreate()

files = ExpandFindFiles(files);

if (fileIds.Count() > 0 || files.Count() > 0)
var fileSearch = FileSearchTrueFalseToken.Data().GetOrDefault(_values, false);
if (fileSearch || fileIds.Count() > 0 || files.Count() > 0)
{
if (!_quiet) Console.WriteLine("\n Creating vector store ...");

Expand Down Expand Up @@ -1165,8 +1166,15 @@ private async Task<bool> DoChatAssistantUpdate()

var deployment = ConfigDeploymentToken.Data().GetOrDefault(_values, assistant.Model);
var instructions = InstructionsToken.Data().GetOrDefault(_values, assistant.Instructions);
var codeInterpreter = CodeInterpreterToken.Data().GetOrDefault(_values, assistant.Tools.FirstOrDefault(t => t is CodeInterpreterToolDefinition) != null);

var existingCodeInterpreter = assistant.Tools.FirstOrDefault(t => t is CodeInterpreterToolDefinition) != null;
var codeInterpreterSpecifiedTrue = CodeInterpreterTrueFalseToken.Data().GetOrDefault(_values, false) == true;
var codeInterpreterSpecifiedFalse = CodeInterpreterTrueFalseToken.Data().GetOrDefault(_values, true) == false;

var vectorStoreId = assistant.ToolResources.FileSearch?.VectorStoreIds?.FirstOrDefault();
var existingVectorStore = !string.IsNullOrEmpty(vectorStoreId);
var fileSearchSpecifiedFalse = FileSearchTrueFalseToken.Data().GetOrDefault(_values, true) == false;
var fileSearchSpecifiedTrue = FileSearchTrueFalseToken.Data().GetOrDefault(_values, false) == true;

var fileIds = FileIdOptionXToken.GetOptions(_values).ToList();
fileIds.AddRange(FileIdsOptionXToken.GetOptions(_values));
Expand All @@ -1177,25 +1185,30 @@ private async Task<bool> DoChatAssistantUpdate()
files.ExtendSplitItems(';');
files = ExpandFindFiles(files);

var existingVectorStore = !string.IsNullOrEmpty(vectorStoreId);
var newFilesForVectorStore = fileIds.Count() > 0 || files.Count() > 0;
if (newFilesForVectorStore)
var createVectorStore = !existingVectorStore && (newFilesForVectorStore || fileSearchSpecifiedTrue);
var updateVectorStore = existingVectorStore && newFilesForVectorStore;
var removeVectorStore = existingVectorStore && fileSearchSpecifiedFalse;

if (createVectorStore)
{
if (!existingVectorStore)
{
if (!_quiet) Console.WriteLine("\n Creating vector store ...");
if (!_quiet) Console.WriteLine("\n Creating vector store ...");

var store = await CreateAssistantVectorStoreAsync(key, endpoint, name, fileIds, files);
vectorStoreId = store.Id;
}
else
{
if (!_quiet) Console.WriteLine("\n Updating vector store ...");
var store = await CreateAssistantVectorStoreAsync(key, endpoint, name, fileIds, files);
vectorStoreId = store.Id;
}
else if (updateVectorStore)
{
if (!_quiet) Console.WriteLine("\n Updating vector store ...");

var store = await OpenAIAssistantHelpers.GetVectorStoreAsync(key, endpoint, vectorStoreId);
store = await UploadFilesToAssistantVectorStore(key, endpoint, fileIds, files, store);
vectorStoreId = store.Id;
}
var store = await OpenAIAssistantHelpers.GetVectorStoreAsync(key, endpoint, vectorStoreId);
store = await UploadFilesToAssistantVectorStore(key, endpoint, fileIds, files, store);
vectorStoreId = store.Id;
}
else if (removeVectorStore)
{
Console.WriteLine();
_values.AddThrowError("ERROR:", $"Removing vector store; not yet implemented.");
}

var modifyOptions = new AssistantModificationOptions()
Expand All @@ -1206,7 +1219,13 @@ private async Task<bool> DoChatAssistantUpdate()
ToolResources = ToolResourcesFromVectorStoreId(vectorStoreId),
};

if (codeInterpreter)
var removeCodeInterpreter = existingCodeInterpreter && codeInterpreterSpecifiedFalse;
if (removeCodeInterpreter)
{
Console.WriteLine();
_values.AddThrowError("ERROR:", $"Removing code interpreter; not yet implemented.");
}
else if (existingCodeInterpreter)
{
modifyOptions.DefaultTools.Add(new CodeInterpreterToolDefinition());
}
Expand Down Expand Up @@ -1265,12 +1284,11 @@ private async Task<bool> DoChatAssistantGet()
if (!_quiet) Console.WriteLine(message);

DemandKeyAndEndpoint(out var key, out var endpoint);
var json = await OpenAIAssistantHelpers.GetAssistantJsonAsync(key, endpoint, id);

if (!_quiet) Console.WriteLine($"{message} Done!\n");
var assistant = await OpenAIAssistantHelpers.GetAssistantAsync(key, endpoint, id);
PrintAssistant(assistant);

var ok = !string.IsNullOrEmpty(json);
if (ok) Console.WriteLine(json);
if (!_quiet) Console.WriteLine($"{message} Done!\n");

return true;
}
Expand Down Expand Up @@ -1687,6 +1705,34 @@ private void PrintVectorStore(string key, string endpoint, VectorStore store)
}
}

private static void PrintAssistant(Assistant assistant)
{
var id = assistant.Id;
var name = assistant.Name ?? "(no name)";

Console.WriteLine();
Console.WriteLine($" ID: {id}");
Console.WriteLine($" Name: {name}");
Console.WriteLine();
Console.WriteLine($" Model: {assistant.Model}");
Console.WriteLine($" Instructions: {assistant.Instructions}");
Console.WriteLine();

var toolNames = string.Join(", ", assistant.Tools.Select(x => x.GetType().Name));
Console.WriteLine($" Tools: {toolNames}");

var countVectorStoreIds = assistant.ToolResources?.FileSearch?.VectorStoreIds?.Count();
if (countVectorStoreIds > 0)
{
var vectorStoreIds = string.Join(", ", assistant.ToolResources?.FileSearch?.VectorStoreIds);
Console.WriteLine(vectorStoreIds.Contains(',')
? $" Vector stores: {vectorStoreIds}"
: $" Vector store: {vectorStoreIds}");
}

Console.WriteLine();
}

private void DemandKeyAndEndpoint(out string key, out string endpoint)
{
key = _values["service.config.key"];
Expand Down
8 changes: 5 additions & 3 deletions src/ai/commands/parsers/chat_command_parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,14 @@ public CommonChatNamedValueTokenParsers() : base(
};

private static INamedValueTokenParser[] _chatAssistantCreateCommandParsers = {
CodeInterpreterToken.Parser(),
CodeInterpreterTrueFalseToken.Parser(),

FileSearchTrueFalseToken.Parser(),
FileIdOptionXToken.Parser(),
FileIdsOptionXToken.Parser(),
FileOptionXToken.Parser(),
FilesOptionXToken.Parser(),

new CommonChatNamedValueTokenParsers(),

new Any1ValueNamedValueTokenParser(null, "chat.assistant.id", "001"),
Expand All @@ -181,8 +182,9 @@ public CommonChatNamedValueTokenParsers() : base(
};

private static INamedValueTokenParser[] _chatAssistantUpdateCommandParsers = {
CodeInterpreterToken.Parser(),
CodeInterpreterTrueFalseToken.Parser(),

FileSearchTrueFalseToken.Parser(),
FileIdOptionXToken.Parser(),
FileIdsOptionXToken.Parser(),
FileOptionXToken.Parser(),
Expand Down
12 changes: 12 additions & 0 deletions src/ai/helpers/openai/OpenAIAssistantHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ public static async Task DeleteAssistantAsync(AssistantClient client, string id)
var response = await client.DeleteAssistantAsync(id);
}

public static async Task<Assistant> GetAssistantAsync(string key, string endpoint, string id)
{
var client = CreateOpenAIAssistantClient(key, endpoint);
return await GetAssistantAsync(client, id);
}

public static async Task<Assistant> GetAssistantAsync(AssistantClient client, string id)
{
var response = await client.GetAssistantAsync(id);
return response.Value;
}

public static async Task<string?> GetAssistantJsonAsync(string key, string endpoint, string id)
{
var client = CreateOpenAIAssistantClient(key, endpoint);
Expand Down
2 changes: 1 addition & 1 deletion src/common/details/commands/parsers/command_parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ protected static void ParseInvalidArgumentError(INamedValueTokens tokens, INamed
{
if (!values.Contains("error"))
{
var restOfTokens = tokens.PeekAllTokens();
var restOfTokens = tokens.PeekAllTokens().Replace('.', '-');
var command = values.GetCommandForDisplay();
values.AddError(
"ERROR:", $"Invalid {kind} at \"{restOfTokens}\".", "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace Azure.AI.Details.Common.CLI
{
public class CodeInterpreterToken
public class CodeInterpreterTrueFalseToken
{
public static NamedValueTokenData Data() => new NamedValueTokenData(_optionName, _fullName, _optionExample, _requiredDisplayName);
public static INamedValueTokenParser Parser() => new TrueFalseNamedValueTokenParser(_optionName, _fullName, "11");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

namespace Azure.AI.Details.Common.CLI
{
public class FileSearchTrueFalseToken
{
public static NamedValueTokenData Data() => new NamedValueTokenData(_optionName, _fullName, _optionExample, _requiredDisplayName);
public static INamedValueTokenParser Parser() => new TrueFalseNamedValueTokenParser(_optionName, _fullName, "11");

private const string _requiredDisplayName = "file search";
private const string _optionName = "--file-search";
private const string _optionExample = "true|false";
private const string _fullName = "file.search";
}
}

0 comments on commit 341a900

Please sign in to comment.