Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
owaiskazi19 authored Feb 3, 2025
2 parents bfff2b4 + 9e014fa commit 86ae88b
Show file tree
Hide file tree
Showing 15 changed files with 2,301 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ default void deleteModel(String modelId, ActionListener<DeleteResponse> listener
*/
default ActionFuture<DeleteResponse> deleteTask(String taskId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteModel(taskId, actionFuture);
deleteTask(taskId, actionFuture);
return actionFuture;
}

Expand Down Expand Up @@ -361,7 +361,7 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param nodeIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
Expand All @@ -372,7 +372,7 @@ default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUnde
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param nodeIds the node ids. May be null for all nodes.
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
Expand Down Expand Up @@ -480,8 +480,7 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
* @param listener a listener to be notified of the result
*/
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, null, actionFuture);
deleteAgent(agentId, null, listener);
}

/**
Expand Down Expand Up @@ -543,5 +542,15 @@ default ActionFuture<MLConfig> getConfig(String configId) {
* @param configId ML config id
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, ActionListener<MLConfig> listener);
default void getConfig(String configId, ActionListener<MLConfig> listener) {
getConfig(configId, null, listener);
}

/**
* Delete agent
* @param configId ML config id
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).tenantId(tenantId).build();

client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ public void setUp() {
.build();

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
}

@Override
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand All @@ -169,21 +165,11 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
listener.onResponse(output);
}

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand All @@ -194,21 +180,11 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
listener.onResponse(searchResponse);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand All @@ -224,21 +200,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
listener.onResponse(registerModelResponse);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
Expand All @@ -259,11 +225,6 @@ public void deleteConnector(String connectorId, String tenantId, ActionListener<
listener.onResponse(deleteResponse);
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void listTools(ActionListener<List<ToolMetadata>> listener) {
listener.onResponse(toolsList);
Expand All @@ -286,18 +247,13 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
listener.onResponse(registerAgentResponse);
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ public void deleteTask() {
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);
machineLearningNodeClient.deleteTask(taskId, null, deleteTaskActionListener);

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
verify(deleteTaskActionListener).onResponse(argumentCaptor.capture());
Expand Down Expand Up @@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() {
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
}

@Test
public void predict_withTenantId() {
String tenantId = "testTenant";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput
.builder()
.status("Success")
.predictionResult(output)
.taskId("taskId")
.build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

ArgumentCaptor<MLPredictionTaskRequest> requestCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict("modelId", tenantId, mlInput, dataFrameActionListener);

verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), requestCaptor.capture(), any());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
assertEquals("modelId", requestCaptor.getValue().getModelId());
}

@Test
public void getTask_withFailure() {
String taskId = "taskId";
String errorMessage = "Task not found";

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new IllegalArgumentException(errorMessage));
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

machineLearningNodeClient.getTask(taskId, new ActionListener<>() {
@Override
public void onResponse(MLTask mlTask) {
fail("Expected failure but got success");
}

@Override
public void onFailure(Exception e) {
assertEquals(errorMessage, e.getMessage());
}
});

verify(client).execute(eq(MLTaskGetAction.INSTANCE), isA(MLTaskGetRequest.class), any());
}

@Test
public void deploy_withTenantId() {
String modelId = "testModel";
String tenantId = "testTenant";
String taskId = "taskId";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelRequest> requestCaptor = ArgumentCaptor.forClass(MLDeployModelRequest.class);
machineLearningNodeClient.deploy(modelId, tenantId, deployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), requestCaptor.capture(), any());
assertEquals(modelId, requestCaptor.getValue().getModelId());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
}

@Test
public void trainAndPredict_withNullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("ML Input can't be null");

machineLearningNodeClient.trainAndPredict(null, trainingActionListener);
}

@Test
public void trainAndPredict_withNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data set can't be null");

MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
machineLearningNodeClient.trainAndPredict(mlInput, trainingActionListener);
}

@Test
public void getTask_withTaskIdAndTenantId() {
String taskId = "taskId";
String tenantId = "testTenant";
String modelId = "modelId";

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build();
MLTaskGetResponse output = MLTaskGetResponse.builder().mlTask(mlTask).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLTaskGetRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class);
ArgumentCaptor<MLTask> taskCaptor = ArgumentCaptor.forClass(MLTask.class);

machineLearningNodeClient.getTask(taskId, tenantId, getTaskActionListener);

verify(client).execute(eq(MLTaskGetAction.INSTANCE), requestCaptor.capture(), any());
verify(getTaskActionListener).onResponse(taskCaptor.capture());

// Verify request parameters
assertEquals(taskId, requestCaptor.getValue().getTaskId());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());

// Verify response
assertEquals(taskId, taskCaptor.getValue().getTaskId());
assertEquals(modelId, taskCaptor.getValue().getModelId());
assertEquals(FunctionName.KMEANS, taskCaptor.getValue().getFunctionName());
}

@Test
public void deleteTask_withTaskId() {
String taskId = "taskId";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, taskId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<MLTaskDeleteRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class);
ArgumentCaptor<DeleteResponse> responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class);

machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), requestCaptor.capture(), any());
verify(deleteTaskActionListener).onResponse(responseCaptor.capture());

// Verify request parameter
assertEquals(taskId, requestCaptor.getValue().getTaskId());

// Verify response
assertEquals(taskId, responseCaptor.getValue().getId());
assertEquals("DELETED", responseCaptor.getValue().getResult().toString());
}

@Test
public void deleteTask_withFailure() {
String taskId = "taskId";
String errorMessage = "Task deletion failed";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException(errorMessage));
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

machineLearningNodeClient.deleteTask(taskId, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
fail("Expected failure but got success");
}

@Override
public void onFailure(Exception e) {
assertEquals(errorMessage, e.getMessage());
}
});

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies {
testImplementation "org.opensearch.test:framework:${opensearch_version}"

compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
compileOnly group: 'org.json', name: 'json', version: '20231013'
testImplementation group: 'org.json', name: 'json', version: '20231013'
implementation('com.google.guava:guava:32.1.3-jre') {
Expand Down
Loading

0 comments on commit 86ae88b

Please sign in to comment.