|
17 | 17 | import com.google.protobuf.Any;
|
18 | 18 | import com.google.protobuf.ByteString;
|
19 | 19 | import com.google.protobuf.Empty;
|
| 20 | +import com.google.protobuf.Message; |
20 | 21 | import io.dapr.client.domain.ActorMetadata;
|
21 | 22 | import io.dapr.client.domain.AppConnectionPropertiesHealthMetadata;
|
22 | 23 | import io.dapr.client.domain.AppConnectionPropertiesMetadata;
|
|
27 | 28 | import io.dapr.client.domain.CloudEvent;
|
28 | 29 | import io.dapr.client.domain.ComponentMetadata;
|
29 | 30 | import io.dapr.client.domain.ConfigurationItem;
|
| 31 | +import io.dapr.client.domain.ConversationFunction; |
30 | 32 | import io.dapr.client.domain.ConversationInput;
|
| 33 | +import io.dapr.client.domain.ConversationInputAlpha2; |
| 34 | +import io.dapr.client.domain.ConversationMessage; |
| 35 | +import io.dapr.client.domain.ConversationMessageContent; |
31 | 36 | import io.dapr.client.domain.ConversationOutput;
|
32 | 37 | import io.dapr.client.domain.ConversationRequest;
|
| 38 | +import io.dapr.client.domain.ConversationRequestAlpha2; |
33 | 39 | import io.dapr.client.domain.ConversationResponse;
|
| 40 | +import io.dapr.client.domain.ConversationResponseAlpha2; |
| 41 | +import io.dapr.client.domain.ConversationResultAlpha2; |
| 42 | +import io.dapr.client.domain.ConversationResultChoices; |
| 43 | +import io.dapr.client.domain.ConversationResultMessage; |
| 44 | +import io.dapr.client.domain.ConversationToolCalls; |
| 45 | +import io.dapr.client.domain.ConversationToolCallsFunction; |
| 46 | +import io.dapr.client.domain.ConversationTools; |
34 | 47 | import io.dapr.client.domain.DaprMetadata;
|
35 | 48 | import io.dapr.client.domain.DeleteJobRequest;
|
36 | 49 | import io.dapr.client.domain.DeleteStateRequest;
|
|
100 | 113 | import reactor.util.retry.Retry;
|
101 | 114 |
|
102 | 115 | import javax.annotation.Nonnull;
|
| 116 | + |
103 | 117 | import java.io.IOException;
|
104 | 118 | import java.time.Duration;
|
105 | 119 | import java.time.Instant;
|
@@ -1628,6 +1642,193 @@ private void validateConversationRequest(ConversationRequest conversationRequest
|
1628 | 1642 | }
|
1629 | 1643 | }
|
1630 | 1644 |
|
| 1645 | + /** |
| 1646 | + * {@inheritDoc} |
| 1647 | + */ |
| 1648 | + @Override |
| 1649 | + public Mono<ConversationResponseAlpha2> converseAlpha2(ConversationRequestAlpha2 conversationRequestAlpha2) { |
| 1650 | + |
| 1651 | + try { |
| 1652 | + validateConversationRequestAlpha2(conversationRequestAlpha2); |
| 1653 | + |
| 1654 | + DaprProtos.ConversationRequestAlpha2.Builder protosConversationRequestBuilder = |
| 1655 | + DaprProtos.ConversationRequestAlpha2 |
| 1656 | + .newBuilder() |
| 1657 | + .setTemperature(conversationRequestAlpha2.getTemperature()) |
| 1658 | + .setScrubPii(conversationRequestAlpha2.isScrubPii()) |
| 1659 | + .setName(conversationRequestAlpha2.getName()); |
| 1660 | + |
| 1661 | + if (conversationRequestAlpha2.getContextId() != null) { |
| 1662 | + protosConversationRequestBuilder.setContextId(conversationRequestAlpha2.getContextId()); |
| 1663 | + } |
| 1664 | + |
| 1665 | + if (conversationRequestAlpha2.getToolChoice() != null) { |
| 1666 | + protosConversationRequestBuilder.setToolChoice(conversationRequestAlpha2.getToolChoice()); |
| 1667 | + } |
| 1668 | + |
| 1669 | + if (conversationRequestAlpha2.getTools() != null) { |
| 1670 | + for (ConversationTools tool : conversationRequestAlpha2.getTools()) { |
| 1671 | + |
| 1672 | + ConversationFunction conversationFunction = tool.getFunction(); |
| 1673 | + |
| 1674 | + Map<String, Any> protosConversationToolFunctionParameters = conversationFunction.getParameters() |
| 1675 | + .entrySet().stream() |
| 1676 | + .collect(Collectors.toMap( |
| 1677 | + Map.Entry::getKey, |
| 1678 | + e -> Any.pack((Message) e.getValue()) |
| 1679 | + )); |
| 1680 | + DaprProtos.ConversationToolsFunction protosConversationToolsFunction = |
| 1681 | + DaprProtos.ConversationToolsFunction.newBuilder() |
| 1682 | + .setName(conversationFunction.getName()) |
| 1683 | + .setDescription(conversationFunction.getDescription()) |
| 1684 | + .putAllParameters(protosConversationToolFunctionParameters) |
| 1685 | + .build(); |
| 1686 | + |
| 1687 | + DaprProtos.ConversationTools conversationTool = DaprProtos.ConversationTools.newBuilder() |
| 1688 | + .setFunction(protosConversationToolsFunction).build(); |
| 1689 | + |
| 1690 | + protosConversationRequestBuilder.addTools(conversationTool); |
| 1691 | + } |
| 1692 | + } |
| 1693 | + |
| 1694 | + for (ConversationInputAlpha2 input : conversationRequestAlpha2.getInputs()) { |
| 1695 | + DaprProtos.ConversationInputAlpha2.Builder conversationInputBuilder = DaprProtos.ConversationInputAlpha2 |
| 1696 | + .newBuilder() |
| 1697 | + .setScrubPii(input.isScrubPii()); |
| 1698 | + |
| 1699 | + if (input.getMessages() != null) { |
| 1700 | + |
| 1701 | + for (ConversationMessage conversationMessage : input.getMessages()) { |
| 1702 | + DaprProtos.ConversationMessage.Builder messageBuilder = |
| 1703 | + DaprProtos.ConversationMessage.newBuilder(); |
| 1704 | + |
| 1705 | + ConversationMessage.Role role = conversationMessage.getRole(); |
| 1706 | + switch (role) { |
| 1707 | + case TOOL: |
| 1708 | + messageBuilder.setOfTool(DaprProtos.ConversationMessageOfTool.newBuilder() |
| 1709 | + .setToolId(conversationMessage.getToolId()) |
| 1710 | + .setName(conversationMessage.getName()) |
| 1711 | + .addAllContent(getConversationMessageContent(conversationMessage)).build()); |
| 1712 | + break; |
| 1713 | + case USER: |
| 1714 | + messageBuilder.setOfUser(DaprProtos.ConversationMessageOfUser.newBuilder() |
| 1715 | + .setName(conversationMessage.getName()) |
| 1716 | + .addAllContent(getConversationMessageContent(conversationMessage)).build()); |
| 1717 | + break; |
| 1718 | + case ASSISTANT: |
| 1719 | + messageBuilder.setOfAssistant(DaprProtos.ConversationMessageOfAssistant.newBuilder() |
| 1720 | + .setName(conversationMessage.getName()) |
| 1721 | + .addAllToolCalls(getConversationToolCalls(conversationMessage)) |
| 1722 | + .addAllContent(getConversationMessageContent(conversationMessage)).build()); |
| 1723 | + break; |
| 1724 | + case DEVELOPER: |
| 1725 | + messageBuilder.setOfDeveloper(DaprProtos.ConversationMessageOfDeveloper.newBuilder() |
| 1726 | + .setName(conversationMessage.getName()) |
| 1727 | + .addAllContent(getConversationMessageContent(conversationMessage)).build()); |
| 1728 | + break; |
| 1729 | + case SYSTEM: |
| 1730 | + messageBuilder.setOfSystem(DaprProtos.ConversationMessageOfSystem.newBuilder() |
| 1731 | + .setName(conversationMessage.getName()) |
| 1732 | + .addAllContent(getConversationMessageContent(conversationMessage)).build()); |
| 1733 | + break; |
| 1734 | + default: throw new IllegalArgumentException("No role of type " + role + " found"); |
| 1735 | + } |
| 1736 | + |
| 1737 | + conversationInputBuilder.addMessages(messageBuilder.build()); |
| 1738 | + } |
| 1739 | + } |
| 1740 | + |
| 1741 | + protosConversationRequestBuilder.addInputs(conversationInputBuilder.build()); |
| 1742 | + } |
| 1743 | + |
| 1744 | + Mono<DaprProtos.ConversationResponseAlpha2> conversationResponseMono = Mono.deferContextual( |
| 1745 | + context -> this.createMono( |
| 1746 | + it -> intercept(context, asyncStub) |
| 1747 | + .converseAlpha2(protosConversationRequestBuilder.build(), it) |
| 1748 | + ) |
| 1749 | + ); |
| 1750 | + |
| 1751 | + return conversationResponseMono.map(conversationResponse -> { |
| 1752 | + List<ConversationResultAlpha2> results = new ArrayList<>(); |
| 1753 | + |
| 1754 | + for (DaprProtos.ConversationResultAlpha2 result : conversationResponse.getOutputsList()) { |
| 1755 | + List<ConversationResultChoices> choices = new ArrayList<>(); |
| 1756 | + |
| 1757 | + for (DaprProtos.ConversationResultChoices choice : result.getChoicesList()) { |
| 1758 | + ConversationResultMessage message = null; |
| 1759 | + if (choice.hasMessage()) { |
| 1760 | + List<ConversationToolCalls> toolCalls = new ArrayList<>(); |
| 1761 | + |
| 1762 | + for (DaprProtos.ConversationToolCalls toolCall : choice.getMessage().getToolCallsList()) { |
| 1763 | + ConversationToolCallsFunction function = null; |
| 1764 | + if (toolCall.hasFunction()) { |
| 1765 | + function = new ConversationToolCallsFunction( |
| 1766 | + toolCall.getFunction().getName(), |
| 1767 | + toolCall.getFunction().getArguments() |
| 1768 | + ); |
| 1769 | + } |
| 1770 | + |
| 1771 | + toolCalls.add(new ConversationToolCalls(toolCall.getId(), function)); |
| 1772 | + } |
| 1773 | + |
| 1774 | + message = new ConversationResultMessage( |
| 1775 | + choice.getMessage().getContent(), |
| 1776 | + toolCalls |
| 1777 | + ); |
| 1778 | + } |
| 1779 | + |
| 1780 | + choices.add(new ConversationResultChoices(choice.getFinishReason(), choice.getIndex(), message)); |
| 1781 | + } |
| 1782 | + |
| 1783 | + results.add(new ConversationResultAlpha2(choices)); |
| 1784 | + } |
| 1785 | + |
| 1786 | + return new ConversationResponseAlpha2(conversationResponse.getContextId(), results); |
| 1787 | + }); |
| 1788 | + } catch (Exception ex) { |
| 1789 | + return DaprException.wrapMono(ex); |
| 1790 | + } |
| 1791 | + } |
| 1792 | + |
| 1793 | + private List<DaprProtos.ConversationMessageContent> getConversationMessageContent( |
| 1794 | + ConversationMessage conversationMessage) { |
| 1795 | + |
| 1796 | + List<DaprProtos.ConversationMessageContent> conversationMessageContents = new ArrayList<>(); |
| 1797 | + for (ConversationMessageContent conversationMessageContent: conversationMessage.getContent()) { |
| 1798 | + conversationMessageContents.add(DaprProtos.ConversationMessageContent.newBuilder() |
| 1799 | + .setText(conversationMessageContent.getText()) |
| 1800 | + .build()); |
| 1801 | + } |
| 1802 | + |
| 1803 | + return conversationMessageContents; |
| 1804 | + } |
| 1805 | + |
| 1806 | + private List<DaprProtos.ConversationToolCalls> getConversationToolCalls( |
| 1807 | + ConversationMessage conversationMessage) { |
| 1808 | + List<DaprProtos.ConversationToolCalls> conversationToolCalls = new ArrayList<>(); |
| 1809 | + for (ConversationToolCalls conversationToolCall: conversationMessage.getToolCalls()) { |
| 1810 | + conversationToolCalls.add(DaprProtos.ConversationToolCalls.newBuilder() |
| 1811 | + .setId(conversationToolCall.getId()) |
| 1812 | + .setFunction(DaprProtos.ConversationToolCallsOfFunction.newBuilder() |
| 1813 | + .setName(conversationToolCall.getFunction().getName()) |
| 1814 | + .setArguments(conversationToolCall.getFunction().getArguments()) |
| 1815 | + .build()) |
| 1816 | + .build()); |
| 1817 | + } |
| 1818 | + |
| 1819 | + return conversationToolCalls; |
| 1820 | + } |
| 1821 | + |
| 1822 | + private void validateConversationRequestAlpha2(ConversationRequestAlpha2 conversationRequest) { |
| 1823 | + if ((conversationRequest.getName() == null) || (conversationRequest.getName().trim().isEmpty())) { |
| 1824 | + throw new IllegalArgumentException("LLM name cannot be null or empty."); |
| 1825 | + } |
| 1826 | + |
| 1827 | + if ((conversationRequest.getInputs() == null) || (conversationRequest.getInputs().isEmpty())) { |
| 1828 | + throw new IllegalArgumentException("Conversation inputs cannot be null or empty."); |
| 1829 | + } |
| 1830 | + } |
| 1831 | + |
1631 | 1832 | private DaprMetadata buildDaprMetadata(DaprProtos.GetMetadataResponse response) throws IOException {
|
1632 | 1833 | String id = response.getId();
|
1633 | 1834 | String runtimeVersion = response.getRuntimeVersion();
|
|
0 commit comments