Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.18] Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs (#125023) #125416

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125023.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125023
summary: Fix `AlibabaCloudSearchCompletionAction` not accepting `ChatCompletionInputs`
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
Expand All @@ -51,18 +49,8 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
listener.onFailure(
new ElasticsearchStatusException(
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
RestStatus.INTERNAL_SERVER_ERROR
)
);
return;
}

var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() % 2 == 0) {
var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
if (completionInput.getInputs().size() % 2 == 0) {
listener.onFailure(
new ElasticsearchStatusException(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;

public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {

private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;

@Before
public void init() throws IOException {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}

@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}

public void testExecute_Success() {
var sender = mock(Sender.class);

var resultString = randomAlphaOfLength(100);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));

return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
}

public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("error"));
}

public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
var sender = mock(Sender.class);
doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
var action = createAction(threadPool, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));
}

public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
var action = createAction(threadPool, mock(Sender.class));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
assertThrows(IllegalArgumentException.class, () -> {
action.execute(new DocumentsOnlyInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
});
}

public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
var action = createAction(threadPool, mock(Sender.class));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(
new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response."
)
);
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}

private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
var model = AlibabaCloudSearchCompletionModelTests.createModel(
"completion_test",
TaskType.COMPLETION,
AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
getSecretSettingsMap("secret")
);

var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
}
}