diff --git a/client/build.gradle b/client/build.gradle index 78d92f9947..399b034c4d 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -16,6 +16,7 @@ plugins { dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') + implementation project(path: ":${rootProject.name}-memory") compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 457f1b7d1f..44421a90c8 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -34,6 +34,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; /** * A client to provide interfaces for machine learning jobs. This will be used by other plugins. @@ -553,4 +554,22 @@ default void getConfig(String configId, ActionListener listener) { * @param listener a listener to be notified of the result */ void getConfig(String configId, String tenantId, ActionListener listener); + + /** + * Create conversational memory for conversation + * @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/ + * @return the result future + */ + default ActionFuture createConversation(String name) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + createConversation(name, actionFuture); + return actionFuture; + } + + /** + * Create conversational memory for conversation + * @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/ + * @param listener action listener + */ + void createConversation(String name, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 695b9f0892..42da38d03a 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -89,6 +89,9 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import lombok.AccessLevel; import lombok.RequiredArgsConstructor; @@ -318,6 +321,11 @@ public void getConfig(String configId, String tenantId, ActionListener client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener)); } + public void createConversation(String name, ActionListener listener) { + CreateConversationRequest createConversationRequest = new CreateConversationRequest(name); + client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener)); + } + private ActionListener getMlListToolsResponseActionListener(ActionListener> listener) { ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { listener.onResponse(mlModelListResponse.getToolMetadataList()); @@ -386,6 +394,16 @@ private ActionListener getMLRegisterModelResponseAction return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse); } + private ActionListener getCreateConversationResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response); + return conversationResponse; + }); + return actionListener; + } + private ActionListener wrapActionListener( final ActionListener listener, final Function recreate diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 776aefd2cf..8dc8229fc1 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -49,7 +49,6 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; -import org.opensearch.ml.common.transport.config.MLConfigGetResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -59,6 +58,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; public class MachineLearningClientTest { @@ -107,7 +107,7 @@ public class MachineLearningClientTest { MLRegisterAgentResponse registerAgentResponse; @Mock - MLConfigGetResponse configGetResponse; + CreateConversationResponse createConversationResponse; private final String modekId = "test_model_id"; private MLModel mlModel; @@ -256,6 +256,11 @@ public void deleteAgent(String agentId, String tenantId, ActionListener listener) { listener.onResponse(mlConfig); } + + @Override + public void createConversation(String name, ActionListener listener) { + listener.onResponse(createConversationResponse); + } }; } @@ -554,4 +559,9 @@ public void listTools() { public void getConfig() { assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet()); } + + @Test + public void createConversation() { + assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 76df1d9a8c..fd75639987 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -141,6 +141,9 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -219,6 +222,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener getMlConfigListener; + @Mock + ActionListener createConversationResponseActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -1453,6 +1459,25 @@ public void onFailure(Exception e) { }); verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any()); + } + + public void createConversation() { + String name = "Conversation for a RAG pipeline"; + String conversationId = "conversationId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + CreateConversationResponse output = new CreateConversationResponse(conversationId); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); + machineLearningNodeClient.createConversation(name, createConversationResponseActionListener); + + verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any()); + verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(conversationId, argumentCaptor.getValue().getId()); } private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java index 79f6fb6bf0..9ba60558a0 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java @@ -17,15 +17,21 @@ */ package org.opensearch.ml.memory.action.conversation; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import lombok.AllArgsConstructor; @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } + public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateConnectorResponse) { + return (CreateConversationResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new CreateConversationResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e); + } + + } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java index 542a1f652d..6a0a9fceab 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java @@ -17,10 +17,16 @@ */ package org.opensearch.ml.memory.action.conversation; +import static org.junit.Assert.assertEquals; + import java.io.IOException; +import java.io.UncheckedIOException; +import org.junit.Before; +import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.BytesStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -32,6 +38,13 @@ public class CreateConversationResponseTests extends OpenSearchTestCase { + CreateConversationResponse response; + + @Before + public void setup() { + response = new CreateConversationResponse("test-id"); + } + public void testCreateConversationResponseStreaming() throws IOException { CreateConversationResponse response = new CreateConversationResponse("test-id"); assert (response.getId().equals("test-id")); @@ -51,4 +64,34 @@ public void testToXContent() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); assert (result.equals(expected)); } + + @Test + public void fromActionResponseWithCreateConversationResponseSuccess() { + CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(response); + assertEquals(response.getId(), responseFromActionResponse.getId()); + } + + @Test + public void fromActionResponseSuccess() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + response.writeTo(out); + } + }; + CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(actionResponse); + assertNotSame(response, responseFromActionResponse); + assertEquals(response.getId(), responseFromActionResponse.getId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponseIOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + CreateConversationResponse.fromActionResponse(actionResponse); + } }