diff --git a/examples/src/main/java/io/dapr/examples/conversation/AssistantMessageDemo.java b/examples/src/main/java/io/dapr/examples/conversation/AssistantMessageDemo.java new file mode 100644 index 000000000..7ff2d43be --- /dev/null +++ b/examples/src/main/java/io/dapr/examples/conversation/AssistantMessageDemo.java @@ -0,0 +1,136 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.examples.conversation; + +import io.dapr.client.DaprClientBuilder; +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.AssistantMessage; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessage; +import io.dapr.client.domain.ConversationMessageContent; +import io.dapr.client.domain.ConversationRequestAlpha2; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.ConversationToolCalls; +import io.dapr.client.domain.ConversationToolCallsOfFunction; +import io.dapr.client.domain.SystemMessage; +import io.dapr.client.domain.ToolMessage; +import io.dapr.client.domain.UserMessage; +import reactor.core.publisher.Mono; + +import java.util.ArrayList; +import java.util.List; + +public class AssistantMessageDemo { + /** + * The main method to demonstrate conversation AI with assistant messages and conversation history. + * + * @param args Input arguments (unused). + */ + public static void main(String[] args) { + try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + System.out.println("Demonstrating Conversation AI with Assistant Messages and Conversation History"); + + // Create a conversation history with multiple message types + List conversationHistory = new ArrayList<>(); + + // 1. System message to set context + SystemMessage systemMessage = new SystemMessage(List.of( + new ConversationMessageContent("You are a helpful assistant that can help with weather queries.") + )); + systemMessage.setName("WeatherBot"); + conversationHistory.add(systemMessage); + + // 2. Initial user greeting + UserMessage greeting = new UserMessage(List.of( + new ConversationMessageContent("Hello! I need help with weather information.") + )); + greeting.setName("User123"); + conversationHistory.add(greeting); + + // 3. Assistant response with tool call + AssistantMessage assistantResponse = new AssistantMessage( + List.of(new ConversationMessageContent("I'll help you with weather information. Let me check the weather for you.")), + List.of(new ConversationToolCalls( + new ConversationToolCallsOfFunction("get_weather", "{\"location\": \"San Francisco\", \"unit\": \"fahrenheit\"}") + )) + ); + assistantResponse.setName("WeatherBot"); + conversationHistory.add(assistantResponse); + + // 4. Tool response (simulating weather API response) + ToolMessage toolResponse = new ToolMessage(List.of( + new ConversationMessageContent("{\"temperature\": \"72F\", \"condition\": \"sunny\", \"humidity\": \"65%\"}") + )); + toolResponse.setName("weather_api"); + conversationHistory.add(toolResponse); + + // 5. Current user question + UserMessage currentQuestion = new UserMessage(List.of( + new ConversationMessageContent("Based on that weather data, should I wear a jacket today?") + )); + currentQuestion.setName("User123"); + conversationHistory.add(currentQuestion); + + // Create conversation input with the full history + ConversationInputAlpha2 conversationInput = new ConversationInputAlpha2(conversationHistory); + conversationInput.setScrubPii(false); + + // Create the conversation request + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(conversationInput)) + .setContextId("assistant-demo-context") + .setTemperature(0.8d); + + // Send the request + System.out.println("Sending conversation with assistant messages and history..."); + System.out.println("Conversation includes:"); + System.out.println("- System message (context setting)"); + System.out.println("- User greeting"); + System.out.println("- Assistant response with tool call"); + System.out.println("- Tool response with weather data"); + System.out.println("- User follow-up question"); + + Mono responseMono = client.converseAlpha2(request); + ConversationResponseAlpha2 response = responseMono.block(); + + // Process and display the response + if (response != null && response.getOutputs() != null && !response.getOutputs().isEmpty()) { + ConversationResultAlpha2 result = response.getOutputs().get(0); + if (result.getChoices() != null && !result.getChoices().isEmpty()) { + ConversationResultChoices choice = result.getChoices().get(0); + + if (choice.getMessage() != null && choice.getMessage().getContent() != null) { + System.out.printf("Assistant Response: %s%n", choice.getMessage().getContent()); + } + + // Check for additional tool calls in the response + if (choice.getMessage() != null && choice.getMessage().getToolCalls() != null) { + System.out.println("Assistant requested additional tool calls:"); + choice.getMessage().getToolCalls().forEach(toolCall -> { + System.out.printf("Tool: %s, Arguments: %s%n", + toolCall.getFunction().getName(), + toolCall.getFunction().getArguments()); + }); + } + } + } + + System.out.println("Assistant message demonstration completed."); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java deleted file mode 100644 index 09c957026..000000000 --- a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2021 The Dapr Authors - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and -limitations under the License. -*/ - -package io.dapr.examples.conversation; - -import io.dapr.client.DaprClientBuilder; -import io.dapr.client.DaprPreviewClient; -import io.dapr.client.domain.ConversationInput; -import io.dapr.client.domain.ConversationRequest; -import io.dapr.client.domain.ConversationResponse; -import reactor.core.publisher.Mono; - -import java.util.List; - -public class DemoConversationAI { - /** - * The main method to start the client. - * - * @param args Input arguments (unused). - */ - public static void main(String[] args) { - try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { - System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567"); - - ConversationInput daprConversationInput = new ConversationInput("Hello How are you? " - + "This is the my number 672-123-4567"); - - // Component name is the name provided in the metadata block of the conversation.yaml file. - Mono responseMono = client.converse(new ConversationRequest("echo", - List.of(daprConversationInput)) - .setContextId("contextId") - .setScrubPii(true).setTemperature(1.1d)); - ConversationResponse response = responseMono.block(); - System.out.printf("Conversation output: %s", response.getConversationOutputs().get(0).getResult()); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} diff --git a/examples/src/main/java/io/dapr/examples/conversation/README.md b/examples/src/main/java/io/dapr/examples/conversation/README.md index 29468cfb3..cb12b6460 100644 --- a/examples/src/main/java/io/dapr/examples/conversation/README.md +++ b/examples/src/main/java/io/dapr/examples/conversation/README.md @@ -45,30 +45,51 @@ Run `dapr init` to initialize Dapr in Self-Hosted Mode if it's not already initi ### Running the example This example uses the Java SDK Dapr client in order to **Converse** with an LLM. -`DemoConversationAI.java` is the example class demonstrating these features. +`UserMessageDemo.java` is the example class demonstrating these features. Kindly check [DaprPreviewClient.java](https://github.com/dapr/java-sdk/blob/master/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java) for a detailed description of the supported APIs. ```java -public class DemoConversationAI { +public class UserMessageDemo { /** * The main method to start the client. * * @param args Input arguments (unused). */ public static void main(String[] args) { - try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + Map, String> overrides = Map.of( + Properties.HTTP_PORT, "3500", + Properties.GRPC_PORT, "50001" + ); + + try (DaprPreviewClient client = new DaprClientBuilder().withPropertyOverrides(overrides).buildPreviewClient()) { System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567"); - ConversationInput daprConversationInput = new ConversationInput("Hello How are you? " - + "This is the my number 672-123-4567"); + // Create user message with content + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("Hello How are you? " + + "This is the my number 672-123-4567"))); + + // Create conversation input with the user message + ConversationInputAlpha2 daprConversationInput = new ConversationInputAlpha2(List.of(userMessage)); // Component name is the name provided in the metadata block of the conversation.yaml file. - Mono responseMono = client.converse(new ConversationRequest("echo", + Mono responseMono = client.converseAlpha2(new ConversationRequestAlpha2("echo", List.of(daprConversationInput)) .setContextId("contextId") - .setScrubPii(true).setTemperature(1.1d)); - ConversationResponse response = responseMono.block(); - System.out.printf("Conversation output: %s", response.getConversationOutpus().get(0).getResult()); + .setScrubPii(true) + .setTemperature(1.1d)); + + ConversationResponseAlpha2 response = responseMono.block(); + + // Extract and print the conversation result + if (response != null && response.getOutputs() != null && !response.getOutputs().isEmpty()) { + ConversationResultAlpha2 result = response.getOutputs().get(0); + if (result.getChoices() != null && !result.getChoices().isEmpty()) { + ConversationResultChoices choice = result.getChoices().get(0); + if (choice.getMessage() != null && choice.getMessage().getContent() != null) { + System.out.printf("Conversation output: %s", choice.getMessage().getContent()); + } + } + } } catch (Exception e) { throw new RuntimeException(e); } @@ -88,7 +109,7 @@ sleep: 10 --> ```bash -dapr run --resources-path ./components/conversation --app-id myapp --app-port 8080 --dapr-http-port 3500 --dapr-grpc-port 51439 --log-level debug -- java -jar target/dapr-java-sdk-examples-exec.jar io.dapr.examples.conversation.DemoConversationAI +dapr run --resources-path ./components/conversation --app-id myapp --app-port 8080 --dapr-http-port 3500 --dapr-grpc-port 51439 --log-level debug -- java -jar target/dapr-java-sdk-examples-exec.jar io.dapr.examples.conversation.UserMessageDemo ``` diff --git a/examples/src/main/java/io/dapr/examples/conversation/ToolsCallDemo.java b/examples/src/main/java/io/dapr/examples/conversation/ToolsCallDemo.java new file mode 100644 index 000000000..30335802c --- /dev/null +++ b/examples/src/main/java/io/dapr/examples/conversation/ToolsCallDemo.java @@ -0,0 +1,109 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.examples.conversation; + +import io.dapr.client.DaprClientBuilder; +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessageContent; +import io.dapr.client.domain.ConversationRequestAlpha2; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.ConversationTools; +import io.dapr.client.domain.ConversationToolsFunction; +import io.dapr.client.domain.SystemMessage; +import io.dapr.client.domain.UserMessage; +import reactor.core.publisher.Mono; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ToolsCallDemo { + /** + * The main method to demonstrate conversation AI with tools/function calling. + * + * @param args Input arguments (unused). + */ + public static void main(String[] args) { + try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + System.out.println("Demonstrating Conversation AI with Tools/Function Calling"); + + // Create system message to set context + SystemMessage systemMessage = new SystemMessage(List.of( + new ConversationMessageContent("You are a helpful weather assistant. Use the provided tools to get weather information.") + )); + + // Create user message asking for weather + UserMessage userMessage = new UserMessage(List.of( + new ConversationMessageContent("What's the weather like in San Francisco?") + )); + + // Create conversation input with messages + ConversationInputAlpha2 conversationInput = new ConversationInputAlpha2(List.of(systemMessage, userMessage)); + + // Define function parameters for the weather tool + Map functionParams = new HashMap<>(); + functionParams.put("location", "string"); + functionParams.put("unit", "string"); + + // Create the weather function definition + ConversationToolsFunction weatherFunction = new ConversationToolsFunction("get_current_weather", functionParams); + weatherFunction.setDescription("Get the current weather for a specified location"); + + // Create the tool wrapper + ConversationTools weatherTool = new ConversationTools(weatherFunction); + + // Create the conversation request with tools + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(conversationInput)) + .setContextId("weather-demo-context") + .setTemperature(0.7d) + .setTools(List.of(weatherTool)); + + // Send the request + System.out.println("Sending request to AI with weather tool available..."); + Mono responseMono = client.converseAlpha2(request); + ConversationResponseAlpha2 response = responseMono.block(); + + // Process and display the response + if (response != null && response.getOutputs() != null && !response.getOutputs().isEmpty()) { + ConversationResultAlpha2 result = response.getOutputs().get(0); + if (result.getChoices() != null && !result.getChoices().isEmpty()) { + ConversationResultChoices choice = result.getChoices().get(0); + + // Check if the AI wants to call a tool + if (choice.getMessage() != null && choice.getMessage().getToolCalls() != null) { + System.out.println("AI requested to call tools:"); + choice.getMessage().getToolCalls().forEach(toolCall -> { + System.out.printf("Tool: %s, Arguments: %s%n", + toolCall.getFunction().getName(), + toolCall.getFunction().getArguments()); + }); + } + + // Display the message content if available + if (choice.getMessage() != null && choice.getMessage().getContent() != null) { + System.out.printf("AI Response: %s%n", choice.getMessage().getContent()); + } + } + } + + System.out.println("Tools call demonstration completed."); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/examples/src/main/java/io/dapr/examples/conversation/UserMessageDemo.java b/examples/src/main/java/io/dapr/examples/conversation/UserMessageDemo.java new file mode 100644 index 000000000..e5afbc475 --- /dev/null +++ b/examples/src/main/java/io/dapr/examples/conversation/UserMessageDemo.java @@ -0,0 +1,72 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.examples.conversation; + +import io.dapr.client.DaprClientBuilder; +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessageContent; +import io.dapr.client.domain.ConversationRequestAlpha2; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.UserMessage; +import io.dapr.config.Properties; +import io.dapr.config.Property; +import reactor.core.publisher.Mono; + +import java.util.List; +import java.util.Map; + +public class UserMessageDemo { + /** + * The main method to start the client. + * + * @param args Input arguments (unused). + */ + public static void main(String[] args) { + try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567"); + + // Create user message with content + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("Hello How are you? " + + "This is the my number 672-123-4567"))); + + // Create conversation input with the user message + ConversationInputAlpha2 daprConversationInput = new ConversationInputAlpha2(List.of(userMessage)); + + // Component name is the name provided in the metadata block of the conversation.yaml file. + Mono responseMono = client.converseAlpha2(new ConversationRequestAlpha2("echo", + List.of(daprConversationInput)) + .setContextId("contextId") + .setScrubPii(true) + .setTemperature(1.1d)); + + ConversationResponseAlpha2 response = responseMono.block(); + + // Extract and print the conversation result + if (response != null && response.getOutputs() != null && !response.getOutputs().isEmpty()) { + ConversationResultAlpha2 result = response.getOutputs().get(0); + if (result.getChoices() != null && !result.getChoices().isEmpty()) { + ConversationResultChoices choice = result.getChoices().get(0); + if (choice.getMessage() != null && choice.getMessage().getContent() != null) { + System.out.printf("Conversation output: %s", choice.getMessage().getContent()); + } + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationAlpha2IT.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationAlpha2IT.java new file mode 100644 index 000000000..00522eae6 --- /dev/null +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationAlpha2IT.java @@ -0,0 +1,386 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.it.testcontainers.conversations; + +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.AssistantMessage; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessage; +import io.dapr.client.domain.ConversationMessageContent; +import io.dapr.client.domain.ConversationRequestAlpha2; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.ConversationToolCalls; +import io.dapr.client.domain.ConversationToolCallsOfFunction; +import io.dapr.client.domain.ConversationTools; +import io.dapr.client.domain.ConversationToolsFunction; +import io.dapr.client.domain.DeveloperMessage; +import io.dapr.client.domain.SystemMessage; +import io.dapr.client.domain.ToolMessage; +import io.dapr.client.domain.UserMessage; +import io.dapr.it.testcontainers.DaprPreviewClientConfiguration; +import io.dapr.testcontainers.Component; +import io.dapr.testcontainers.DaprContainer; +import io.dapr.testcontainers.DaprLogLevel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; +import org.testcontainers.containers.Network; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static io.dapr.it.testcontainers.ContainerConstants.DAPR_RUNTIME_IMAGE_TAG; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@SpringBootTest( + webEnvironment = WebEnvironment.RANDOM_PORT, + classes = { + DaprPreviewClientConfiguration.class, + TestConversationApplication.class + } +) +@Testcontainers +@Tag("testcontainers") +public class DaprConversationAlpha2IT { + + private static final Network DAPR_NETWORK = Network.newNetwork(); + private static final Random RANDOM = new Random(); + private static final int PORT = RANDOM.nextInt(1000) + 8000; + + @Container + private static final DaprContainer DAPR_CONTAINER = new DaprContainer(DAPR_RUNTIME_IMAGE_TAG) + .withAppName("conversation-alpha2-dapr-app") + .withComponent(new Component("echo", "conversation.echo", "v1", new HashMap<>())) + .withNetwork(DAPR_NETWORK) + .withDaprLogLevel(DaprLogLevel.DEBUG) + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withAppChannelAddress("host.testcontainers.internal") + .withAppPort(PORT); + + /** + * Expose the Dapr port to the host. + * + * @param registry the dynamic property registry + */ + @DynamicPropertySource + static void daprProperties(DynamicPropertyRegistry registry) { + registry.add("dapr.http.endpoint", DAPR_CONTAINER::getHttpEndpoint); + registry.add("dapr.grpc.endpoint", DAPR_CONTAINER::getGrpcEndpoint); + registry.add("server.port", () -> PORT); + } + + @Autowired + private DaprPreviewClient daprPreviewClient; + + @BeforeEach + public void setUp() { + org.testcontainers.Testcontainers.exposeHostPorts(PORT); + } + + @Test + public void testConverseAlpha2WithUserMessage() { + // Create a user message + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("Hello, how are you?"))); + userMessage.setName("TestUser"); + + // Create input with the message + ConversationInputAlpha2 input = new ConversationInputAlpha2(List.of(userMessage)); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertEquals(1, response.getOutputs().size()); + + ConversationResultAlpha2 result = response.getOutputs().get(0); + assertNotNull(result.getChoices()); + assertTrue(result.getChoices().size() > 0); + + ConversationResultChoices choice = result.getChoices().get(0); + assertNotNull(choice.getMessage()); + assertNotNull(choice.getMessage().getContent()); + } + + @Test + public void testConverseAlpha2WithAllMessageTypes() { + List messages = new ArrayList<>(); + + // System message + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("You are a helpful assistant."))); + systemMsg.setName("system"); + messages.add(systemMsg); + + // User message + UserMessage userMsg = new UserMessage(List.of(new ConversationMessageContent("Hello!"))); + userMsg.setName("user"); + messages.add(userMsg); + + // Assistant message + AssistantMessage assistantMsg = new AssistantMessage(List.of(new ConversationMessageContent("Hi there!")), + List.of(new ConversationToolCalls( + new ConversationToolCallsOfFunction("get_weather", "{\"location\": \"New York\"}")))); + assistantMsg.setName("assistant"); + messages.add(assistantMsg); + + // Tool message + ToolMessage toolMsg = new ToolMessage(List.of(new ConversationMessageContent("Weather data: 72F"))); + toolMsg.setName("tool"); + messages.add(toolMsg); + + // Developer message + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + } + + @Test + public void testConverseAlpha2WithScrubPII() { + // Create a user message with PII + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("My email is test@example.com and phone is +1234567890"))); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(List.of(userMessage)); + input.setScrubPii(true); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + request.setScrubPii(true); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + + // Verify response structure (actual PII scrubbing depends on echo component implementation) + ConversationResultChoices choice = response.getOutputs().get(0).getChoices().get(0); + assertNotNull(choice.getMessage()); + assertNotNull(choice.getMessage().getContent()); + } + + @Test + public void testConverseAlpha2WithTools() { + // Create a tool function + Map parameters = new HashMap<>(); + parameters.put("location", "string"); + parameters.put("unit", "celsius"); + ConversationToolsFunction function = new ConversationToolsFunction("get_weather", parameters); + function.setDescription("Get current weather information"); + + ConversationTools tool = new ConversationTools(function); + + // Create user message + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("What's the weather like?"))); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(List.of(userMessage)); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + request.setTools(List.of(tool)); + request.setToolChoice("auto"); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + } + + @Test + public void testConverseAlpha2WithMetadataAndParameters() { + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("Hello world"))); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(List.of(userMessage)); + + // Set metadata and parameters + Map metadata = new HashMap<>(); + metadata.put("request-id", "test-123"); + metadata.put("source", "integration-test"); + + Map parameters = new HashMap<>(); + parameters.put("max_tokens", "1000"); + parameters.put("temperature", "0.7"); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + request.setContextId("test-context-123"); + request.setTemperature(0.8); + request.setMetadata(metadata); + request.setParameters(parameters); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + + // Verify context ID is handled properly + // Note: actual context ID behavior depends on echo component implementation + assertNotNull(response.getContextId()); + } + + @Test + public void testConverseAlpha2WithAssistantToolCalls() { + // Create a tool call + ConversationToolCallsOfFunction toolFunction = + new ConversationToolCallsOfFunction("get_weather", "{\"location\": \"New York\"}"); + ConversationToolCalls toolCall = new ConversationToolCalls(toolFunction); + toolCall.setId("call_123"); + + // Create assistant message with tool calls + AssistantMessage assistantMsg = new AssistantMessage(List.of(new ConversationMessageContent("Hi there!")), + List.of(new ConversationToolCalls( + new ConversationToolCallsOfFunction("get_weather", "{\"location\": \"New York\"}")))); // Note: Current implementation doesn't support setting tool calls in constructor + // This tests the structure and ensures no errors occur + + ConversationInputAlpha2 input = new ConversationInputAlpha2(List.of(assistantMsg)); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + } + + @Test + public void testConverseAlpha2WithComplexScenario() { + List messages = new ArrayList<>(); + + // System message setting context + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("You are a helpful weather assistant."))); + systemMsg.setName("WeatherBot"); + messages.add(systemMsg); + + // User asking for weather + UserMessage userMsg = new UserMessage(List.of(new ConversationMessageContent("What's the weather in San Francisco?"))); + userMsg.setName("User123"); + messages.add(userMsg); + + // Assistant response + AssistantMessage assistantMsg = new AssistantMessage(List.of(new ConversationMessageContent("Hi there!")), + List.of(new ConversationToolCalls( + new ConversationToolCallsOfFunction("get_weather", "{\"location\": \"New York\"}")))); + assistantMsg.setName("WeatherBot"); + messages.add(assistantMsg); + + // Tool response + ToolMessage toolMsg = new ToolMessage(List.of(new ConversationMessageContent("{\"temperature\": \"68F\", \"condition\": \"sunny\"}"))); + toolMsg.setName("weather_api"); + messages.add(toolMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + input.setScrubPii(false); + + // Create tools + Map functionParams = new HashMap<>(); + functionParams.put("location", "string"); + functionParams.put("unit", "fahrenheit"); + ConversationToolsFunction weatherFunction = new ConversationToolsFunction("get_current_weather", + functionParams); + weatherFunction.setDescription("Get current weather for a location"); + + + ConversationTools weatherTool = new ConversationTools(weatherFunction); + + // Set up complete request + Map metadata = new HashMap<>(); + metadata.put("conversation-type", "weather-query"); + metadata.put("user-session", "session-456"); + + Map parameters = new HashMap<>(); + parameters.put("max_tokens", "2000"); + parameters.put("response_format", "json"); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", List.of(input)); + request.setContextId("weather-conversation-789"); + request.setTemperature(0.7); + request.setScrubPii(false); + request.setTools(List.of(weatherTool)); + request.setToolChoice("auto"); + request.setMetadata(metadata); + request.setParameters(parameters); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + + ConversationResultAlpha2 result = response.getOutputs().get(0); + assertNotNull(result.getChoices()); + assertTrue(result.getChoices().size() > 0); + + ConversationResultChoices choice = result.getChoices().get(0); + assertNotNull(choice.getFinishReason()); + assertTrue(choice.getIndex() >= 0); + + if (choice.getMessage() != null) { + assertNotNull(choice.getMessage().getContent()); + } + } + + @Test + public void testConverseAlpha2MultipleInputs() { + // Create multiple conversation inputs + List inputs = new ArrayList<>(); + + // First input - greeting + UserMessage greeting = new UserMessage(List.of(new ConversationMessageContent("Hello!"))); + ConversationInputAlpha2 input1 = new ConversationInputAlpha2(List.of(greeting)); + inputs.add(input1); + + // Second input - question + UserMessage question = new UserMessage(List.of(new ConversationMessageContent("How are you?"))); + ConversationInputAlpha2 input2 = new ConversationInputAlpha2(List.of(question)); + input2.setScrubPii(true); + inputs.add(input2); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("echo", inputs); + + ConversationResponseAlpha2 response = daprPreviewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertNotNull(response.getOutputs()); + assertTrue(response.getOutputs().size() > 0); + + // Should handle multiple inputs appropriately + for (ConversationResultAlpha2 result : response.getOutputs()) { + assertNotNull(result.getChoices()); + assertTrue(result.getChoices().size() > 0); + } + } +} diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationIT.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationIT.java index 2301e6fc3..64eac5cce 100644 --- a/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationIT.java +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/conversations/DaprConversationIT.java @@ -132,3 +132,4 @@ public void testConversationSDKShouldScrubPIIOnlyForTheInputWhereScrubPIIIsSet() response.getConversationOutputs().get(1).getResult()); } } + diff --git a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java index 8bae8f51d..a8454c9c8 100644 --- a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java +++ b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java @@ -17,9 +17,12 @@ import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Empty; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; import io.dapr.client.domain.ActorMetadata; import io.dapr.client.domain.AppConnectionPropertiesHealthMetadata; import io.dapr.client.domain.AppConnectionPropertiesMetadata; +import io.dapr.client.domain.AssistantMessage; import io.dapr.client.domain.BulkPublishEntry; import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; @@ -29,9 +32,21 @@ import io.dapr.client.domain.ConfigurationItem; import io.dapr.client.domain.ConstantFailurePolicy; import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessage; +import io.dapr.client.domain.ConversationMessageContent; import io.dapr.client.domain.ConversationOutput; import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationRequestAlpha2; import io.dapr.client.domain.ConversationResponse; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.ConversationResultMessage; +import io.dapr.client.domain.ConversationToolCalls; +import io.dapr.client.domain.ConversationToolCallsOfFunction; +import io.dapr.client.domain.ConversationTools; +import io.dapr.client.domain.ConversationToolsFunction; import io.dapr.client.domain.DaprMetadata; import io.dapr.client.domain.DeleteJobRequest; import io.dapr.client.domain.DeleteStateRequest; @@ -64,6 +79,7 @@ import io.dapr.client.domain.SubscribeConfigurationRequest; import io.dapr.client.domain.SubscribeConfigurationResponse; import io.dapr.client.domain.SubscriptionMetadata; +import io.dapr.client.domain.ToolMessage; import io.dapr.client.domain.TransactionalStateOperation; import io.dapr.client.domain.UnlockRequest; import io.dapr.client.domain.UnlockResponseStatus; @@ -1620,6 +1636,7 @@ public Mono getMetadata() { /** * {@inheritDoc} */ + @Deprecated(forRemoval = true) @Override public Mono converse(ConversationRequest conversationRequest) { @@ -1690,6 +1707,283 @@ private void validateConversationRequest(ConversationRequest conversationRequest } } + /** + * {@inheritDoc} + */ + @Override + public Mono converseAlpha2(ConversationRequestAlpha2 conversationRequestAlpha2) { + try { + if ((conversationRequestAlpha2.getName() == null) || (conversationRequestAlpha2.getName().trim().isEmpty())) { + throw new IllegalArgumentException("LLM name cannot be null or empty."); + } + + if (conversationRequestAlpha2.getInputs() == null || conversationRequestAlpha2.getInputs().isEmpty()) { + throw new IllegalArgumentException("Conversation Inputs cannot be null or empty."); + } + + DaprProtos.ConversationRequestAlpha2.Builder builder = DaprProtos.ConversationRequestAlpha2 + .newBuilder() + .setTemperature(conversationRequestAlpha2.getTemperature()) + .setScrubPii(conversationRequestAlpha2.isScrubPii()) + .setName(conversationRequestAlpha2.getName()); + + if (conversationRequestAlpha2.getContextId() != null) { + builder.setContextId(conversationRequestAlpha2.getContextId()); + } + + if (conversationRequestAlpha2.getToolChoice() != null) { + builder.setToolChoice(conversationRequestAlpha2.getToolChoice()); + } + + DaprProtos.ConversationRequestAlpha2 protoRequest = buildConversationRequestProto(conversationRequestAlpha2, + builder); + + Mono conversationResponseMono = Mono.deferContextual( + context -> this.createMono( + it -> intercept(context, asyncStub).converseAlpha2(protoRequest, it) + ) + ); + + DaprProtos.ConversationResponseAlpha2 conversationResponse = conversationResponseMono.block(); + + assert conversationResponse != null; + List results = buildConversationResults(conversationResponse.getOutputsList()); + return Mono.just(new ConversationResponseAlpha2(conversationResponse.getContextId(), results)); + } catch (Exception ex) { + return DaprException.wrapMono(ex); + } + } + + private DaprProtos.ConversationRequestAlpha2 buildConversationRequestProto(ConversationRequestAlpha2 request, + DaprProtos.ConversationRequestAlpha2.Builder builder) { + if (request.getTools() != null) { + buildConversationTools(request.getTools(), builder); + } + + if (request.getMetadata() != null) { + builder.putAllMetadata(request.getMetadata()); + } + + + if (request.getParameters() != null) { + Map parameters = request.getParameters() + .entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> { + try { + return Any.newBuilder().setValue(ByteString.copyFrom(objectSerializer.serialize(e.getValue()))) + .build(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }) + ); + builder.putAllParameters(parameters); + } + + for (ConversationInputAlpha2 input : request.getInputs()) { + DaprProtos.ConversationInputAlpha2.Builder inputBuilder = DaprProtos.ConversationInputAlpha2 + .newBuilder() + .setScrubPii(input.isScrubPii()); + + if (input.getMessages() != null) { + for (ConversationMessage message : input.getMessages()) { + DaprProtos.ConversationMessage protoMessage = buildConversationMessage(message); + inputBuilder.addMessages(protoMessage); + } + } + + builder.addInputs(inputBuilder.build()); + } + + return builder.build(); + } + + private void buildConversationTools(List tools, + DaprProtos.ConversationRequestAlpha2.Builder builder) { + for (ConversationTools tool : tools) { + ConversationToolsFunction function = tool.getFunction(); + + DaprProtos.ConversationToolsFunction.Builder protoFunction = DaprProtos.ConversationToolsFunction.newBuilder() + .setName(function.getName()); + + if (function.getDescription() != null) { + protoFunction.setDescription(function.getDescription()); + } + + if (function.getParameters() != null) { + Map functionParams = function.getParameters() + .entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> { + try { + return ProtobufValueHelper.toProtobufValue(e.getValue()); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + )); + + protoFunction.setParameters(Struct.newBuilder().putAllFields(functionParams).build()); + } + + builder.addTools(DaprProtos.ConversationTools.newBuilder() + .setFunction(protoFunction) + .build()); + } + } + + private DaprProtos.ConversationMessage buildConversationMessage(ConversationMessage message) { + DaprProtos.ConversationMessage.Builder messageBuilder = DaprProtos.ConversationMessage.newBuilder(); + + switch (message.getRole()) { + case TOOL: + DaprProtos.ConversationMessageOfTool.Builder toolMessage = + DaprProtos.ConversationMessageOfTool.newBuilder(); + if (message.getName() != null) { + toolMessage.setName(message.getName()); + } + if (message.getContent() != null) { + toolMessage.addAllContent(getConversationMessageContent(message)); + } + if (((ToolMessage)message).getToolId() != null) { + toolMessage.setToolId(((ToolMessage)message).getToolId()); + } + messageBuilder.setOfTool(toolMessage); + break; + case USER: + DaprProtos.ConversationMessageOfUser.Builder userMessage = + DaprProtos.ConversationMessageOfUser.newBuilder(); + if (message.getName() != null) { + userMessage.setName(message.getName()); + } + if (message.getContent() != null) { + userMessage.addAllContent(getConversationMessageContent(message)); + } + messageBuilder.setOfUser(userMessage); + break; + case ASSISTANT: + DaprProtos.ConversationMessageOfAssistant.Builder assistantMessage = + DaprProtos.ConversationMessageOfAssistant.newBuilder(); + + if (message.getName() != null) { + assistantMessage.setName(message.getName()); + } + if (message.getContent() != null) { + assistantMessage.addAllContent(getConversationMessageContent(message)); + } + if (((AssistantMessage)message).getToolCalls() != null) { + assistantMessage.addAllToolCalls(getConversationToolCalls((AssistantMessage)message)); + } + messageBuilder.setOfAssistant(assistantMessage); + break; + case DEVELOPER: + DaprProtos.ConversationMessageOfDeveloper.Builder developerMessage = + DaprProtos.ConversationMessageOfDeveloper.newBuilder(); + if (message.getName() != null) { + developerMessage.setName(message.getName()); + } + if (message.getContent() != null) { + developerMessage.addAllContent(getConversationMessageContent(message)); + } + messageBuilder.setOfDeveloper(developerMessage); + break; + case SYSTEM: + DaprProtos.ConversationMessageOfSystem.Builder systemMessage = + DaprProtos.ConversationMessageOfSystem.newBuilder(); + if (message.getName() != null) { + systemMessage.setName(message.getName()); + } + if (message.getContent() != null) { + systemMessage.addAllContent(getConversationMessageContent(message)); + } + messageBuilder.setOfSystem(systemMessage); + break; + default: + throw new IllegalArgumentException("No role of type " + message.getRole() + " found"); + } + + return messageBuilder.build(); + } + + private List buildConversationResults( + List protoResults) { + List results = new ArrayList<>(); + + for (DaprProtos.ConversationResultAlpha2 protoResult : protoResults) { + List choices = new ArrayList<>(); + + for (DaprProtos.ConversationResultChoices protoChoice : protoResult.getChoicesList()) { + ConversationResultMessage message = buildConversationResultMessage(protoChoice); + choices.add(new ConversationResultChoices(protoChoice.getFinishReason(), protoChoice.getIndex(), message)); + } + + results.add(new ConversationResultAlpha2(choices)); + } + + return results; + } + + private ConversationResultMessage buildConversationResultMessage(DaprProtos.ConversationResultChoices protoChoice) { + if (!protoChoice.hasMessage()) { + return null; + } + + List toolCalls = new ArrayList<>(); + + for (DaprProtos.ConversationToolCalls protoToolCall : protoChoice.getMessage().getToolCallsList()) { + ConversationToolCallsOfFunction function = null; + if (protoToolCall.hasFunction()) { + function = new ConversationToolCallsOfFunction( + protoToolCall.getFunction().getName(), + protoToolCall.getFunction().getArguments() + ); + } + + ConversationToolCalls conversationToolCalls = new ConversationToolCalls(function); + conversationToolCalls.setId(protoToolCall.getId()); + + toolCalls.add(conversationToolCalls); + } + + return new ConversationResultMessage(protoChoice.getMessage().getContent(), toolCalls + ); + } + + private List getConversationMessageContent( + ConversationMessage conversationMessage) { + + List conversationMessageContents = new ArrayList<>(); + for (ConversationMessageContent conversationMessageContent: conversationMessage.getContent()) { + conversationMessageContents.add(DaprProtos.ConversationMessageContent.newBuilder() + .setText(conversationMessageContent.getText()) + .build()); + } + + return conversationMessageContents; + } + + private List getConversationToolCalls( + AssistantMessage assistantMessage) { + List conversationToolCalls = new ArrayList<>(); + for (ConversationToolCalls conversationToolCall: assistantMessage.getToolCalls()) { + DaprProtos.ConversationToolCalls.Builder toolCallsBuilder = DaprProtos.ConversationToolCalls.newBuilder() + .setFunction(DaprProtos.ConversationToolCallsOfFunction.newBuilder() + .setName(conversationToolCall.getFunction().getName()) + .setArguments(conversationToolCall.getFunction().getArguments()) + .build()); + if (conversationToolCall.getId() != null) { + toolCallsBuilder.setId(conversationToolCall.getId()); + } + + conversationToolCalls.add(toolCallsBuilder.build()); + } + + return conversationToolCalls; + } + private DaprMetadata buildDaprMetadata(DaprProtos.GetMetadataResponse response) throws IOException { String id = response.getId(); String runtimeVersion = response.getRuntimeVersion(); diff --git a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java index 89c6eded8..92c6a61c3 100644 --- a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java +++ b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java @@ -18,7 +18,9 @@ import io.dapr.client.domain.BulkPublishResponse; import io.dapr.client.domain.BulkPublishResponseFailedEntry; import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationRequestAlpha2; import io.dapr.client.domain.ConversationResponse; +import io.dapr.client.domain.ConversationResponseAlpha2; import io.dapr.client.domain.DeleteJobRequest; import io.dapr.client.domain.GetJobRequest; import io.dapr.client.domain.GetJobResponse; @@ -313,5 +315,14 @@ Subscription subscribeToEvents( * @param conversationRequest request to be passed to the LLM. * @return {@link ConversationResponse}. */ + @Deprecated public Mono converse(ConversationRequest conversationRequest); + + /* + * Converse with an LLM using Alpha2 API. + * + * @param conversationRequestAlpha2 request to be passed to the LLM with Alpha2 features. + * @return {@link ConversationResponseAlpha2}. + */ + public Mono converseAlpha2(ConversationRequestAlpha2 conversationRequestAlpha2); } diff --git a/sdk/src/main/java/io/dapr/client/ProtobufValueHelper.java b/sdk/src/main/java/io/dapr/client/ProtobufValueHelper.java new file mode 100644 index 000000000..409967f8e --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/ProtobufValueHelper.java @@ -0,0 +1,76 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client; + +import com.google.protobuf.ListValue; +import com.google.protobuf.NullValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.dapr.serializer.DaprObjectSerializer; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * Helper class to convert Java objects to Google Protobuf Value types. + */ +public class ProtobufValueHelper { + + /** + * Converts a Java object to a Google Protobuf Value. + * + * @param obj the Java object to convert + * @return the corresponding Protobuf Value + * @throws IOException if serialization fails + */ + public static Value toProtobufValue(Object obj) throws IOException { + if (obj == null) { + return Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); + } + + if (obj instanceof Boolean) { + return Value.newBuilder().setBoolValue((Boolean) obj).build(); + } + + if (obj instanceof String) { + return Value.newBuilder().setStringValue((String) obj).build(); + } + + if (obj instanceof Number) { + return Value.newBuilder().setNumberValue(((Number) obj).doubleValue()).build(); + } + + if (obj instanceof List) { + ListValue.Builder listBuilder = ListValue.newBuilder(); + for (Object item : (List) obj) { + listBuilder.addValues(toProtobufValue(item)); + } + return Value.newBuilder().setListValue(listBuilder.build()).build(); + } + + if (obj instanceof Map) { + Struct.Builder structBuilder = Struct.newBuilder(); + for (Map.Entry entry : ((Map) obj).entrySet()) { + String key = entry.getKey().toString(); + Value value = toProtobufValue(entry.getValue()); + structBuilder.putFields(key, value); + } + return Value.newBuilder().setStructValue(structBuilder.build()).build(); + } + + // Fallback: convert to string + return Value.newBuilder().setStringValue(obj.toString()).build(); + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java b/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java new file mode 100644 index 000000000..1007066aa --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java @@ -0,0 +1,67 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Assistant message containing responses from the AI model. + * Can include regular content and/or tool calls that the model wants to make. + */ +public class AssistantMessage implements ConversationMessage { + + private String name; + private final List content; + private final List toolCalls; + + /** + * Creates an assistant message with content and optional tool calls. + * @param content the content of the assistant message. + * @param toolCalls the tool calls requested by the assistant. + */ + public AssistantMessage(List content, List toolCalls) { + this.content = List.copyOf(content); + this.toolCalls = List.copyOf(toolCalls); + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.ASSISTANT; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the assistant participant. + * + * @param name the name to set + * @return this instance for method chaining + */ + public AssistantMessage setName(String name) { + this.name = name; + return this; + } + + @Override + public List getContent() { + return content; + } + + public List getToolCalls() { + return toolCalls; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationInputAlpha2.java b/sdk/src/main/java/io/dapr/client/domain/ConversationInputAlpha2.java new file mode 100644 index 000000000..52ee2f40a --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationInputAlpha2.java @@ -0,0 +1,63 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Represents an Alpha2 input for conversation with enhanced message support. + */ +public class ConversationInputAlpha2 { + + private final List messages; + private boolean scrubPii; + + /** + * Constructor. + * + * @param messages the list of conversation messages + */ + public ConversationInputAlpha2(List messages) { + this.messages = List.copyOf(messages); + } + + /** + * Gets the list of conversation messages. + * + * @return the list of messages + */ + public List getMessages() { + return messages; + } + + /** + * Checks if Personally Identifiable Information (PII) should be scrubbed before sending to the LLM. + * + * @return {@code true} if PII should be scrubbed, {@code false} otherwise. + */ + public boolean isScrubPii() { + return scrubPii; + } + + /** + * Enable obfuscation of sensitive information present in the content field. Optional + * + * @param scrubPii A boolean indicating whether to remove PII. + * @return this. + */ + public ConversationInputAlpha2 setScrubPii(boolean scrubPii) { + this.scrubPii = scrubPii; + return this; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java b/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java new file mode 100644 index 000000000..c26c0d041 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Interface representing a conversation message with role-specific content. + * Supports different message types: system, user, assistant, developer, and tool. + */ +public interface ConversationMessage { + + /** + * Gets the role of the message sender. + * + * @return the message role + */ + ConversationMessageRole getRole(); + + /** + * Gets the name of the participant in the message. + * + * @return the participant name, or null if not specified + */ + String getName(); + + /** + * Gets the content of the message. + * + * @return the message content + */ + List getContent(); +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationMessageContent.java b/sdk/src/main/java/io/dapr/client/domain/ConversationMessageContent.java new file mode 100644 index 000000000..78a4ad604 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationMessageContent.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Represents the content of a conversation message. + */ +public class ConversationMessageContent { + + private final String text; + + /** + * Constructor. + * + * @param text the text content of the message + */ + public ConversationMessageContent(String text) { + this.text = text; + } + + /** + * Gets the text content of the message. + * + * @return the text content + */ + public String getText() { + return text; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationMessageRole.java b/sdk/src/main/java/io/dapr/client/domain/ConversationMessageRole.java new file mode 100644 index 000000000..0bfd1b076 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationMessageRole.java @@ -0,0 +1,44 @@ +/* + * Copyright 2022 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Enum representing the different roles a conversation message can have. + */ +public enum ConversationMessageRole { + /** + * System message that sets the behavior or context for the conversation. + */ + SYSTEM, + + /** + * User message containing input from the human user. + */ + USER, + + /** + * Assistant message containing responses from the AI model. + */ + ASSISTANT, + + /** + * Tool message containing results from function/tool calls. + */ + TOOL, + + /** + * Developer message for development and debugging purposes. + */ + DEVELOPER +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationRequestAlpha2.java b/sdk/src/main/java/io/dapr/client/domain/ConversationRequestAlpha2.java new file mode 100644 index 000000000..2f85fbd7d --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationRequestAlpha2.java @@ -0,0 +1,209 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; +import java.util.Map; + +/** + * Represents the Alpha2 conversation configuration with enhanced features including + * tools, improved message handling, and better compatibility with OpenAI ChatCompletion API. + */ +public class ConversationRequestAlpha2 { + + private final String name; + private final List inputs; + private String contextId; + private boolean scrubPii; + private double temperature; + private List tools; + private String toolChoice; + private Map parameters; + private Map metadata; + + /** + * Constructs a ConversationRequestAlpha2 with a component name and conversation inputs. + * + * @param name The name of the Dapr conversation component. See a list of all available conversation components + * @see + * @param inputs the list of Dapr conversation inputs (Alpha2 format) + */ + public ConversationRequestAlpha2(String name, List inputs) { + this.name = name; + this.inputs = inputs; + } + + /** + * Gets the conversation component name. + * + * @return the conversation component name + */ + public String getName() { + return name; + } + + /** + * Gets the list of Dapr conversation input (Alpha2 format). + * + * @return the list of conversation input + */ + public List getInputs() { + return inputs; + } + + /** + * Gets the context identifier. + * + * @return the context identifier + */ + public String getContextId() { + return contextId; + } + + /** + * Sets the context identifier. + * + * @param contextId the context identifier to set + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setContextId(String contextId) { + this.contextId = contextId; + return this; + } + + /** + * Checks if PII scrubbing is enabled. + * + * @return true if PII scrubbing is enabled, false otherwise + */ + public boolean isScrubPii() { + return scrubPii; + } + + /** + * Enable obfuscation of sensitive information returning from the LLM. Optional. + * + * @param scrubPii whether to enable PII scrubbing + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setScrubPii(boolean scrubPii) { + this.scrubPii = scrubPii; + return this; + } + + /** + * Gets the temperature of the model. Used to optimize for consistency and creativity. Optional + * + * @return the temperature value + */ + public double getTemperature() { + return temperature; + } + + /** + * Sets the temperature of the model. Used to optimize for consistency and creativity. Optional + * + * @param temperature the temperature value to set + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setTemperature(double temperature) { + this.temperature = temperature; + return this; + } + + /** + * Gets the tools available to be used by the LLM during the conversation. + * + * @return the list of tools + */ + public List getTools() { + return tools; + } + + /** + * Sets the tools available to be used by the LLM during the conversation. + * These are sent on a per request basis. + * + * @param tools the tools to set + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setTools(List tools) { + this.tools = tools; + return this; + } + + /** + * Gets the tool choice setting which controls which (if any) tool is called by the model. + * + * @return the tool choice setting + */ + public String getToolChoice() { + return toolChoice; + } + + /** + * Sets the tool choice setting which controls which (if any) tool is called by the model. + * - "none" means the model will not call any tool and instead generates a message + * - "auto" means the model can pick between generating a message or calling one or more tools + * - "required" requires one or more functions to be called + * - Alternatively, a specific tool name may be used here + * + * @param toolChoice the tool choice setting to set + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + return this; + } + + /** + * Gets the parameters for all custom fields. + * + * @return the parameters map + */ + public Map getParameters() { + return parameters; + } + + /** + * Sets the parameters for all custom fields. + * + * @param parameters the parameters to set + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setParameters(Map parameters) { + this.parameters = parameters; + return this; + } + + /** + * Gets the metadata passing to conversation components. + * + * @return the metadata map + */ + public Map getMetadata() { + return metadata; + } + + /** + * Sets the metadata passing to conversation components. + * + * @param metadata the metadata to set + * @return the current instance of {@link ConversationRequestAlpha2} + */ + public ConversationRequestAlpha2 setMetadata(Map metadata) { + this.metadata = metadata; + return this; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationResponseAlpha2.java b/sdk/src/main/java/io/dapr/client/domain/ConversationResponseAlpha2.java new file mode 100644 index 000000000..9ef6dac7a --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationResponseAlpha2.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Alpha2 response from the Dapr Conversation API with enhanced features. + */ +public class ConversationResponseAlpha2 { + + private final String contextId; + private final List outputs; + + /** + * Constructor. + * + * @param contextId context id supplied to LLM. + * @param outputs outputs from the LLM (Alpha2 format). + */ + public ConversationResponseAlpha2(String contextId, List outputs) { + this.contextId = contextId; + this.outputs = List.copyOf(outputs); + } + + /** + * The ID of an existing chat (like in ChatGPT). + * + * @return String identifier. + */ + public String getContextId() { + return this.contextId; + } + + /** + * Get list of conversation outputs (Alpha2 format). + * + * @return List{@link ConversationResultAlpha2}. + */ + public List getOutputs() { + return this.outputs; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationResultAlpha2.java b/sdk/src/main/java/io/dapr/client/domain/ConversationResultAlpha2.java new file mode 100644 index 000000000..369caeb65 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationResultAlpha2.java @@ -0,0 +1,42 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Alpha2 result for conversation output with enhanced choice-based structure. + */ +public class ConversationResultAlpha2 { + + private final List choices; + + /** + * Constructor. + * + * @param choices the list of conversation result choices. + */ + public ConversationResultAlpha2(List choices) { + this.choices = List.copyOf(choices); + } + + /** + * Gets the list of conversation result choices. + * + * @return the list of conversation result choices + */ + public List getChoices() { + return choices; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationResultChoices.java b/sdk/src/main/java/io/dapr/client/domain/ConversationResultChoices.java new file mode 100644 index 000000000..468286987 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationResultChoices.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Represents a conversation result choice with finish reason, index, and message. + */ +public class ConversationResultChoices { + + private final String finishReason; + private final long index; + private final ConversationResultMessage message; + + /** + * Constructor. + * + * @param finishReason the reason the model stopped generating tokens + * @param index the index of the choice in the list of choices + * @param message the result message + */ + public ConversationResultChoices(String finishReason, long index, ConversationResultMessage message) { + this.finishReason = finishReason; + this.index = index; + this.message = message; + } + + /** + * Gets the reason the model stopped generating tokens. + * This will be "stop" if the model hit a natural stop point or a provided stop sequence, + * "length" if the maximum number of tokens specified in the request was reached, + * "content_filter" if content was omitted due to a flag from content filters, + * "tool_calls" if the model called a tool. + * + * @return the finish reason + */ + public String getFinishReason() { + return finishReason; + } + + /** + * Gets the index of the choice in the list of choices. + * + * @return the index + */ + public long getIndex() { + return index; + } + + /** + * Gets the result message. + * + * @return the message + */ + public ConversationResultMessage getMessage() { + return message; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationResultMessage.java b/sdk/src/main/java/io/dapr/client/domain/ConversationResultMessage.java new file mode 100644 index 000000000..bbeeebae6 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationResultMessage.java @@ -0,0 +1,72 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Represents a conversation result message with content and optional tool calls. + */ +public class ConversationResultMessage { + + private final String content; + private final List toolCalls; + + /** + * Constructor. + * + * @param content the contents of the message + * @param toolCalls the tool calls generated by the model (optional) + */ + public ConversationResultMessage(String content, List toolCalls) { + this.content = content; + this.toolCalls = toolCalls != null ? List.copyOf(toolCalls) : null; + } + + /** + * Constructor for message without tool calls. + * + * @param content the contents of the message + */ + public ConversationResultMessage(String content) { + this(content, null); + } + + /** + * Gets the contents of the message. + * + * @return the message content + */ + public String getContent() { + return content; + } + + /** + * Gets the tool calls generated by the model. + * + * @return the tool calls, or null if none + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Checks if the message has tool calls. + * + * @return true if there are tool calls, false otherwise + */ + public boolean hasToolCalls() { + return toolCalls != null && !toolCalls.isEmpty(); + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java new file mode 100644 index 000000000..73646e664 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Represents a tool call request sent from the LLM to the client to execute. + */ +public class ConversationToolCalls { + + private String id; + private final ConversationToolCallsOfFunction function; + + /** + * Constructor without ID. + * + * @param function the function to call + */ + public ConversationToolCalls(ConversationToolCallsOfFunction function) { + this.function = function; + } + + /** + * Gets the unique identifier for the tool call. + * + * @return the tool call ID, or null if not provided + */ + public String getId() { + return id; + } + + /** + * Set with ID. + * + * @param id the unique identifier for the tool call + * @return this instance for method chaining + */ + public ConversationToolCalls setId(String id) { + this.id = id; + return this; + } + + /** + * Gets the function to call. + * + * @return the function details + */ + public ConversationToolCallsOfFunction getFunction() { + return function; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsOfFunction.java b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsOfFunction.java new file mode 100644 index 000000000..0edacd7d9 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsOfFunction.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Represents a function call within a tool call. + */ +public class ConversationToolCallsOfFunction { + + private final String name; + private final String arguments; + + /** + * Constructor. + * + * @param name the name of the function to call + * @param arguments the arguments to call the function with, as generated by the model in JSON format + */ + public ConversationToolCallsOfFunction(String name, String arguments) { + this.name = name; + this.arguments = arguments; + } + + /** + * Gets the name of the function to call. + * + * @return the function name + */ + public String getName() { + return name; + } + + /** + * Gets the arguments to call the function with, as generated by the model in JSON format. + * Note that the model does not always generate valid JSON, and may hallucinate parameters + * not defined by your function schema. Validate the arguments in your code before calling your function. + * + * @return the function arguments in JSON format + */ + public String getArguments() { + return arguments; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java b/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java new file mode 100644 index 000000000..f36fa0545 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Represents tool definitions that can be used during conversation. + */ +public class ConversationTools { + + private final ConversationToolsFunction function; + + /** + * Constructor. + * + * @param function the function definition + */ + public ConversationTools(ConversationToolsFunction function) { + this.function = function; + } + + /** + * Gets the function definition. + * + * @return the function definition + */ + public ConversationToolsFunction getFunction() { + return function; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationToolsFunction.java b/sdk/src/main/java/io/dapr/client/domain/ConversationToolsFunction.java new file mode 100644 index 000000000..11e3afdf3 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationToolsFunction.java @@ -0,0 +1,75 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.Map; + +/** + * Represents a function definition for conversation tools. + */ +public class ConversationToolsFunction { + + private String description; + private final String name; + private final Map parameters; + + /** + * Constructor. + * + * @param name the function name + * @param parameters the function parameters schema + */ + public ConversationToolsFunction(String name, Map parameters) { + this.name = name; + this.parameters = parameters; + } + + /** + * Gets the function name. + * + * @return the function name + */ + public String getName() { + return name; + } + + /** + * Gets the function description. + * + * @return the function description + */ + public String getDescription() { + return description; + } + + /** + * Sets the function description. + * + * @param description the function description + * @return this instance for method chaining + */ + public ConversationToolsFunction setDescription(String description) { + this.description = description; + return this; + } + + /** + * Gets the function parameters schema. + * + * @return the function parameters + */ + public Map getParameters() { + return parameters; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java b/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java new file mode 100644 index 000000000..5dbaa58a5 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Developer message for development and debugging purposes. + * Used for providing additional context or instructions during development. + */ +public class DeveloperMessage implements ConversationMessage { + + private String name; + private final List content; + + /** + * Creates a developer message with content. + * + * @param content the content of the developer message + */ + public DeveloperMessage(List content) { + this.content = List.copyOf(content); + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.DEVELOPER; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the developer participant. + * + * @param name the name to set + * @return this instance for method chaining + */ + public DeveloperMessage setName(String name) { + this.name = name; + return this; + } + + @Override + public List getContent() { + return content; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java b/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java new file mode 100644 index 000000000..aacdb80d5 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * System message that sets the behavior or context for the conversation. + * Used to provide instructions or context to the AI model. + */ +public class SystemMessage implements ConversationMessage { + + private String name; + private final List content; + + /** + * Creates a system message with content. + * + * @param content the content of the system message + */ + public SystemMessage(List content) { + this.content = List.copyOf(content); + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.SYSTEM; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the system participant. + * + * @param name the name to set + */ + public void setName(String name) { + this.name = name; + } + + @Override + public List getContent() { + return content; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java b/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java new file mode 100644 index 000000000..e88e37af4 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Tool message containing results from function/tool calls. + * Used to provide the response from a tool execution back to the AI model. + */ +public class ToolMessage implements ConversationMessage { + + private String toolId; + private String name; + private final List content; + + /** + * Creates a tool message with content. + * + * @param content the content containing the tool execution result + */ + public ToolMessage(List content) { + this.content = List.copyOf(content); + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.TOOL; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the tool identifier. + * + * @param toolId the tool identifier to set + * @return this instance for method chaining + */ + public ToolMessage setToolId(String toolId) { + this.toolId = toolId; + return this; + } + + /** + * Sets the name of the tool participant. + * + * @param name the name to set + * @return this instance for method chaining + */ + public ToolMessage setName(String name) { + this.name = name; + return this; + } + + @Override + public List getContent() { + return content; + } + + public String getToolId() { + return toolId; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/UserMessage.java b/sdk/src/main/java/io/dapr/client/domain/UserMessage.java new file mode 100644 index 000000000..30ae023a8 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/UserMessage.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * User message containing input from the human user. + * Represents questions, requests, or other input from the end user. + */ +public class UserMessage implements ConversationMessage { + + private String name; + private final List content; + + /** + * Creates a user message with content. + * + * @param content the content of the user message + */ + public UserMessage(List content) { + this.content = List.copyOf(content); + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.USER; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the user participant. + * + * @param name the name to set + * @return this instance for method chaining + */ + public UserMessage setName(String name) { + this.name = name; + return this; + } + + @Override + public List getContent() { + return content; + } +} diff --git a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java index 8f0667b8d..f7b5584cc 100644 --- a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java +++ b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java @@ -16,12 +16,29 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; import com.google.protobuf.Any; import com.google.protobuf.ByteString; +import io.dapr.client.domain.AssistantMessage; import io.dapr.client.domain.BulkPublishEntry; import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; import io.dapr.client.domain.CloudEvent; +import io.dapr.client.domain.ConversationToolCallsOfFunction; +import io.dapr.client.domain.ConversationToolsFunction; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessage; +import io.dapr.client.domain.ConversationMessageContent; +import io.dapr.client.domain.ConversationRequestAlpha2; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.ConversationToolCalls; +import io.dapr.client.domain.ConversationTools; +import io.dapr.client.domain.DeleteJobRequest; +import io.dapr.client.domain.DeveloperMessage; +import io.dapr.client.domain.GetJobRequest; +import io.dapr.client.domain.GetJobResponse; import io.dapr.client.domain.ConstantFailurePolicy; import io.dapr.client.domain.ConversationInput; import io.dapr.client.domain.ConversationRequest; @@ -35,7 +52,10 @@ import io.dapr.client.domain.QueryStateRequest; import io.dapr.client.domain.QueryStateResponse; import io.dapr.client.domain.ScheduleJobRequest; +import io.dapr.client.domain.SystemMessage; +import io.dapr.client.domain.ToolMessage; import io.dapr.client.domain.UnlockResponseStatus; +import io.dapr.client.domain.UserMessage; import io.dapr.client.domain.query.Query; import io.dapr.serializer.DaprObjectSerializer; import io.dapr.serializer.DefaultObjectSerializer; @@ -1329,6 +1349,473 @@ public void deleteJobShouldThrowWhenNameIsEmptyRequest() { assertEquals("Name in the request cannot be null or empty", exception.getMessage()); } + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsNull() { + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2(null, List.of(input)); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsEmpty() { + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("", List.of(input)); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsWhitespace() { + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2(" ", null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenInputIsNull() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("abc", null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("Conversation Inputs cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenInputIsEmpty() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("abc", new ArrayList<>()); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("Conversation Inputs cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ExceptionThrownTest() { + doAnswer((Answer) invocation -> { + throw newStatusRuntimeException("INVALID_ARGUMENT", "bad argument"); + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + + assertThrows(IllegalArgumentException.class, () -> previewClient.converseAlpha2(request).block()); + } + + @Test + public void converseAlpha2CallbackExceptionThrownTest() { + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onError(newStatusRuntimeException("INVALID_ARGUMENT", "bad argument")); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + Mono result = previewClient.converseAlpha2(request); + + assertThrowsDaprException( + ExecutionException.class, + "INVALID_ARGUMENT", + "INVALID_ARGUMENT: bad argument", + () -> result.block()); + } + + @Test + public void converseAlpha2MinimalRequestTest() { + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .setContextId("test-context") + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Hello! How can I help you today?") + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + List messages = new ArrayList<>(); + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals("test-context", response.getContextId()); + assertEquals(1, response.getOutputs().size()); + + ConversationResultAlpha2 result = response.getOutputs().get(0); + assertEquals(1, result.getChoices().size()); + + ConversationResultChoices choice = result.getChoices().get(0); + assertEquals("stop", choice.getFinishReason()); + assertEquals(0, choice.getIndex()); + assertEquals("Hello! How can I help you today?", choice.getMessage().getContent()); + } + + @Test + public void converseAlpha2ComplexRequestTest() { + // Create messages + List messages = new ArrayList<>(); + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("Hello, how are you?"))); + userMessage.setName("John"); + messages.add(userMessage); + + // Create input + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + input.setScrubPii(true); + + // Create tools + Map functionParams = new HashMap<>(); + functionParams.put("location", "Required location parameter"); + List tools = new ArrayList<>(); + ConversationToolsFunction function = new ConversationToolsFunction("get_weather", functionParams); + function.setDescription("Get current weather"); + + ConversationTools tool = new ConversationTools(function); + tools.add(tool); + + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + + Map parameters = new HashMap<>(); + parameters.put("max_tokens", "1000"); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + request.setContextId("test-context"); + request.setTemperature(0.7); + request.setScrubPii(true); + request.setTools(tools); + request.setToolChoice("auto"); + request.setMetadata(metadata); + request.setParameters(parameters); + + // Mock response with tool calls + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .setContextId("test-context") + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("tool_calls") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("I'll help you get the weather information.") + .addToolCalls(DaprProtos.ConversationToolCalls.newBuilder() + .setId("call_123") + .setFunction(DaprProtos.ConversationToolCallsOfFunction.newBuilder() + .setName("get_weather") + .setArguments("{\"location\": \"New York\"}") + .build()) + .build()) + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals("test-context", response.getContextId()); + + ConversationResultChoices choice = response.getOutputs().get(0).getChoices().get(0); + assertEquals("tool_calls", choice.getFinishReason()); + assertEquals("I'll help you get the weather information.", choice.getMessage().getContent()); + assertEquals(1, choice.getMessage().getToolCalls().size()); + + ConversationToolCalls toolCall = choice.getMessage().getToolCalls().get(0); + assertEquals("call_123", toolCall.getId()); + assertEquals("get_weather", toolCall.getFunction().getName()); + assertEquals("{\"location\": \"New York\"}", toolCall.getFunction().getArguments()); + + // Verify the request was built correctly + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequestAlpha2.class); + verify(daprStub).converseAlpha2(captor.capture(), any()); + + DaprProtos.ConversationRequestAlpha2 capturedRequest = captor.getValue(); + assertEquals("openai", capturedRequest.getName()); + assertEquals("test-context", capturedRequest.getContextId()); + assertEquals(0.7, capturedRequest.getTemperature(), 0.001); + assertTrue(capturedRequest.getScrubPii()); + assertEquals("auto", capturedRequest.getToolChoice()); + assertEquals("value1", capturedRequest.getMetadataMap().get("key1")); + assertEquals(1, capturedRequest.getToolsCount()); + assertEquals("get_weather", capturedRequest.getTools(0).getFunction().getName()); + } + + @Test + public void converseAlpha2AllMessageTypesTest() { + List messages = new ArrayList<>(); + + // System message + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("You are a helpful assistant."))); + systemMsg.setName("system"); + messages.add(systemMsg); + + // User message + UserMessage userMsg = new UserMessage(List.of(new ConversationMessageContent("Hello!"))); + userMsg.setName("user"); + messages.add(userMsg); + + // Assistant message + AssistantMessage assistantMsg = new AssistantMessage(List.of(new ConversationMessageContent("Hi there!")), + List.of(new ConversationToolCalls(new ConversationToolCallsOfFunction("abc", "parameters")))); + assistantMsg.setName("assistant"); + messages.add(assistantMsg); + + // Tool message + ToolMessage toolMsg = new ToolMessage(List.of(new ConversationMessageContent("Weather data: 72F"))); + toolMsg.setName("tool"); + messages.add(toolMsg); + + // Developer message + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Processed all message types") + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals("Processed all message types", response.getOutputs().get(0).getChoices().get(0).getMessage().getContent()); + + // Verify all message types were processed + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequestAlpha2.class); + verify(daprStub).converseAlpha2(captor.capture(), any()); + + DaprProtos.ConversationRequestAlpha2 capturedRequest = captor.getValue(); + assertEquals(1, capturedRequest.getInputsCount()); + assertEquals(5, capturedRequest.getInputs(0).getMessagesCount()); + + // Verify each message type was converted correctly + List capturedMessages = capturedRequest.getInputs(0).getMessagesList(); + assertTrue(capturedMessages.get(0).hasOfSystem()); + assertTrue(capturedMessages.get(1).hasOfUser()); + assertTrue(capturedMessages.get(2).hasOfAssistant()); + assertTrue(capturedMessages.get(3).hasOfTool()); + assertTrue(capturedMessages.get(4).hasOfDeveloper()); + } + + @Test + public void converseAlpha2ResponseWithoutMessageTest() { + List messages = new ArrayList<>(); + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + // No message set + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + ConversationResultChoices choice = response.getOutputs().get(0).getChoices().get(0); + assertEquals("stop", choice.getFinishReason()); + assertEquals(0, choice.getIndex()); + assertNull(choice.getMessage()); + } + + @Test + public void converseAlpha2MultipleResultsTest() { + List messages = new ArrayList<>(); + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("First choice") + .build()) + .build()) + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(1) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Second choice") + .build()) + .build()) + .build()) + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("length") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Third result") + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals(2, response.getOutputs().size()); + + // First result with 2 choices + ConversationResultAlpha2 firstResult = response.getOutputs().get(0); + assertEquals(2, firstResult.getChoices().size()); + assertEquals("First choice", firstResult.getChoices().get(0).getMessage().getContent()); + assertEquals("Second choice", firstResult.getChoices().get(1).getMessage().getContent()); + + // Second result with 1 choice + ConversationResultAlpha2 secondResult = response.getOutputs().get(1); + assertEquals(1, secondResult.getChoices().size()); + assertEquals("Third result", secondResult.getChoices().get(0).getMessage().getContent()); + } + + @Test + public void converseAlpha2ToolCallWithoutFunctionTest() { + List messages = new ArrayList<>(); + UserMessage userMsg = new UserMessage(List.of(new ConversationMessageContent("Debug info"))); + userMsg.setName("developer"); + messages.add(userMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("tool_calls") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Test content") + .addToolCalls(DaprProtos.ConversationToolCalls.newBuilder() + .setId("call_123") + // No function set + .build()) + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + ConversationToolCalls toolCall = response.getOutputs().get(0).getChoices().get(0) + .getMessage().getToolCalls().get(0); + assertEquals("call_123", toolCall.getId()); + assertNull(toolCall.getFunction()); + } + private DaprProtos.QueryStateResponse buildQueryStateResponse(List> resp,String token) throws JsonProcessingException { List items = new ArrayList<>(); diff --git a/sdk/src/test/java/io/dapr/client/ProtobufValueHelperTest.java b/sdk/src/test/java/io/dapr/client/ProtobufValueHelperTest.java new file mode 100644 index 000000000..c345f34ff --- /dev/null +++ b/sdk/src/test/java/io/dapr/client/ProtobufValueHelperTest.java @@ -0,0 +1,423 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client; + +import com.google.protobuf.ListValue; +import com.google.protobuf.NullValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ProtobufValueHelperTest { + + @Test + public void testToProtobufValue_Null() throws IOException { + Value result = ProtobufValueHelper.toProtobufValue(null); + + assertNotNull(result); + assertTrue(result.hasNullValue()); + assertEquals(NullValue.NULL_VALUE, result.getNullValue()); + } + + @Test + public void testToProtobufValue_Boolean_True() throws IOException { + Value result = ProtobufValueHelper.toProtobufValue(true); + + assertNotNull(result); + assertTrue(result.hasBoolValue()); + assertEquals(true, result.getBoolValue()); + } + + @Test + public void testToProtobufValue_Boolean_False() throws IOException { + Value result = ProtobufValueHelper.toProtobufValue(false); + + assertNotNull(result); + assertTrue(result.hasBoolValue()); + assertEquals(false, result.getBoolValue()); + } + + @Test + public void testToProtobufValue_String() throws IOException { + String testString = "Hello, World!"; + Value result = ProtobufValueHelper.toProtobufValue(testString); + + assertNotNull(result); + assertTrue(result.hasStringValue()); + assertEquals(testString, result.getStringValue()); + } + + @Test + public void testToProtobufValue_String_Empty() throws IOException { + String emptyString = ""; + Value result = ProtobufValueHelper.toProtobufValue(emptyString); + + assertNotNull(result); + assertTrue(result.hasStringValue()); + assertEquals(emptyString, result.getStringValue()); + } + + @Test + public void testToProtobufValue_Integer() throws IOException { + Integer testInt = 42; + Value result = ProtobufValueHelper.toProtobufValue(testInt); + + assertNotNull(result); + assertTrue(result.hasNumberValue()); + assertEquals(42.0, result.getNumberValue(), 0.001); + } + + @Test + public void testToProtobufValue_Long() throws IOException { + Long testLong = 9876543210L; + Value result = ProtobufValueHelper.toProtobufValue(testLong); + + assertNotNull(result); + assertTrue(result.hasNumberValue()); + assertEquals(9876543210.0, result.getNumberValue(), 0.001); + } + + @Test + public void testToProtobufValue_Double() throws IOException { + Double testDouble = 3.14159; + Value result = ProtobufValueHelper.toProtobufValue(testDouble); + + assertNotNull(result); + assertTrue(result.hasNumberValue()); + assertEquals(testDouble, result.getNumberValue(), 0.00001); + } + + @Test + public void testToProtobufValue_Float() throws IOException { + Float testFloat = 2.718f; + Value result = ProtobufValueHelper.toProtobufValue(testFloat); + + assertNotNull(result); + assertTrue(result.hasNumberValue()); + assertEquals(2.718, result.getNumberValue(), 0.001); + } + + @Test + public void testToProtobufValue_BigInteger() throws IOException { + BigInteger testBigInt = new BigInteger("123456789012345678901234567890"); + Value result = ProtobufValueHelper.toProtobufValue(testBigInt); + + assertNotNull(result); + assertTrue(result.hasNumberValue()); + assertEquals(1.2345678901234568E29, result.getNumberValue(), 1E20); + } + + @Test + public void testToProtobufValue_BigDecimal() throws IOException { + BigDecimal testBigDecimal = new BigDecimal("123.456789"); + Value result = ProtobufValueHelper.toProtobufValue(testBigDecimal); + + assertNotNull(result); + assertTrue(result.hasNumberValue()); + assertEquals(123.456789, result.getNumberValue(), 0.000001); + } + + @Test + public void testToProtobufValue_EmptyList() throws IOException { + List emptyList = new ArrayList<>(); + Value result = ProtobufValueHelper.toProtobufValue(emptyList); + + assertNotNull(result); + assertTrue(result.hasListValue()); + ListValue listValue = result.getListValue(); + assertEquals(0, listValue.getValuesCount()); + } + + @Test + public void testToProtobufValue_SimpleList() throws IOException { + List testList = Arrays.asList("hello", 42, true, null); + Value result = ProtobufValueHelper.toProtobufValue(testList); + + assertNotNull(result); + assertTrue(result.hasListValue()); + ListValue listValue = result.getListValue(); + assertEquals(4, listValue.getValuesCount()); + + // Verify each element + assertEquals("hello", listValue.getValues(0).getStringValue()); + assertEquals(42.0, listValue.getValues(1).getNumberValue(), 0.001); + assertEquals(true, listValue.getValues(2).getBoolValue()); + assertEquals(NullValue.NULL_VALUE, listValue.getValues(3).getNullValue()); + } + + @Test + public void testToProtobufValue_NestedList() throws IOException { + List innerList = Arrays.asList(1, 2, 3); + List outerList = Arrays.asList("outer", innerList, "end"); + Value result = ProtobufValueHelper.toProtobufValue(outerList); + + assertNotNull(result); + assertTrue(result.hasListValue()); + ListValue listValue = result.getListValue(); + assertEquals(3, listValue.getValuesCount()); + + // Verify nested list + assertEquals("outer", listValue.getValues(0).getStringValue()); + assertTrue(listValue.getValues(1).hasListValue()); + ListValue nestedList = listValue.getValues(1).getListValue(); + assertEquals(3, nestedList.getValuesCount()); + assertEquals(1.0, nestedList.getValues(0).getNumberValue(), 0.001); + assertEquals(2.0, nestedList.getValues(1).getNumberValue(), 0.001); + assertEquals(3.0, nestedList.getValues(2).getNumberValue(), 0.001); + assertEquals("end", listValue.getValues(2).getStringValue()); + } + + @Test + public void testToProtobufValue_EmptyMap() throws IOException { + Map emptyMap = new HashMap<>(); + Value result = ProtobufValueHelper.toProtobufValue(emptyMap); + + assertNotNull(result); + assertTrue(result.hasStructValue()); + Struct struct = result.getStructValue(); + assertEquals(0, struct.getFieldsCount()); + } + + @Test + public void testToProtobufValue_SimpleMap() throws IOException { + Map testMap = new LinkedHashMap<>(); + testMap.put("name", "John Doe"); + testMap.put("age", 30); + testMap.put("active", true); + testMap.put("description", null); + + Value result = ProtobufValueHelper.toProtobufValue(testMap); + + assertNotNull(result); + assertTrue(result.hasStructValue()); + Struct struct = result.getStructValue(); + assertEquals(4, struct.getFieldsCount()); + + // Verify each field + assertEquals("John Doe", struct.getFieldsMap().get("name").getStringValue()); + assertEquals(30.0, struct.getFieldsMap().get("age").getNumberValue(), 0.001); + assertEquals(true, struct.getFieldsMap().get("active").getBoolValue()); + assertEquals(NullValue.NULL_VALUE, struct.getFieldsMap().get("description").getNullValue()); + } + + @Test + public void testToProtobufValue_NestedMap() throws IOException { + Map innerMap = new HashMap<>(); + innerMap.put("city", "New York"); + innerMap.put("zipcode", 10001); + + Map outerMap = new HashMap<>(); + outerMap.put("name", "John"); + outerMap.put("address", innerMap); + outerMap.put("hobbies", Arrays.asList("reading", "coding")); + + Value result = ProtobufValueHelper.toProtobufValue(outerMap); + + assertNotNull(result); + assertTrue(result.hasStructValue()); + Struct struct = result.getStructValue(); + assertEquals(3, struct.getFieldsCount()); + + // Verify nested structure + assertEquals("John", struct.getFieldsMap().get("name").getStringValue()); + + // Verify nested map + assertTrue(struct.getFieldsMap().get("address").hasStructValue()); + Struct nestedStruct = struct.getFieldsMap().get("address").getStructValue(); + assertEquals("New York", nestedStruct.getFieldsMap().get("city").getStringValue()); + assertEquals(10001.0, nestedStruct.getFieldsMap().get("zipcode").getNumberValue(), 0.001); + + // Verify nested list + assertTrue(struct.getFieldsMap().get("hobbies").hasListValue()); + ListValue hobbiesList = struct.getFieldsMap().get("hobbies").getListValue(); + assertEquals(2, hobbiesList.getValuesCount()); + assertEquals("reading", hobbiesList.getValues(0).getStringValue()); + assertEquals("coding", hobbiesList.getValues(1).getStringValue()); + } + + @Test + public void testToProtobufValue_MapWithNonStringKeys() throws IOException { + Map intKeyMap = new HashMap<>(); + intKeyMap.put(1, "one"); + intKeyMap.put(2, "two"); + + Value result = ProtobufValueHelper.toProtobufValue(intKeyMap); + + assertNotNull(result); + assertTrue(result.hasStructValue()); + Struct struct = result.getStructValue(); + assertEquals(2, struct.getFieldsCount()); + + // Keys should be converted to strings + assertTrue(struct.getFieldsMap().containsKey("1")); + assertTrue(struct.getFieldsMap().containsKey("2")); + assertEquals("one", struct.getFieldsMap().get("1").getStringValue()); + assertEquals("two", struct.getFieldsMap().get("2").getStringValue()); + } + + @Test + public void testToProtobufValue_CustomObject() throws IOException { + // Test with a custom object that will fall back to toString() + TestCustomObject customObj = new TestCustomObject("test", 123); + Value result = ProtobufValueHelper.toProtobufValue(customObj); + + assertNotNull(result); + assertTrue(result.hasStringValue()); + assertEquals("TestCustomObject{name='test', value=123}", result.getStringValue()); + } + + @Test + public void testToProtobufValue_ComplexNestedStructure() throws IOException { + // Create a complex nested structure + Map config = new HashMap<>(); + config.put("timeout", 30); + config.put("retries", 3); + + Map server = new HashMap<>(); + server.put("host", "localhost"); + server.put("port", 8080); + server.put("ssl", true); + server.put("config", config); + + List tags = Arrays.asList("prod", "critical", "monitoring"); + + Map application = new HashMap<>(); + application.put("name", "my-app"); + application.put("version", "1.0.0"); + application.put("server", server); + application.put("tags", tags); + application.put("metadata", null); + + Value result = ProtobufValueHelper.toProtobufValue(application); + + assertNotNull(result); + assertTrue(result.hasStructValue()); + Struct appStruct = result.getStructValue(); + + // Verify top-level fields + assertEquals("my-app", appStruct.getFieldsMap().get("name").getStringValue()); + assertEquals("1.0.0", appStruct.getFieldsMap().get("version").getStringValue()); + assertEquals(NullValue.NULL_VALUE, appStruct.getFieldsMap().get("metadata").getNullValue()); + + // Verify server object + assertTrue(appStruct.getFieldsMap().get("server").hasStructValue()); + Struct serverStruct = appStruct.getFieldsMap().get("server").getStructValue(); + assertEquals("localhost", serverStruct.getFieldsMap().get("host").getStringValue()); + assertEquals(8080.0, serverStruct.getFieldsMap().get("port").getNumberValue(), 0.001); + assertEquals(true, serverStruct.getFieldsMap().get("ssl").getBoolValue()); + + // Verify nested config + assertTrue(serverStruct.getFieldsMap().get("config").hasStructValue()); + Struct configStruct = serverStruct.getFieldsMap().get("config").getStructValue(); + assertEquals(30.0, configStruct.getFieldsMap().get("timeout").getNumberValue(), 0.001); + assertEquals(3.0, configStruct.getFieldsMap().get("retries").getNumberValue(), 0.001); + + // Verify tags list + assertTrue(appStruct.getFieldsMap().get("tags").hasListValue()); + ListValue tagsList = appStruct.getFieldsMap().get("tags").getListValue(); + assertEquals(3, tagsList.getValuesCount()); + assertEquals("prod", tagsList.getValues(0).getStringValue()); + assertEquals("critical", tagsList.getValues(1).getStringValue()); + assertEquals("monitoring", tagsList.getValues(2).getStringValue()); + } + + @Test + public void testToProtobufValue_OpenAPIFunctionSchema() throws IOException { + // Test with the exact schema structure provided by the user + Map functionSchema = new LinkedHashMap<>(); + functionSchema.put("type", "function"); + functionSchema.put("name", "get_horoscope"); + functionSchema.put("description", "Get today's horoscope for an astrological sign."); + + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new LinkedHashMap<>(); + Map signProperty = new LinkedHashMap<>(); + signProperty.put("type", "string"); + signProperty.put("description", "An astrological sign like Taurus or Aquarius"); + properties.put("sign", signProperty); + + parameters.put("properties", properties); + parameters.put("required", Arrays.asList("sign")); + + functionSchema.put("parameters", parameters); + + Value result = ProtobufValueHelper.toProtobufValue(functionSchema); + + assertNotNull(result); + assertTrue(result.hasStructValue()); + Struct rootStruct = result.getStructValue(); + + // Verify root level fields + assertEquals("function", rootStruct.getFieldsMap().get("type").getStringValue()); + assertEquals("get_horoscope", rootStruct.getFieldsMap().get("name").getStringValue()); + assertEquals("Get today's horoscope for an astrological sign.", + rootStruct.getFieldsMap().get("description").getStringValue()); + + // Verify parameters object + assertTrue(rootStruct.getFieldsMap().get("parameters").hasStructValue()); + Struct parametersStruct = rootStruct.getFieldsMap().get("parameters").getStructValue(); + assertEquals("object", parametersStruct.getFieldsMap().get("type").getStringValue()); + + // Verify properties object + assertTrue(parametersStruct.getFieldsMap().get("properties").hasStructValue()); + Struct propertiesStruct = parametersStruct.getFieldsMap().get("properties").getStructValue(); + + // Verify sign property + assertTrue(propertiesStruct.getFieldsMap().get("sign").hasStructValue()); + Struct signStruct = propertiesStruct.getFieldsMap().get("sign").getStructValue(); + assertEquals("string", signStruct.getFieldsMap().get("type").getStringValue()); + assertEquals("An astrological sign like Taurus or Aquarius", + signStruct.getFieldsMap().get("description").getStringValue()); + + // Verify required array + assertTrue(parametersStruct.getFieldsMap().get("required").hasListValue()); + ListValue requiredList = parametersStruct.getFieldsMap().get("required").getListValue(); + assertEquals(1, requiredList.getValuesCount()); + assertEquals("sign", requiredList.getValues(0).getStringValue()); + } + + /** + * Helper class for testing custom object conversion + */ + private static class TestCustomObject { + private final String name; + private final int value; + + public TestCustomObject(String name, int value) { + this.name = name; + this.value = value; + } + + @Override + public String toString() { + return "TestCustomObject{name='" + name + "', value=" + value + "}"; + } + } +}