Skip to content

Commit

Permalink
Added Conversation API in MLClient
Browse files Browse the repository at this point in the history
Signed-off-by: Owais <[email protected]>
  • Loading branch information
owaiskazi19 committed Feb 3, 2025
1 parent 86ae88b commit 8f48aee
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 2 deletions.
1 change: 1 addition & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ plugins {
dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-memory")
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -553,4 +554,22 @@ default void getConfig(String configId, ActionListener<MLConfig> listener) {
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @return the result future
*/
default ActionFuture<CreateConversationResponse> createConversation(String name) {
PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
createConversation(name, actionFuture);
return actionFuture;
}

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @param listener action listener
*/
void createConversation(String name, ActionListener<CreateConversationResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -318,6 +321,11 @@ public void getConfig(String configId, String tenantId, ActionListener<MLConfig>
client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}

public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down Expand Up @@ -386,6 +394,16 @@ private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseAction
return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse);
}

private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
ActionListener<CreateConversationResponse> listener
) {
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
return conversationResponse;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -59,6 +58,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -107,7 +107,7 @@ public class MachineLearningClientTest {
MLRegisterAgentResponse registerAgentResponse;

@Mock
MLConfigGetResponse configGetResponse;
CreateConversationResponse createConversationResponse;

private final String modekId = "test_model_id";
private MLModel mlModel;
Expand Down Expand Up @@ -256,6 +256,11 @@ public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteRe
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}

@Override
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
listener.onResponse(createConversationResponse);
}
};
}

Expand Down Expand Up @@ -554,4 +559,9 @@ public void listTools() {
public void getConfig() {
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
}

@Test
public void createConversation() {
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -219,6 +222,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLConfig> getMlConfigListener;

@Mock
ActionListener<CreateConversationResponse> createConversationResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -1453,6 +1459,25 @@ public void onFailure(Exception e) {
});

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
}

public void createConversation() {
String name = "Conversation for a RAG pipeline";
String conversationId = "conversationId";

doAnswer(invocation -> {
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
CreateConversationResponse output = new CreateConversationResponse(conversationId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());

ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);

verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(conversationId, argumentCaptor.getValue().getId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import lombok.AllArgsConstructor;

Expand Down Expand Up @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
return builder;
}

public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCreateConnectorResponse) {
return (CreateConversationResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new CreateConversationResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
*/
package org.opensearch.ml.memory.action.conversation;

import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -32,6 +38,13 @@

public class CreateConversationResponseTests extends OpenSearchTestCase {

CreateConversationResponse response;

@Before
public void setup() {
response = new CreateConversationResponse("test-id");
}

public void testCreateConversationResponseStreaming() throws IOException {
CreateConversationResponse response = new CreateConversationResponse("test-id");
assert (response.getId().equals("test-id"));
Expand All @@ -51,4 +64,34 @@ public void testToXContent() throws IOException {
String result = BytesReference.bytes(builder).utf8ToString();
assert (result.equals(expected));
}

@Test
public void fromActionResponseWithCreateConversationResponseSuccess() {
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(response);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test
public void fromActionResponseSuccess() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(actionResponse);
assertNotSame(response, responseFromActionResponse);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionResponseIOException() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
CreateConversationResponse.fromActionResponse(actionResponse);
}
}

0 comments on commit 8f48aee

Please sign in to comment.