diff --git a/src/ai/.x/help/chat.assistant.create b/src/ai/.x/help/chat.assistant.create index 1f4f58ce..8cf48254 100644 --- a/src/ai/.x/help/chat.assistant.create +++ b/src/ai/.x/help/chat.assistant.create @@ -18,6 +18,7 @@ USAGE: ai chat assistant create [...] --files FILE1 [...] --file-id ID --file-ids ID1 [...] + --file-search TRUE/FALSE TOOLS --code-interpreter TRUE/FALSE diff --git a/src/ai/.x/help/chat.assistant.update b/src/ai/.x/help/chat.assistant.update index 0fef7e4e..9fb1d984 100644 --- a/src/ai/.x/help/chat.assistant.update +++ b/src/ai/.x/help/chat.assistant.update @@ -18,6 +18,7 @@ USAGE: ai chat assistant update [...] --files FILE1 [...] --file-id ID --file-ids ID1 [...] + --file-search TRUE/FALSE TOOLS --code-interpreter TRUE/FALSE diff --git a/src/ai/commands/chat_command.cs b/src/ai/commands/chat_command.cs index 628f2689..cd8312aa 100644 --- a/src/ai/commands/chat_command.cs +++ b/src/ai/commands/chat_command.cs @@ -1100,7 +1100,7 @@ private async Task DoChatAssistantCreate() DemandKeyAndEndpoint(out var key, out var endpoint); - var codeInterpreter = CodeInterpreterToken.Data().GetOrDefault(_values, false); + var codeInterpreter = CodeInterpreterTrueFalseToken.Data().GetOrDefault(_values, false); var assistant = await OpenAIAssistantHelpers.CreateAssistantAsync(key, endpoint, name, deployment, instructions, codeInterpreter); IdHelpers.CheckWriteOutputNameOrId(assistant.Id, _values, "chat", IdKind.Id); IdHelpers.CheckWriteOutputNameOrId(assistant.Name, _values, "chat", IdKind.Name); @@ -1115,7 +1115,8 @@ private async Task DoChatAssistantCreate() files = ExpandFindFiles(files); - if (fileIds.Count() > 0 || files.Count() > 0) + var fileSearch = FileSearchTrueFalseToken.Data().GetOrDefault(_values, false); + if (fileSearch || fileIds.Count() > 0 || files.Count() > 0) { if (!_quiet) Console.WriteLine("\n Creating vector store ..."); @@ -1165,8 +1166,15 @@ private async Task DoChatAssistantUpdate() var deployment = ConfigDeploymentToken.Data().GetOrDefault(_values, assistant.Model); var instructions = InstructionsToken.Data().GetOrDefault(_values, assistant.Instructions); - var codeInterpreter = CodeInterpreterToken.Data().GetOrDefault(_values, assistant.Tools.FirstOrDefault(t => t is CodeInterpreterToolDefinition) != null); + + var existingCodeInterpreter = assistant.Tools.FirstOrDefault(t => t is CodeInterpreterToolDefinition) != null; + var codeInterpreterSpecifiedTrue = CodeInterpreterTrueFalseToken.Data().GetOrDefault(_values, false) == true; + var codeInterpreterSpecifiedFalse = CodeInterpreterTrueFalseToken.Data().GetOrDefault(_values, true) == false; + var vectorStoreId = assistant.ToolResources.FileSearch?.VectorStoreIds?.FirstOrDefault(); + var existingVectorStore = !string.IsNullOrEmpty(vectorStoreId); + var fileSearchSpecifiedFalse = FileSearchTrueFalseToken.Data().GetOrDefault(_values, true) == false; + var fileSearchSpecifiedTrue = FileSearchTrueFalseToken.Data().GetOrDefault(_values, false) == true; var fileIds = FileIdOptionXToken.GetOptions(_values).ToList(); fileIds.AddRange(FileIdsOptionXToken.GetOptions(_values)); @@ -1177,25 +1185,30 @@ private async Task DoChatAssistantUpdate() files.ExtendSplitItems(';'); files = ExpandFindFiles(files); - var existingVectorStore = !string.IsNullOrEmpty(vectorStoreId); var newFilesForVectorStore = fileIds.Count() > 0 || files.Count() > 0; - if (newFilesForVectorStore) + var createVectorStore = !existingVectorStore && (newFilesForVectorStore || fileSearchSpecifiedTrue); + var updateVectorStore = existingVectorStore && newFilesForVectorStore; + var removeVectorStore = existingVectorStore && fileSearchSpecifiedFalse; + + if (createVectorStore) { - if (!existingVectorStore) - { - if (!_quiet) Console.WriteLine("\n Creating vector store ..."); + if (!_quiet) Console.WriteLine("\n Creating vector store ..."); - var store = await CreateAssistantVectorStoreAsync(key, endpoint, name, fileIds, files); - vectorStoreId = store.Id; - } - else - { - if (!_quiet) Console.WriteLine("\n Updating vector store ..."); + var store = await CreateAssistantVectorStoreAsync(key, endpoint, name, fileIds, files); + vectorStoreId = store.Id; + } + else if (updateVectorStore) + { + if (!_quiet) Console.WriteLine("\n Updating vector store ..."); - var store = await OpenAIAssistantHelpers.GetVectorStoreAsync(key, endpoint, vectorStoreId); - store = await UploadFilesToAssistantVectorStore(key, endpoint, fileIds, files, store); - vectorStoreId = store.Id; - } + var store = await OpenAIAssistantHelpers.GetVectorStoreAsync(key, endpoint, vectorStoreId); + store = await UploadFilesToAssistantVectorStore(key, endpoint, fileIds, files, store); + vectorStoreId = store.Id; + } + else if (removeVectorStore) + { + Console.WriteLine(); + _values.AddThrowError("ERROR:", $"Removing vector store; not yet implemented."); } var modifyOptions = new AssistantModificationOptions() @@ -1206,7 +1219,13 @@ private async Task DoChatAssistantUpdate() ToolResources = ToolResourcesFromVectorStoreId(vectorStoreId), }; - if (codeInterpreter) + var removeCodeInterpreter = existingCodeInterpreter && codeInterpreterSpecifiedFalse; + if (removeCodeInterpreter) + { + Console.WriteLine(); + _values.AddThrowError("ERROR:", $"Removing code interpreter; not yet implemented."); + } + else if (existingCodeInterpreter) { modifyOptions.DefaultTools.Add(new CodeInterpreterToolDefinition()); } @@ -1265,12 +1284,11 @@ private async Task DoChatAssistantGet() if (!_quiet) Console.WriteLine(message); DemandKeyAndEndpoint(out var key, out var endpoint); - var json = await OpenAIAssistantHelpers.GetAssistantJsonAsync(key, endpoint, id); - if (!_quiet) Console.WriteLine($"{message} Done!\n"); + var assistant = await OpenAIAssistantHelpers.GetAssistantAsync(key, endpoint, id); + PrintAssistant(assistant); - var ok = !string.IsNullOrEmpty(json); - if (ok) Console.WriteLine(json); + if (!_quiet) Console.WriteLine($"{message} Done!\n"); return true; } @@ -1687,6 +1705,34 @@ private void PrintVectorStore(string key, string endpoint, VectorStore store) } } + private static void PrintAssistant(Assistant assistant) + { + var id = assistant.Id; + var name = assistant.Name ?? "(no name)"; + + Console.WriteLine(); + Console.WriteLine($" ID: {id}"); + Console.WriteLine($" Name: {name}"); + Console.WriteLine(); + Console.WriteLine($" Model: {assistant.Model}"); + Console.WriteLine($" Instructions: {assistant.Instructions}"); + Console.WriteLine(); + + var toolNames = string.Join(", ", assistant.Tools.Select(x => x.GetType().Name)); + Console.WriteLine($" Tools: {toolNames}"); + + var countVectorStoreIds = assistant.ToolResources?.FileSearch?.VectorStoreIds?.Count(); + if (countVectorStoreIds > 0) + { + var vectorStoreIds = string.Join(", ", assistant.ToolResources?.FileSearch?.VectorStoreIds); + Console.WriteLine(vectorStoreIds.Contains(',') + ? $" Vector stores: {vectorStoreIds}" + : $" Vector store: {vectorStoreIds}"); + } + + Console.WriteLine(); + } + private void DemandKeyAndEndpoint(out string key, out string endpoint) { key = _values["service.config.key"]; diff --git a/src/ai/commands/parsers/chat_command_parser.cs b/src/ai/commands/parsers/chat_command_parser.cs index 80b91657..cda6933c 100644 --- a/src/ai/commands/parsers/chat_command_parser.cs +++ b/src/ai/commands/parsers/chat_command_parser.cs @@ -161,13 +161,14 @@ public CommonChatNamedValueTokenParsers() : base( }; private static INamedValueTokenParser[] _chatAssistantCreateCommandParsers = { - CodeInterpreterToken.Parser(), + CodeInterpreterTrueFalseToken.Parser(), + FileSearchTrueFalseToken.Parser(), FileIdOptionXToken.Parser(), FileIdsOptionXToken.Parser(), FileOptionXToken.Parser(), FilesOptionXToken.Parser(), - + new CommonChatNamedValueTokenParsers(), new Any1ValueNamedValueTokenParser(null, "chat.assistant.id", "001"), @@ -181,8 +182,9 @@ public CommonChatNamedValueTokenParsers() : base( }; private static INamedValueTokenParser[] _chatAssistantUpdateCommandParsers = { - CodeInterpreterToken.Parser(), + CodeInterpreterTrueFalseToken.Parser(), + FileSearchTrueFalseToken.Parser(), FileIdOptionXToken.Parser(), FileIdsOptionXToken.Parser(), FileOptionXToken.Parser(), diff --git a/src/ai/helpers/openai/OpenAIAssistantHelpers.cs b/src/ai/helpers/openai/OpenAIAssistantHelpers.cs index 180be3f1..8a59af63 100644 --- a/src/ai/helpers/openai/OpenAIAssistantHelpers.cs +++ b/src/ai/helpers/openai/OpenAIAssistantHelpers.cs @@ -58,6 +58,18 @@ public static async Task DeleteAssistantAsync(AssistantClient client, string id) var response = await client.DeleteAssistantAsync(id); } + public static async Task GetAssistantAsync(string key, string endpoint, string id) + { + var client = CreateOpenAIAssistantClient(key, endpoint); + return await GetAssistantAsync(client, id); + } + + public static async Task GetAssistantAsync(AssistantClient client, string id) + { + var response = await client.GetAssistantAsync(id); + return response.Value; + } + public static async Task GetAssistantJsonAsync(string key, string endpoint, string id) { var client = CreateOpenAIAssistantClient(key, endpoint); diff --git a/src/common/details/commands/parsers/command_parser.cs b/src/common/details/commands/parsers/command_parser.cs index 8acb909b..000d5636 100644 --- a/src/common/details/commands/parsers/command_parser.cs +++ b/src/common/details/commands/parsers/command_parser.cs @@ -334,7 +334,7 @@ protected static void ParseInvalidArgumentError(INamedValueTokens tokens, INamed { if (!values.Contains("error")) { - var restOfTokens = tokens.PeekAllTokens(); + var restOfTokens = tokens.PeekAllTokens().Replace('.', '-'); var command = values.GetCommandForDisplay(); values.AddError( "ERROR:", $"Invalid {kind} at \"{restOfTokens}\".", "", diff --git a/src/common/details/named_values/tokens/code_interpreter_token.cs b/src/common/details/named_values/tokens/code_interpreter_true_false_token.cs similarity index 93% rename from src/common/details/named_values/tokens/code_interpreter_token.cs rename to src/common/details/named_values/tokens/code_interpreter_true_false_token.cs index 31ea0b0f..71978f2f 100644 --- a/src/common/details/named_values/tokens/code_interpreter_token.cs +++ b/src/common/details/named_values/tokens/code_interpreter_true_false_token.cs @@ -5,7 +5,7 @@ namespace Azure.AI.Details.Common.CLI { - public class CodeInterpreterToken + public class CodeInterpreterTrueFalseToken { public static NamedValueTokenData Data() => new NamedValueTokenData(_optionName, _fullName, _optionExample, _requiredDisplayName); public static INamedValueTokenParser Parser() => new TrueFalseNamedValueTokenParser(_optionName, _fullName, "11"); diff --git a/src/common/details/named_values/tokens/file_search_true_false_token.cs b/src/common/details/named_values/tokens/file_search_true_false_token.cs new file mode 100644 index 00000000..086f14b0 --- /dev/null +++ b/src/common/details/named_values/tokens/file_search_true_false_token.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +namespace Azure.AI.Details.Common.CLI +{ + public class FileSearchTrueFalseToken + { + public static NamedValueTokenData Data() => new NamedValueTokenData(_optionName, _fullName, _optionExample, _requiredDisplayName); + public static INamedValueTokenParser Parser() => new TrueFalseNamedValueTokenParser(_optionName, _fullName, "11"); + + private const string _requiredDisplayName = "file search"; + private const string _optionName = "--file-search"; + private const string _optionExample = "true|false"; + private const string _fullName = "file.search"; + } +}