From 6dcbd003c9af16c6cc468ca6ec09238f1babe8eb Mon Sep 17 00:00:00 2001 From: br3no Date: Thu, 25 Apr 2024 21:03:07 +0200 Subject: [PATCH 01/10] adding support for asymmetric embedding models Signed-off-by: br3no --- .../ml/MLCommonsClientAccessor.java | 326 ++++++++++++++---- .../processor/TextEmbeddingProcessor.java | 15 +- .../query/NeuralQueryBuilder.java | 15 +- .../ml/MLCommonsClientAccessorTests.java | 125 ++++++- .../processor/TextEmbeddingProcessorIT.java | 221 +++--------- .../TextEmbeddingProcessorTests.java | 24 +- .../query/NeuralQueryBuilderTests.java | 8 +- .../UploadAsymmetricModelRequestBody.json | 17 + 8 files changed, 489 insertions(+), 262 deletions(-) create mode 100644 src/test/resources/processor/UploadAsymmetricModelRequestBody.json diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index f9ddf73a9..5a663724a 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -11,16 +11,24 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import java.util.stream.Collectors; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.Nullable; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.output.model.ModelTensor; @@ -40,6 +48,7 @@ public class MLCommonsClientAccessor { private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); private final MachineLearningNodeClient mlClient; + private final Map modelAsymmetryCache = new ConcurrentHashMap<>(); /** * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating @@ -54,7 +63,29 @@ public void inferenceSentence( @NonNull final String inputText, @NonNull final ActionListener> listener ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { + inferenceSentence(modelId, inputText, null, listener); + } + + /** + * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating + * point vector as a response. Supports passing {@link MLAlgoParams} to the inference. If the model is + * asymmetric, passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} is + * mandatory. This method will check whether the model being used is asymmetric and correctly handle the + * parameter, so it's okay to always pass the parameter (even if the model is symmetric). + * + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + */ + public void inferenceSentence( + @NonNull final String modelId, + @NonNull final String inputText, + @Nullable final MLAlgoParams mlAlgoParams, + @NonNull final ActionListener> listener + ) { + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), mlAlgoParams, ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( new IllegalStateException( @@ -69,43 +100,98 @@ public void inferenceSentence( } /** - * Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we - * need to run only TextEmbedding tasks only. + * Abstraction to call predict function of api of MLClient with default targetResponse filters. It + * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The + * return will be sent using the actionListener which will have a {@link List} of {@link List} of + * {@link Float} in the order of inputText. We are not making this function generic enough to take + * any function or TaskType as currently we need to run only TextEmbedding tasks only. * - * @param modelId {@link String} + * @param modelId {@link String} * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * @param listener {@link ActionListener} which will be called when prediction is completed or + * errored out */ public void inferenceSentences( @NonNull final String modelId, @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener); + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, null, listener); } /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we - * need to run only TextEmbedding tasks only. + * Abstraction to call predict function of api of MLClient with default targetResponse filters. It + * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The + * return will be sent using the actionListener which will have a {@link List} of {@link List} of + * {@link Float} in the order of inputText. We are not making this function generic enough to take + * any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports + * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} + * is mandatory. This method will check whether the model being used is asymmetric and correctly + * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses * @param modelId {@link String} * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is completed or + */ + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final List inputText, + @Nullable final MLAlgoParams mlAlgoParams, + @NonNull final ActionListener>> listener + ) { + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, mlAlgoParams, listener); + } + + /** + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a {@link List} of {@link List} + * of {@link Float} in the order of inputText. We are not making this function generic enough to + * take any function or TaskType as currently we need to run only TextEmbedding tasks only. + * + * @param targetResponseFilters {@link List} of {@link String} which filters out the responses + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. + */ + public void inferenceSentences( + @NonNull final List targetResponseFilters, + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener + ) { + inferenceSentences(targetResponseFilters, modelId, inputText, null, listener); + } + + /** + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a {@link List} of {@link List} + * of {@link Float} in the order of inputText. We are not making this function generic enough to + * take any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports + * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} + * is mandatory. This method will check whether the model being used is asymmetric and correctly + * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). + * + * @param targetResponseFilters {@link List} of {@link String} which filters out the responses + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. */ public void inferenceSentences( @NonNull final List targetResponseFilters, @NonNull final String modelId, @NonNull final List inputText, + @Nullable final MLAlgoParams mlAlgoParams, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, mlAlgoParams, 0, listener); } public void inferenceSentencesWithMapResult( @@ -113,35 +199,65 @@ public void inferenceSentencesWithMapResult( @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); + retryableInferenceSentencesWithMapResult(modelId, inputText, null, 0, listener); } /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a list of floats in the order of inputText. + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a list of floats in the order + * of inputText. * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + * @param modelId {@link String} + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to + * happen + * @param listener {@link ActionListener} which will be called when prediction is completed or + * errored out. */ public void inferenceSentences( @NonNull final String modelId, @NonNull final Map inputObjects, @NonNull final ActionListener> listener ) { - retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + inferenceSentences(modelId, inputObjects, null, listener); } /** - * Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the - * {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing - * the similarity scores of the texts w.r.t. the query text, in the order of the input texts. + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a list of floats in the order + * of inputText. Supports passing {@link MLAlgoParams} to the inference. If the model is asymmetric, + * passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} + * is mandatory. This method will check whether the model being used is asymmetric and correctly + * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). * - * @param modelId {@link String} ML-Commons Model Id + * @param modelId {@link String} + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to + * happen + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is completed or + * errored out. + */ + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final Map inputObjects, + @Nullable final MLAlgoParams mlAlgoParams, + @NonNull final ActionListener> listener + ) { + retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, null, 0, listener); + } + + /** + * Abstraction to call predict function of api of MLClient. It uses the custom model provided as + * modelId and the {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via + * actionListener as a list of floats representing the similarity scores of the texts w.r.t. the + * query text, in the order of the input texts. + * + * @param modelId {@link String} ML-Commons Model Id * @param queryText {@link String} The query to compare all the inputText to * @param inputText {@link List} of {@link String} The texts to compare to the query - * @param listener {@link ActionListener} receives the result of the inference + * @param listener {@link ActionListener} receives the result of the inference */ public void inferenceSimilarity( @NonNull final String modelId, @@ -155,42 +271,95 @@ public void inferenceSimilarity( private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLTextInput(null, inputText); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> result = buildMapResultFromResponse(mlOutput); - listener.onResponse(result); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + + Consumer runPrediction = isAsymmetricModel -> { + MLInput mlInput = createMLTextInput(null, inputText, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> result = buildMapResultFromResponse(mlOutput); + listener.onResponse(result); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithMapResult(modelId, inputText, mlAlgoParams, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, runPrediction); } private void retryableInferenceSentencesWithVectorResult( final List targetResponseFilters, final String modelId, final List inputText, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> vector = buildVectorFromResponse(mlOutput); - listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); + + Consumer runPrediction = isAsymmetricModel -> { + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> vector = buildVectorFromResponse(mlOutput); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithVectorResult( + targetResponseFilters, + modelId, + inputText, + mlAlgoParams, + retryTimeAdd, + listener + ); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, runPrediction); + } + + /** + * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept + * that is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then + * this check is not applicable. + * + * The asymmetry of a model is static for a given model. To avoid repeated checks for the same + * model, we cache the model asymmetry status. Non-TextEmbeddingModels are cached as false. + * + * @param modelId The model id to check + * @param onFailure The action to take if the model cannot be retrieved + * @param runPrediction The action to take if the model is successfully retrieved + */ + private void checkModelAsymmetryAndThenPredict(String modelId, Consumer onFailure, Consumer runPrediction) { + CheckedConsumer checkModelAsymmetryListener = model -> { + MLModelConfig modelConfig = model.getModelConfig(); + if (!(modelConfig instanceof TextEmbeddingModelConfig)) { + modelAsymmetryCache.putIfAbsent(modelId, false); + return; } - })); + final TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig; + final boolean isAsymmetricModel = textEmbeddingModelConfig.getPassagePrefix() != null + || textEmbeddingModelConfig.getQueryPrefix() != null; + modelAsymmetryCache.putIfAbsent(modelId, isAsymmetricModel); + }; + if (modelAsymmetryCache.containsKey(modelId)) { + runPrediction.accept(modelAsymmetryCache.get(modelId)); + } else { + mlClient.getModel(modelId, ActionListener.wrap(mlModel -> { + checkModelAsymmetryListener.accept(mlModel); + runPrediction.accept(modelAsymmetryCache.get(modelId)); + }, onFailure)); + } } private void retryableInferenceSimilarityWithVectorResult( @@ -213,10 +382,10 @@ private void retryableInferenceSimilarityWithVectorResult( })); } - private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { + private MLInput createMLTextInput(final List targetResponseFilters, List inputText, MLAlgoParams mlAlgoParams) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + return new MLInput(FunctionName.TEXT_EMBEDDING, mlAlgoParams, inputDataset); } private MLInput createMLTextPairsInput(final String query, final List inputText) { @@ -264,25 +433,42 @@ private void retryableInferenceSentencesWithSingleVectorResult( final List targetResponseFilters, final String modelId, final Map inputObjects, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener> listener ) { - MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List vector = buildSingleVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence is : {} ", vector); - listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + + Consumer predictConsumer = isAsymmetricModel -> { + MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List vector = buildSingleVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence is : {} ", vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithSingleVectorResult( + targetResponseFilters, + modelId, + inputObjects, + mlAlgoParams, + retryTimeAdd, + listener + ); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, predictConsumer); } - private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { + private MLInput createMLMultimodalInput( + final List targetResponseFilters, + final Map input, + MLAlgoParams mlAlgoParams + ) { List inputText = new ArrayList<>(); inputText.add(input.get(INPUT_TEXT)); if (input.containsKey(INPUT_IMAGE)) { @@ -290,6 +476,6 @@ private MLInput createMLMultimodalInput(final List targetResponseFilters } final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + return new MLInput(FunctionName.TEXT_EMBEDDING, mlAlgoParams, inputDataset); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c8f9f080d..23dc6af49 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -13,6 +13,8 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; @@ -47,10 +49,15 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentences( + this.modelId, + inferenceList, + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 915a79117..c8b2a1d4a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -44,6 +44,8 @@ import org.opensearch.knn.index.query.parser.RescoreParser; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.common.MinClusterVersionUtil; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import com.google.common.annotations.VisibleForTesting; @@ -333,10 +335,15 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { inferenceInput.put(INPUT_IMAGE, queryImage()); } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { - vectorSetOnce.set(vectorAsListToArray(floatList)); - actionListener.onResponse(null); - }, actionListener::onFailure))) + ((client, actionListener) -> ML_CLIENT.inferenceSentences( + modelId(), + inferenceInput, + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), + ActionListener.wrap(floatList -> { + vectorSetOnce.set(vectorAsListToArray(floatList)); + actionListener.onResponse(null); + }, actionListener::onFailure) + )) ); return new NeuralQueryBuilder( fieldName(), diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3749e63dc..3c0376909 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -23,7 +23,11 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; @@ -59,8 +63,14 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener); + accessor.inferenceSentence( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST.get(0), + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -68,6 +78,19 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } + private void setupMocksForTextEmbeddingModelAsymmetryCheck(boolean isAsymmetric) { + MLModel modelMock = mock(MLModel.class); + TextEmbeddingModelConfig textEmbeddingModelConfigMock = mock(TextEmbeddingModelConfig.class); + Mockito.when(textEmbeddingModelConfigMock.getPassagePrefix()).thenReturn(isAsymmetric ? "passage: " : null); + Mockito.when(textEmbeddingModelConfigMock.getQueryPrefix()).thenReturn(isAsymmetric ? "query: " : null); + Mockito.when(modelMock.getModelConfig()).thenReturn(textEmbeddingModelConfigMock); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(modelMock); + return null; + }).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + } + public void testInferenceSentences_whenValidInputThenSuccess() { final List> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -76,6 +99,8 @@ public void testInferenceSentences_whenValidInputThenSuccess() { actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -92,6 +117,8 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { actionListener.onResponse(createModelTensorOutput(new Float[] {})); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -107,6 +134,8 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { actionListener.onFailure(exception); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( TestCommonConstants.TARGET_RESPONSE_FILTERS, TestCommonConstants.MODEL_ID, @@ -130,6 +159,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( TestCommonConstants.TARGET_RESPONSE_FILTERS, TestCommonConstants.MODEL_ID, @@ -149,6 +181,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { actionListener.onFailure(illegalStateException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( TestCommonConstants.TARGET_RESPONSE_FILTERS, TestCommonConstants.MODEL_ID, @@ -161,6 +196,62 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(true); + + accessor.inferenceSentence( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST.get(0), + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict( + Mockito.eq(TestCommonConstants.MODEL_ID), + Mockito.argThat((MLInput input) -> input.getParameters() != null), + Mockito.isA(ActionListener.class) + ); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSentences_whenGetModelException_thenFailure() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + RuntimeException exception = new RuntimeException("Bam!"); + setupMocksForTextEmbeddingModelAsymmetryCheck(exception); + + accessor.inferenceSentence( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST.get(0), + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + singleSentenceResultListener + ); + + Mockito.verify(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + private void setupMocksForTextEmbeddingModelAsymmetryCheck(Exception exception) { + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(exception); + return null; + }).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + } + public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); @@ -169,6 +260,9 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { actionListener.onResponse(createModelTensorOutput(map)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -185,6 +279,9 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -209,6 +306,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -236,6 +336,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -255,6 +358,9 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client, times(4)) @@ -270,6 +376,9 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client, times(1)) @@ -285,6 +394,8 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); Mockito.verify(client) @@ -300,6 +411,9 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { actionListener.onFailure(exception); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); Mockito.verify(client) @@ -318,6 +432,9 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); Mockito.verify(client, times(4)) @@ -333,6 +450,8 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( TestCommonConstants.MODEL_ID, "is it sunny", @@ -354,6 +473,8 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( TestCommonConstants.MODEL_ID, "is it sunny", @@ -378,6 +499,8 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( TestCommonConstants.MODEL_ID, "is it sunny", diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 4afa4031d..a596e8618 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -6,32 +6,24 @@ import java.io.IOException; import java.net.URISyntaxException; -import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; import org.apache.commons.lang3.StringUtils; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; -import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.collect.ImmutableList; -import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { @@ -66,7 +58,9 @@ public void setUp() throws Exception { public void testTextEmbeddingProcessor() throws Exception { String modelId = null; try { - modelId = uploadTextEmbeddingModel(); + modelId = uploadTextEmbeddingModel( + Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())) + ); loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); createTextEmbeddingIndex(); @@ -77,179 +71,70 @@ public void testTextEmbeddingProcessor() throws Exception { } } - public void testTextEmbeddingProcessor_batch() throws Exception { - String modelId = null; - try { - modelId = uploadTextEmbeddingModel(); - loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); - createTextEmbeddingIndex(); - ingestBatchDocumentWithBulk("batch_", 2, Collections.emptySet(), Collections.emptySet()); - assertEquals(2, getDocCount(INDEX_NAME)); - - ingestDocument(String.format(LOCALE, INGEST_DOC1, "success"), "1"); - ingestDocument(String.format(LOCALE, INGEST_DOC2, "success"), "2"); - - assertEquals(getDocById(INDEX_NAME, "1").get("_source"), getDocById(INDEX_NAME, "batch_1").get("_source")); - assertEquals(getDocById(INDEX_NAME, "2").get("_source"), getDocById(INDEX_NAME, "batch_2").get("_source")); - } finally { - wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); - } - } - - public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws Exception { - String modelId = null; - try { - modelId = uploadTextEmbeddingModel(); - loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING); - createTextEmbeddingIndex(); - ingestDocument(INGEST_DOC3, "3"); - ingestDocument(INGEST_DOC4, "4"); - - assertDoc( - (Map) getDocById(INDEX_NAME, "3").get("_source"), - TEXT_FIELD_VALUE_1, - Optional.of(TEXT_FIELD_VALUE_3) - ); - assertDoc((Map) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_2, Optional.empty()); - - NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder( - LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING, - QUERY_TEXT, - "", - modelId, - 10, - null, - null, - null, - null, - null, - null - ); - QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( - LEVEL_1_FIELD + "." + LEVEL_2_FIELD, - neuralQueryBuilderQuery, - ScoreMode.Total - ); - QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); - - Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2); - assertNotNull(searchResponseAsMap); - - Map hits = (Map) searchResponseAsMap.get("hits"); - assertNotNull(hits); - - assertEquals(1.0, hits.get("max_score")); - List> listOfHits = (List>) hits.get("hits"); - assertNotNull(listOfHits); - assertEquals(2, listOfHits.size()); - - Map innerHitDetails = listOfHits.get(0); - assertEquals("3", innerHitDetails.get("_id")); - assertEquals(1.0, innerHitDetails.get("_score")); - - innerHitDetails = listOfHits.get(1); - assertEquals("4", innerHitDetails.get("_id")); - assertTrue((double) innerHitDetails.get("_score") <= 1.0); - } finally { - wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); - } - } - - private void assertDoc(Map sourceMap, String textFieldValue, Optional level3ExpectedValue) { - assertNotNull(sourceMap); - assertTrue(sourceMap.containsKey(LEVEL_1_FIELD)); - Map nestedPassages = (Map) sourceMap.get(LEVEL_1_FIELD); - assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD)); - Map level2 = (Map) nestedPassages.get(LEVEL_2_FIELD); - assertEquals(textFieldValue, level2.get(LEVEL_3_FIELD_TEXT)); - Map level3 = (Map) level2.get(LEVEL_3_FIELD_CONTAINER); - List embeddings = (List) level3.get(LEVEL_3_FIELD_EMBEDDING); - assertEquals(768, embeddings.size()); - for (Double embedding : embeddings) { - assertTrue(embedding >= 0.0 && embedding <= 1.0); - } - if (level3ExpectedValue.isPresent()) { - assertEquals(level3ExpectedValue.get(), level3.get("level_4_text_field")); - } + private String uploadTextEmbeddingModel(String requestBody) throws Exception { + return registerModelGroupAndUploadModel(requestBody); } - public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception { - String modelId = null; - try { - modelId = uploadTextEmbeddingModel(); - loadModel(modelId); - URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json"); - Objects.requireNonNull(pipelineURLPath); - String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); - createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null); - createTextEmbeddingIndex(); - int docCount = 5; - ingestBatchDocumentWithBulk("batch_", docCount, Collections.emptySet(), Collections.emptySet()); - assertEquals(5, getDocCount(INDEX_NAME)); - - for (int i = 0; i < docCount; ++i) { - String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); - String payload = String.format(LOCALE, template, "success"); - ingestDocument(payload, String.valueOf(i + 1)); - } - - for (int i = 0; i < docCount; ++i) { - assertEquals( - getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"), - getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source") - ); - - } - } finally { - wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); - } + private void createTextEmbeddingIndex() throws Exception { + createIndexWithConfiguration( + INDEX_NAME, + Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), + PIPELINE_NAME + ); } - public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception { + public void testAsymmetricTextEmbeddingProcessor() throws Exception { String modelId = null; try { - modelId = uploadTextEmbeddingModel(); + modelId = uploadTextEmbeddingModel( + Files.readString(Path.of(classLoader.getResource("processor/UploadAsymmetricModelRequestBody.json").toURI())) + ); loadModel(modelId); - URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json"); - Objects.requireNonNull(pipelineURLPath); - String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); - createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); createTextEmbeddingIndex(); - int docCount = 5; - ingestBatchDocumentWithBulk("batch_", docCount, Set.of(0), Set.of(1)); - assertEquals(3, getDocCount(INDEX_NAME)); - - for (int i = 2; i < docCount; ++i) { - String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); - String payload = String.format(LOCALE, template, "success"); - ingestDocument(payload, String.valueOf(i + 1)); - } - - for (int i = 2; i < docCount; ++i) { - assertEquals( - getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"), - getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source") - ); - - } + ingestDocument(); + assertEquals(1, getDocCount(INDEX_NAME)); } finally { wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); } } - private String uploadTextEmbeddingModel() throws Exception { - String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); - return registerModelGroupAndUploadModel(requestBody); - } - - private void createTextEmbeddingIndex() throws Exception { - createIndexWithConfiguration( - INDEX_NAME, - Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), - PIPELINE_NAME + private void ingestDocument() throws Exception { + String ingestDocument = "{\n" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"favor_list\": [\n" + + " \"test\",\n" + + " \"hello\",\n" + + " \"mock\"\n" + + " ],\n" + + " \"favorites\": {\n" + + " \"game\": \"overwatch\",\n" + + " \"movie\": null\n" + + " },\n" + + " \"nested_passages\": [\n" + + " {\n" + + " \"text\": \"hello\"\n" + + " },\n" + + " {\n" + + " \"text\": \"world\"\n" + + " }\n" + + " ]\n" + + "}\n"; + Response response = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals("created", map.get("result")); } private void ingestDocument(String doc, String id) throws Exception { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 97e85e46e..925bd01be 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -44,6 +44,7 @@ import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; @@ -151,10 +152,10 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -184,7 +185,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -230,10 +232,10 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -306,10 +308,10 @@ public void testExecute_withMapTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -347,10 +349,10 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -585,10 +587,10 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 6d8e810f3..5efbf3869 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -649,10 +649,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(3); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -685,10 +685,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(3); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/resources/processor/UploadAsymmetricModelRequestBody.json b/src/test/resources/processor/UploadAsymmetricModelRequestBody.json new file mode 100644 index 000000000..8c5b6ec18 --- /dev/null +++ b/src/test/resources/processor/UploadAsymmetricModelRequestBody.json @@ -0,0 +1,17 @@ +{ + "name": "traced_small_model", + "version": "1.0.0", + "model_format": "TORCH_SCRIPT", + "model_task_type": "text_embedding", + "model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021", + "model_group_id": "%s", + "model_config": { + "model_type": "bert", + "embedding_dimension": 768, + "framework_type": "sentence_transformers", + "passage_prefix" : "passage: ", + "query_prefix" : "query: ", + "all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}" + }, + "url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true" +} From 760c1dbb5a2a0de72909a237ee4cff0d02b02c88 Mon Sep 17 00:00:00 2001 From: br3no Date: Fri, 26 Apr 2024 09:08:55 +0200 Subject: [PATCH 02/10] adding changelog entry Signed-off-by: br3no --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..c72b87e0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Add support for asymmetric embedding models ([#710](https://github.com/opensearch-project/neural-search/pull/710)) ### Enhancements ### Bug Fixes ### Infrastructure From e819b139dcad70b9225207d4e9998564ec544dc7 Mon Sep 17 00:00:00 2001 From: br3no Date: Fri, 26 Apr 2024 19:17:00 +0200 Subject: [PATCH 03/10] missing paramter Signed-off-by: br3no --- .../org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 5a663724a..a1a7a1601 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -245,7 +245,7 @@ public void inferenceSentences( @Nullable final MLAlgoParams mlAlgoParams, @NonNull final ActionListener> listener ) { - retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, null, 0, listener); + retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, mlAlgoParams, 0, listener); } /** From 9a920a010d5b0ca7240819bdb57973877c2b249c Mon Sep 17 00:00:00 2001 From: br3no Date: Tue, 22 Oct 2024 18:32:31 +0200 Subject: [PATCH 04/10] Adapt new tests to asymmetric model inference Signed-off-by: br3no --- .../processor/TextEmbeddingProcessorTests.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 925bd01be..1d83c8c95 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -409,10 +409,10 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -467,10 +467,10 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -518,10 +518,10 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); From 2932b840dc739c8c8b06cd6a201d057c99c81dfd Mon Sep 17 00:00:00 2001 From: br3no Date: Tue, 22 Oct 2024 18:50:27 +0200 Subject: [PATCH 05/10] revert accidental removal of tests from Signed-off-by: br3no --- .../processor/TextEmbeddingProcessorIT.java | 183 +++++++++++++++++- 1 file changed, 175 insertions(+), 8 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index a596e8618..1d620e697 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -6,24 +6,32 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Set; import org.apache.commons.lang3.StringUtils; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.collect.ImmutableList; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { @@ -58,9 +66,7 @@ public void setUp() throws Exception { public void testTextEmbeddingProcessor() throws Exception { String modelId = null; try { - modelId = uploadTextEmbeddingModel( - Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())) - ); + modelId = uploadTextEmbeddingModel(); loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); createTextEmbeddingIndex(); @@ -71,7 +77,170 @@ public void testTextEmbeddingProcessor() throws Exception { } } - private String uploadTextEmbeddingModel(String requestBody) throws Exception { + public void testTextEmbeddingProcessor_batch() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); + createTextEmbeddingIndex(); + ingestBatchDocumentWithBulk("batch_", 2, Collections.emptySet(), Collections.emptySet()); + assertEquals(2, getDocCount(INDEX_NAME)); + + ingestDocument(String.format(LOCALE, INGEST_DOC1, "success"), "1"); + ingestDocument(String.format(LOCALE, INGEST_DOC2, "success"), "2"); + + assertEquals(getDocById(INDEX_NAME, "1").get("_source"), getDocById(INDEX_NAME, "batch_1").get("_source")); + assertEquals(getDocById(INDEX_NAME, "2").get("_source"), getDocById(INDEX_NAME, "batch_2").get("_source")); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING); + createTextEmbeddingIndex(); + ingestDocument(INGEST_DOC3, "3"); + ingestDocument(INGEST_DOC4, "4"); + + assertDoc( + (Map) getDocById(INDEX_NAME, "3").get("_source"), + TEXT_FIELD_VALUE_1, + Optional.of(TEXT_FIELD_VALUE_3) + ); + assertDoc((Map) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_2, Optional.empty()); + + NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING, + QUERY_TEXT, + "", + modelId, + 10, + null, + null, + null, + null, + null, + null + ); + QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD, + neuralQueryBuilderQuery, + ScoreMode.Total + ); + QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); + + Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2); + assertNotNull(searchResponseAsMap); + + Map hits = (Map) searchResponseAsMap.get("hits"); + assertNotNull(hits); + + assertEquals(1.0, hits.get("max_score")); + List> listOfHits = (List>) hits.get("hits"); + assertNotNull(listOfHits); + assertEquals(2, listOfHits.size()); + + Map innerHitDetails = listOfHits.get(0); + assertEquals("3", innerHitDetails.get("_id")); + assertEquals(1.0, innerHitDetails.get("_score")); + + innerHitDetails = listOfHits.get(1); + assertEquals("4", innerHitDetails.get("_id")); + assertTrue((double) innerHitDetails.get("_score") <= 1.0); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + private void assertDoc(Map sourceMap, String textFieldValue, Optional level3ExpectedValue) { + assertNotNull(sourceMap); + assertTrue(sourceMap.containsKey(LEVEL_1_FIELD)); + Map nestedPassages = (Map) sourceMap.get(LEVEL_1_FIELD); + assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD)); + Map level2 = (Map) nestedPassages.get(LEVEL_2_FIELD); + assertEquals(textFieldValue, level2.get(LEVEL_3_FIELD_TEXT)); + Map level3 = (Map) level2.get(LEVEL_3_FIELD_CONTAINER); + List embeddings = (List) level3.get(LEVEL_3_FIELD_EMBEDDING); + assertEquals(768, embeddings.size()); + for (Double embedding : embeddings) { + assertTrue(embedding >= 0.0 && embedding <= 1.0); + } + if (level3ExpectedValue.isPresent()) { + assertEquals(level3ExpectedValue.get(), level3.get("level_4_text_field")); + } + } + + public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json"); + Objects.requireNonNull(pipelineURLPath); + String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); + createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null); + createTextEmbeddingIndex(); + int docCount = 5; + ingestBatchDocumentWithBulk("batch_", docCount, Collections.emptySet(), Collections.emptySet()); + assertEquals(5, getDocCount(INDEX_NAME)); + + for (int i = 0; i < docCount; ++i) { + String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); + String payload = String.format(LOCALE, template, "success"); + ingestDocument(payload, String.valueOf(i + 1)); + } + + for (int i = 0; i < docCount; ++i) { + assertEquals( + getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"), + getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source") + ); + + } + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json"); + Objects.requireNonNull(pipelineURLPath); + String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); + createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null); + createTextEmbeddingIndex(); + int docCount = 5; + ingestBatchDocumentWithBulk("batch_", docCount, Set.of(0), Set.of(1)); + assertEquals(3, getDocCount(INDEX_NAME)); + + for (int i = 2; i < docCount; ++i) { + String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); + String payload = String.format(LOCALE, template, "success"); + ingestDocument(payload, String.valueOf(i + 1)); + } + + for (int i = 2; i < docCount; ++i) { + assertEquals( + getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"), + getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source") + ); + + } + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + private String uploadTextEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); return registerModelGroupAndUploadModel(requestBody); } @@ -86,11 +255,9 @@ private void createTextEmbeddingIndex() throws Exception { public void testAsymmetricTextEmbeddingProcessor() throws Exception { String modelId = null; try { - modelId = uploadTextEmbeddingModel( - Files.readString(Path.of(classLoader.getResource("processor/UploadAsymmetricModelRequestBody.json").toURI())) - ); + modelId = uploadTextEmbeddingModel(); loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); createTextEmbeddingIndex(); ingestDocument(); assertEquals(1, getDocCount(INDEX_NAME)); From 4f5c0cd897467b4b57ad872dd4ca29f13fb78131 Mon Sep 17 00:00:00 2001 From: br3no Date: Tue, 22 Oct 2024 18:57:55 +0200 Subject: [PATCH 06/10] Another silly mistake corrected... Signed-off-by: br3no --- .../neuralsearch/processor/TextEmbeddingProcessorIT.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 1d620e697..4432e39d1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -244,6 +244,11 @@ private String uploadTextEmbeddingModel() throws Exception { return registerModelGroupAndUploadModel(requestBody); } + private String uploadAsymmetricEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadAsymmetricModelRequestBody.json").toURI())); + return registerModelGroupAndUploadModel(requestBody); + } + private void createTextEmbeddingIndex() throws Exception { createIndexWithConfiguration( INDEX_NAME, @@ -255,7 +260,7 @@ private void createTextEmbeddingIndex() throws Exception { public void testAsymmetricTextEmbeddingProcessor() throws Exception { String modelId = null; try { - modelId = uploadTextEmbeddingModel(); + modelId = uploadAsymmetricEmbeddingModel(); loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); createTextEmbeddingIndex(); From 8b8f8717661d189076c81640f1b9420011d91bfa Mon Sep 17 00:00:00 2001 From: br3no Date: Wed, 23 Oct 2024 10:29:01 +0200 Subject: [PATCH 07/10] refactor test Signed-off-by: br3no --- .../processor/TextEmbeddingProcessorIT.java | 41 +------------------ src/test/resources/processor/ingest_doc5.json | 21 ++++++++++ 2 files changed, 23 insertions(+), 39 deletions(-) create mode 100644 src/test/resources/processor/ingest_doc5.json diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 4432e39d1..1611daaed 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -51,6 +51,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); + private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); private final String BULK_ITEM_TEMPLATE = Files.readString( Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI()) ); @@ -264,51 +265,13 @@ public void testAsymmetricTextEmbeddingProcessor() throws Exception { loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); createTextEmbeddingIndex(); - ingestDocument(); + ingestDocument(INGEST_DOC5, null); assertEquals(1, getDocCount(INDEX_NAME)); } finally { wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); } } - private void ingestDocument() throws Exception { - String ingestDocument = "{\n" - + " \"title\": \"This is a good day\",\n" - + " \"description\": \"daily logging\",\n" - + " \"favor_list\": [\n" - + " \"test\",\n" - + " \"hello\",\n" - + " \"mock\"\n" - + " ],\n" - + " \"favorites\": {\n" - + " \"game\": \"overwatch\",\n" - + " \"movie\": null\n" - + " },\n" - + " \"nested_passages\": [\n" - + " {\n" - + " \"text\": \"hello\"\n" - + " },\n" - + " {\n" - + " \"text\": \"world\"\n" - + " }\n" - + " ]\n" - + "}\n"; - Response response = makeRequest( - client(), - "POST", - INDEX_NAME + "/_doc?refresh", - null, - toHttpEntity(ingestDocument), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) - ); - Map map = XContentHelper.convertToMap( - XContentType.JSON.xContent(), - EntityUtils.toString(response.getEntity()), - false - ); - assertEquals("created", map.get("result")); - } - private void ingestDocument(String doc, String id) throws Exception { String endpoint; if (StringUtils.isEmpty(id)) { diff --git a/src/test/resources/processor/ingest_doc5.json b/src/test/resources/processor/ingest_doc5.json new file mode 100644 index 000000000..e3302c75a --- /dev/null +++ b/src/test/resources/processor/ingest_doc5.json @@ -0,0 +1,21 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": [ + { + "text": "hello" + }, + { + "text": "world" + } + ] +} From 2ac1f0963c0e9ae36137bc397c7466f97ea558b2 Mon Sep 17 00:00:00 2001 From: br3no Date: Tue, 5 Nov 2024 23:23:29 +0100 Subject: [PATCH 08/10] After further review round Signed-off-by: br3no --- .../ml/MLCommonsClientAccessor.java | 430 ++++++++++-------- .../processor/SparseEncodingProcessor.java | 15 +- .../processor/TextEmbeddingProcessor.java | 19 +- .../TextImageEmbeddingProcessor.java | 12 +- .../rerank/MLOpenSearchRerankProcessor.java | 7 +- .../query/NeuralQueryBuilder.java | 24 +- .../query/NeuralSparseQueryBuilder.java | 4 +- .../ml/MLCommonsClientAccessorTests.java | 110 +++-- .../processor/InferenceProcessorTests.java | 37 +- .../SparseEncodingProcessorTests.java | 35 +- .../TextEmbeddingProcessorTests.java | 85 +++- .../TextImageEmbeddingProcessorTests.java | 25 +- .../MLOpenSearchRerankProcessorTests.java | 13 +- .../query/NeuralQueryBuilderTests.java | 10 +- .../query/NeuralSparseQueryBuilderTests.java | 4 +- 15 files changed, 504 insertions(+), 326 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index a1a7a1601..2daa06c1e 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -11,12 +11,12 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; import org.opensearch.common.CheckedConsumer; -import org.opensearch.common.Nullable; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -46,46 +46,181 @@ @RequiredArgsConstructor @Log4j2 public class MLCommonsClientAccessor { - private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); - private final MachineLearningNodeClient mlClient; - private final Map modelAsymmetryCache = new ConcurrentHashMap<>(); /** - * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating - * point vector as a response. - * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * Inference parameters for calls to the MLCommons client. */ - public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @NonNull final ActionListener> listener - ) { - inferenceSentence(modelId, inputText, null, listener); + public static class InferenceRequest { + + private static final List DEFAULT_TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); + + private final String modelId; + private final List inputTexts; + private final MLAlgoParams mlAlgoParams; + private final List targetResponseFilters; + private final Map inputObjects; + private final String queryText; + + public InferenceRequest( + @NonNull String modelId, + List inputTexts, + MLAlgoParams mlAlgoParams, + List targetResponseFilters, + Map inputObjects, + String queryText + ) { + this.modelId = modelId; + this.inputTexts = inputTexts; + this.mlAlgoParams = mlAlgoParams; + this.targetResponseFilters = targetResponseFilters == null ? DEFAULT_TARGET_RESPONSE_FILTERS : targetResponseFilters; + this.inputObjects = inputObjects; + this.queryText = queryText; + } + + public String getModelId() { + return modelId; + } + + public List getInputTexts() { + return inputTexts; + } + + public MLAlgoParams getMlAlgoParams() { + return mlAlgoParams; + } + + public List getTargetResponseFilters() { + return targetResponseFilters; + } + + public Map getInputObjects() { + return inputObjects; + } + + public String getQueryText() { + return queryText; + } + + /** + * Builder for {@link InferenceRequest}. Supports fluent construction of the request object. + */ + public static class Builder { + + private final String modelId; + private List inputTexts; + private MLAlgoParams mlAlgoParams; + private List targetResponseFilters; + private Map inputObjects; + private String queryText; + + /** + * @param modelId the model id to use for inference + */ + public Builder(String modelId) { + this.modelId = modelId; + } + + /** + * @param inputTexts a {@link List} of input texts to use for inference + * @return this builder + */ + public Builder inputTexts(List inputTexts) { + this.inputTexts = inputTexts; + return this; + } + + /** + * @param inputText an input text to add to the list of input texts. Repeated calls will add + * more input texts. + * @return this builder + */ + public Builder inputText(String inputText) { + if (this.inputTexts != null) { + this.inputTexts.add(inputText); + } else { + this.inputTexts = new ArrayList<>(); + this.inputTexts.add(inputText); + } + return this; + } + + /** + * @param mlAlgoParams the {@link MLAlgoParams} to use for inference. + * @return this builder + */ + public Builder mlAlgoParams(MLAlgoParams mlAlgoParams) { + this.mlAlgoParams = mlAlgoParams; + return this; + } + + /** + * @param targetResponseFilters a {@link List} of target response filters to use for + * inference + * @return this builder + */ + public Builder targetResponseFilters(List targetResponseFilters) { + this.targetResponseFilters = targetResponseFilters; + return this; + } + + /** + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs + * to happen + * @return this builder + */ + public Builder inputObjects(Map inputObjects) { + this.inputObjects = inputObjects; + return this; + } + + /** + * @param queryText the query text to use for similarity inference + * @return this builder + */ + public Builder queryText(String queryText) { + this.queryText = queryText; + return this; + } + + /** + * @return a new {@link InferenceRequest} object with the parameters set in this builder + */ + public InferenceRequest build() { + return new InferenceRequest(modelId, inputTexts, mlAlgoParams, targetResponseFilters, inputObjects, queryText); + } + + } } + private final MachineLearningNodeClient mlClient; + private final Cache modelAsymmetryCache = CacheBuilder.builder().setMaximumWeight(10_000).build(); + /** - * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating - * point vector as a response. Supports passing {@link MLAlgoParams} to the inference. If the model is - * asymmetric, passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} is - * mandatory. This method will check whether the model being used is asymmetric and correctly handle the - * parameter, so it's okay to always pass the parameter (even if the model is symmetric). + * Wrapper around {@link #inferenceSentencesMap} that expects a single input text and produces a + * single floating point vector as a response. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} s + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out */ - public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener> listener - ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), mlAlgoParams, ActionListener.wrap(response -> { + public void inferenceSentence(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + if (inferenceRequest.inputTexts.size() != 1) { + listener.onFailure( + new IllegalArgumentException( + "Unexpected number of input texts. Expected 1 input text, but got [" + inferenceRequest.inputTexts.size() + "]" + ) + ); + return; + } + + inferenceSentences(inferenceRequest, ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( new IllegalStateException( @@ -95,111 +230,57 @@ public void inferenceSentence( return; } - listener.onResponse(response.get(0)); + listener.onResponse(response.getFirst()); }, listener::onFailure)); } /** * Abstraction to call predict function of api of MLClient with default targetResponse filters. It - * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The + * uses the custom model provided as modelId and runs the {@link FunctionName#TEXT_EMBEDDING}. The * return will be sent using the actionListener which will have a {@link List} of {@link List} of * {@link Float} in the order of inputText. We are not making this function generic enough to take * any function or TaskType as currently we need to run only TextEmbedding tasks only. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} as + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or - * errored out + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out */ public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, + @NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, null, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with default targetResponse filters. It - * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The - * return will be sent using the actionListener which will have a {@link List} of {@link List} of - * {@link Float} in the order of inputText. We are not making this function generic enough to take - * any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports - * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} - * is mandatory. This method will check whether the model being used is asymmetric and correctly - * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). - * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is completed or - */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener>> listener - ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, mlAlgoParams, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. - * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. - * The return will be sent using the actionListener which will have a {@link List} of {@link List} - * of {@link Float} in the order of inputText. We are not making this function generic enough to - * take any function or TaskType as currently we need to run only TextEmbedding tasks only. - * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is - * completed or errored out. - */ - public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener - ) { - inferenceSentences(targetResponseFilters, modelId, inputText, null, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. - * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. - * The return will be sent using the actionListener which will have a {@link List} of {@link List} - * of {@link Float} in the order of inputText. We are not making this function generic enough to - * take any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports - * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} - * is mandatory. This method will check whether the model being used is asymmetric and correctly - * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). - * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is - * completed or errored out. - */ - public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener>> listener - ) { - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, mlAlgoParams, 0, listener); + if (inferenceRequest.inputTexts.isEmpty()) { + listener.onFailure(new IllegalArgumentException("inputTexts must be provided")); + return; + } + retryableInferenceSentencesWithVectorResult( + inferenceRequest.targetResponseFilters, + inferenceRequest.modelId, + inferenceRequest.inputTexts, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } public void inferenceSentencesWithMapResult( - @NonNull final String modelId, - @NonNull final List inputText, + @NonNull InferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithMapResult(modelId, inputText, null, 0, listener); + retryableInferenceSentencesWithMapResult( + inferenceRequest.modelId, + inferenceRequest.inputTexts, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } /** @@ -207,45 +288,35 @@ public void inferenceSentencesWithMapResult( * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. * The return will be sent using the actionListener which will have a list of floats in the order * of inputText. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} as + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to - * happen - * @param listener {@link ActionListener} which will be called when prediction is completed or - * errored out. + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * Must contain inputObjects. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final Map inputObjects, + public void inferenceSentencesMap( + @NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener ) { - inferenceSentences(modelId, inputObjects, null, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. - * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. - * The return will be sent using the actionListener which will have a list of floats in the order - * of inputText. Supports passing {@link MLAlgoParams} to the inference. If the model is asymmetric, - * passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} - * is mandatory. This method will check whether the model being used is asymmetric and correctly - * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). - * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to - * happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is completed or - * errored out. - */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final Map inputObjects, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener> listener - ) { - retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, mlAlgoParams, 0, listener); + if (inferenceRequest.inputObjects == null) { + listener.onFailure(new IllegalArgumentException("inputObjects must be provided")); + return; + } + retryableInferenceSentencesWithSingleVectorResult( + inferenceRequest.targetResponseFilters, + inferenceRequest.modelId, + inferenceRequest.inputObjects, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } /** @@ -254,18 +325,22 @@ public void inferenceSentences( * actionListener as a list of floats representing the similarity scores of the texts w.r.t. the * query text, in the order of the input texts. * - * @param modelId {@link String} ML-Commons Model Id - * @param queryText {@link String} The query to compare all the inputText to - * @param inputText {@link List} of {@link String} The texts to compare to the query - * @param listener {@link ActionListener} receives the result of the inference + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * Must contain queryText. + * @param listener {@link ActionListener} receives the result of the inference */ - public void inferenceSimilarity( - @NonNull final String modelId, - @NonNull final String queryText, - @NonNull final List inputText, - @NonNull final ActionListener> listener - ) { - retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener); + public void inferenceSimilarity(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + if (inferenceRequest.queryText == null) { + listener.onFailure(new IllegalArgumentException("queryText must be provided")); + return; + } + retryableInferenceSimilarityWithVectorResult( + inferenceRequest.modelId, + inferenceRequest.queryText, + inferenceRequest.inputTexts, + 0, + listener + ); } private void retryableInferenceSentencesWithMapResult( @@ -329,30 +404,29 @@ private void retryableInferenceSentencesWithVectorResult( } /** - * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept - * that is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then - * this check is not applicable. - * + * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept that + * is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then this + * check is not applicable. + *

* The asymmetry of a model is static for a given model. To avoid repeated checks for the same * model, we cache the model asymmetry status. Non-TextEmbeddingModels are cached as false. * - * @param modelId The model id to check - * @param onFailure The action to take if the model cannot be retrieved + * @param modelId The model id to check + * @param onFailure The action to take if the model cannot be retrieved * @param runPrediction The action to take if the model is successfully retrieved */ private void checkModelAsymmetryAndThenPredict(String modelId, Consumer onFailure, Consumer runPrediction) { CheckedConsumer checkModelAsymmetryListener = model -> { MLModelConfig modelConfig = model.getModelConfig(); - if (!(modelConfig instanceof TextEmbeddingModelConfig)) { - modelAsymmetryCache.putIfAbsent(modelId, false); + if (!(modelConfig instanceof TextEmbeddingModelConfig textEmbeddingModelConfig)) { + modelAsymmetryCache.computeIfAbsent(modelId, k -> false); return; } - final TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig; final boolean isAsymmetricModel = textEmbeddingModelConfig.getPassagePrefix() != null || textEmbeddingModelConfig.getQueryPrefix() != null; - modelAsymmetryCache.putIfAbsent(modelId, isAsymmetricModel); + modelAsymmetryCache.computeIfAbsent(modelId, k -> isAsymmetricModel); }; - if (modelAsymmetryCache.containsKey(modelId)) { + if (modelAsymmetryCache.get(modelId) != null) { runPrediction.accept(modelAsymmetryCache.get(modelId)); } else { mlClient.getModel(modelId, ActionListener.wrap(mlModel -> { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..682a54126 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,6 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -48,17 +49,19 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(), + ActionListener.wrap(resultMaps -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult( - this.modelId, - inferenceList, + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 23dc6af49..16862c27c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -18,10 +18,12 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; /** - * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, - * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. + * This processor is used for user input data text embedding processing, model_id can be used to + * indicate which model user use, and field_map can be used to indicate which fields needs text + * embedding and the corresponding keys for the text embedding results. */ @Log4j2 public final class TextEmbeddingProcessor extends InferenceProcessor { @@ -29,6 +31,10 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + private static final AsymmetricTextEmbeddingParameters PASSAGE_PARAMETERS = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .build(); + public TextEmbeddingProcessor( String tag, String description, @@ -50,9 +56,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentences( - this.modelId, - inferenceList, - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); @@ -62,6 +66,9 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException)); + mlCommonsClientAccessor.inferenceSentences( + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), + ActionListener.wrap(handler::accept, onException) + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index e808869f9..f71110cdc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -25,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; /** @@ -113,10 +114,13 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer { - setVectorFieldsToDocument(ingestDocument, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentencesMap( + new InferenceRequest.Builder(this.modelId).inputObjects(inferenceMap).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } } catch (Exception e) { handler.accept(null, e); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index d8d9e8ec3..29508bd40 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -12,6 +12,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; @@ -73,9 +74,9 @@ public void rescoreSearchResponse( List ctxList = (List) ctxObj; List contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); mlCommonsClientAccessor.inferenceSimilarity( - modelId, - (String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), - contexts, + new InferenceRequest.Builder(modelId).queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD)) + .inputTexts(contexts) + .build(), listener ); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index c8b2a1d4a..2fccfafb0 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -57,11 +57,12 @@ import lombok.Setter; import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; /** - * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a - * k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as - * the query vector for the k-NN search. + * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a + * wrapper around a k-NN vector query. It uses a ML language model to produce a dense vector from a + * query string that is then used as the query vector for the k-NN search. */ @Log4j2 @@ -86,6 +87,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder static final ParseField K_FIELD = new ParseField("k"); private static final int DEFAULT_K = 10; + private static final AsymmetricTextEmbeddingParameters QUERY_PARAMETERS = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .build(); private static MLCommonsClientAccessor ML_CLIENT; @@ -335,10 +339,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { inferenceInput.put(INPUT_IMAGE, queryImage()); } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentences( - modelId(), - inferenceInput, - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), + ((client, actionListener) -> ML_CLIENT.inferenceSentencesMap( + new InferenceRequest.Builder(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(), ActionListener.wrap(floatList -> { vectorSetOnce.set(vectorAsListToArray(floatList)); actionListener.onResponse(null); @@ -368,8 +370,12 @@ protected Query doToQuery(QueryShardContext queryShardContext) { @Override protected boolean doEquals(NeuralQueryBuilder obj) { - if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(fieldName, obj.fieldName); equalsBuilder.append(queryText, obj.queryText); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index f46997d5e..ef9657759 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; @@ -341,8 +342,7 @@ private BiConsumer> getModelInferenceAsync(SetOnce ML_CLIENT.inferenceSentencesWithMapResult( - modelId(), - List.of(queryText), + new InferenceRequest.Builder(modelId()).inputTexts(List.of(queryText)).build(), ActionListener.wrap(mapResultList -> { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3c0376909..d45c6f15c 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -34,6 +34,7 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.neuralsearch.constants.TestCommonConstants; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.NodeNotConnectedException; @@ -66,9 +67,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentence( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST.get(0), - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)).build(), singleSentenceResultListener ); @@ -101,7 +100,10 @@ public void testInferenceSentences_whenValidInputThenSuccess() { }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentences( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -119,7 +121,10 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentences( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -137,9 +142,9 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .build(), resultListener ); @@ -163,9 +168,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .build(), resultListener ); @@ -185,9 +190,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .inputTexts(TestCommonConstants.SENTENCES_LIST) + .build(), resultListener ); @@ -206,9 +211,9 @@ public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(true); accessor.inferenceSentence( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST.get(0), - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) + .build(), singleSentenceResultListener ); @@ -233,9 +238,9 @@ public void testInferenceSentences_whenGetModelException_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(exception); accessor.inferenceSentence( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST.get(0), - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) + .build(), singleSentenceResultListener ); @@ -263,7 +268,10 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -282,7 +290,10 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -309,7 +320,10 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -339,7 +353,10 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -361,7 +378,10 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -379,7 +399,10 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client, times(1)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -396,7 +419,10 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + accessor.inferenceSentencesMap( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -414,7 +440,10 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + accessor.inferenceSentencesMap( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -422,7 +451,7 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } - public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { + public void testInferenceSentencesMapMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( mock(DiscoveryNode.class), "Node not connected" @@ -435,7 +464,10 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + accessor.inferenceSentencesMap( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -453,9 +485,9 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); @@ -476,9 +508,9 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); @@ -502,9 +534,9 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index dc86975bd..52ee1009e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -23,10 +23,10 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -63,7 +63,7 @@ public void test_batchExecute_emptyInput() { ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); verify(resultHandler).accept(captor.capture()); assertTrue(captor.getValue().isEmpty()); - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecuteWithEmpty_allFailedValidation() { @@ -85,7 +85,7 @@ public void test_batchExecuteWithEmpty_allFailedValidation() { ); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecuteWithNull_allFailedValidation() { @@ -104,7 +104,7 @@ public void test_batchExecuteWithNull_allFailedValidation() { assertEquals("list type field [key1] has null, cannot process it", captor.getValue().get(i).getException().getMessage()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecute_partialFailedValidation() { @@ -123,9 +123,9 @@ public void test_batchExecute_partialFailedValidation() { for (int i = 0; i < docCount; ++i) { assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(2, inferenceTextCaptor.getValue().size()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(2, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); } public void test_batchExecute_happyCase() { @@ -144,9 +144,9 @@ public void test_batchExecute_happyCase() { assertNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(4, inferenceTextCaptor.getValue().size()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(4, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); } public void test_batchExecute_sort() { @@ -165,10 +165,10 @@ public void test_batchExecute_sort() { assertNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(4, inferenceTextCaptor.getValue().size()); - assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceTextCaptor.getValue()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(4, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); + assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceRequestArgumentCaptor.getValue().getInputTexts()); List doc1Embeddings = (List) (captor.getValue().get(0).getIngestDocument().getFieldValue("embedding_key1", List.class)); List doc2Embeddings = (List) (captor.getValue().get(1).getIngestDocument().getFieldValue("embedding_key1", List.class)); @@ -197,7 +197,7 @@ public void test_doBatchExecute_exception() { assertNotNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecute_subBatches() { @@ -245,7 +245,10 @@ public void doExecute( @Override void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { // use to verify if doBatchExecute is called from InferenceProcessor - clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {})); + clientAccessor.inferenceSentences( + new InferenceRequest.Builder(MODEL_ID).inputTexts(inferenceList).build(), + ActionListener.wrap(results -> {}, ex -> {}) + ); allInferenceInputs.add(inferenceList); if (this.exception != null) { onException.accept(this.exception); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..ac2a1f0d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; @@ -100,10 +100,11 @@ public void testExecute_successful() { List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -132,7 +133,8 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(any(IngestDocument.class), isNull()); @@ -150,10 +152,11 @@ public void testExecute_withListTypeInput_successful() { List> dataAsMapList = createMockMapResult(6); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -169,10 +172,11 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { SparseEncodingProcessor processor = createInstance(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -206,10 +210,11 @@ public void testExecute_withMapTypeInput_successful() { List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -223,10 +228,11 @@ public void test_batchExecute_successful() { SparseEncodingProcessor processor = createInstance(docCount); List> dataAsMapList = createMockMapResult(10); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); @@ -244,10 +250,11 @@ public void test_batchExecute_exception() { List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); SparseEncodingProcessor processor = createInstance(docCount); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 1d83c8c95..ea009e777 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.processor; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -44,7 +44,6 @@ import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; @@ -152,10 +151,14 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -186,7 +189,10 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep config ); doThrow(new RuntimeException()).when(accessor) - .inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -214,7 +220,8 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(any(IngestDocument.class), isNull()); @@ -232,10 +239,14 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -308,10 +319,14 @@ public void testExecute_withMapTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -349,10 +364,14 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -409,10 +428,14 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -467,10 +490,14 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -518,10 +545,14 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -587,10 +618,14 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -630,7 +665,7 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE ActionListener>> listener = invocation.getArgument(2); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); @@ -830,10 +865,10 @@ public void test_batchExecute_successful() { List> modelTensorList = createMockVectorWithLength(10); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); @@ -851,10 +886,10 @@ public void test_batchExecute_exception() { List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(docCount); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 8f0018f52..a22a7ef76 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.processor; -import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -192,10 +192,11 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -229,7 +230,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -245,10 +247,11 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -277,10 +280,11 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextImageEmbeddingProcessor processor = createInstance(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -336,10 +340,11 @@ public void testExecute_whenInferencesAreEmpty_thenSuccessful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index dbd1c2bd6..3f745b9c3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -5,8 +5,7 @@ package org.opensearch.neuralsearch.processor.rerank; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -114,11 +113,12 @@ private void setupParams(Map params) { private void setupSimilarityRescoring() { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); List scores = List.of(1f, 2f, 3f); listener.onResponse(scores); return null; - }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + }).when(mlCommonsClientAccessor) + .inferenceSimilarity(argThat(request -> request.getQueryText() != null && request.getInputTexts() != null), any()); } private void setupSearchResults() throws IOException { @@ -345,11 +345,12 @@ public void testRerank_HappyPath() throws IOException { public void testRerank_whenScoresAndHitsHaveDiffLengths_thenFail() throws IOException { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); List scores = List.of(1f, 2f); listener.onResponse(scores); return null; - }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + }).when(mlCommonsClientAccessor) + .inferenceSimilarity(argThat(request -> request.getQueryText() != null && request.getInputTexts() != null), any()); setupSearchResults(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 5efbf3869..cc3dc5fbb 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -649,10 +649,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -685,10 +685,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7509efd42..eb350a9ea 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -623,10 +623,10 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() Map expectedMap = Map.of("1", 1f, "2", 2f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any()); NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); From 61347b02c80212df662cbe7a43f006438606e428 Mon Sep 17 00:00:00 2001 From: br3no Date: Fri, 8 Nov 2024 10:56:35 +0100 Subject: [PATCH 09/10] using lombok in InferenceRequest DTO Signed-off-by: br3no --- .../ml/MLCommonsClientAccessor.java | 138 +++--------------- .../processor/SparseEncodingProcessor.java | 4 +- .../processor/TextEmbeddingProcessor.java | 4 +- .../TextImageEmbeddingProcessor.java | 2 +- .../rerank/MLOpenSearchRerankProcessor.java | 4 +- .../query/NeuralQueryBuilder.java | 2 +- .../query/NeuralSparseQueryBuilder.java | 2 +- .../ml/MLCommonsClientAccessorTests.java | 56 ++++--- .../processor/InferenceProcessorTests.java | 2 +- 9 files changed, 64 insertions(+), 150 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 2daa06c1e..3b86c20e4 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -14,6 +14,9 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Getter; +import lombok.Singular; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.cache.Cache; import org.opensearch.common.cache.CacheBuilder; @@ -47,19 +50,24 @@ @Log4j2 public class MLCommonsClientAccessor { + public static final int MAXIMUM_CACHE_ENTRIES = 10_000; + /** * Inference parameters for calls to the MLCommons client. */ + @Getter + @Builder public static class InferenceRequest { private static final List DEFAULT_TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); - private final String modelId; - private final List inputTexts; - private final MLAlgoParams mlAlgoParams; - private final List targetResponseFilters; - private final Map inputObjects; - private final String queryText; + private final String modelId; // required + @Singular + private List inputTexts; + private MLAlgoParams mlAlgoParams; + private List targetResponseFilters; + private Map inputObjects; + private String queryText; public InferenceRequest( @NonNull String modelId, @@ -76,124 +84,12 @@ public InferenceRequest( this.inputObjects = inputObjects; this.queryText = queryText; } - - public String getModelId() { - return modelId; - } - - public List getInputTexts() { - return inputTexts; - } - - public MLAlgoParams getMlAlgoParams() { - return mlAlgoParams; - } - - public List getTargetResponseFilters() { - return targetResponseFilters; - } - - public Map getInputObjects() { - return inputObjects; - } - - public String getQueryText() { - return queryText; - } - - /** - * Builder for {@link InferenceRequest}. Supports fluent construction of the request object. - */ - public static class Builder { - - private final String modelId; - private List inputTexts; - private MLAlgoParams mlAlgoParams; - private List targetResponseFilters; - private Map inputObjects; - private String queryText; - - /** - * @param modelId the model id to use for inference - */ - public Builder(String modelId) { - this.modelId = modelId; - } - - /** - * @param inputTexts a {@link List} of input texts to use for inference - * @return this builder - */ - public Builder inputTexts(List inputTexts) { - this.inputTexts = inputTexts; - return this; - } - - /** - * @param inputText an input text to add to the list of input texts. Repeated calls will add - * more input texts. - * @return this builder - */ - public Builder inputText(String inputText) { - if (this.inputTexts != null) { - this.inputTexts.add(inputText); - } else { - this.inputTexts = new ArrayList<>(); - this.inputTexts.add(inputText); - } - return this; - } - - /** - * @param mlAlgoParams the {@link MLAlgoParams} to use for inference. - * @return this builder - */ - public Builder mlAlgoParams(MLAlgoParams mlAlgoParams) { - this.mlAlgoParams = mlAlgoParams; - return this; - } - - /** - * @param targetResponseFilters a {@link List} of target response filters to use for - * inference - * @return this builder - */ - public Builder targetResponseFilters(List targetResponseFilters) { - this.targetResponseFilters = targetResponseFilters; - return this; - } - - /** - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs - * to happen - * @return this builder - */ - public Builder inputObjects(Map inputObjects) { - this.inputObjects = inputObjects; - return this; - } - - /** - * @param queryText the query text to use for similarity inference - * @return this builder - */ - public Builder queryText(String queryText) { - this.queryText = queryText; - return this; - } - - /** - * @return a new {@link InferenceRequest} object with the parameters set in this builder - */ - public InferenceRequest build() { - return new InferenceRequest(modelId, inputTexts, mlAlgoParams, targetResponseFilters, inputObjects, queryText); - } - - } } private final MachineLearningNodeClient mlClient; - private final Cache modelAsymmetryCache = CacheBuilder.builder().setMaximumWeight(10_000).build(); + private final Cache modelAsymmetryCache = CacheBuilder.builder() + .setMaximumWeight(MAXIMUM_CACHE_ENTRIES) + .build(); /** * Wrapper around {@link #inferenceSentencesMap} that expects a single input text and produces a diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 682a54126..54dc7417c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -50,7 +50,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(), + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(resultMaps -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); handler.accept(ingestDocument, null); @@ -61,7 +61,7 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(), + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 16862c27c..05422850d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -56,7 +56,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentences( - new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); @@ -67,7 +67,7 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentences( - new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), ActionListener.wrap(handler::accept, onException) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index f71110cdc..672dfbf4d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -115,7 +115,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer { setVectorFieldsToDocument(ingestDocument, vectors); handler.accept(ingestDocument, null); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index 29508bd40..b4a285bf2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -74,7 +74,9 @@ public void rescoreSearchResponse( List ctxList = (List) ctxObj; List contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); mlCommonsClientAccessor.inferenceSimilarity( - new InferenceRequest.Builder(modelId).queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD)) + InferenceRequest.builder() + .modelId(modelId) + .queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD)) .inputTexts(contexts) .build(), listener diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 2fccfafb0..81ea3bcaf 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -340,7 +340,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { } queryRewriteContext.registerAsyncAction( ((client, actionListener) -> ML_CLIENT.inferenceSentencesMap( - new InferenceRequest.Builder(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(), + InferenceRequest.builder().modelId(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(), ActionListener.wrap(floatList -> { vectorSetOnce.set(vectorAsListToArray(floatList)); actionListener.onResponse(null); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index ef9657759..05244de9c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -342,7 +342,7 @@ private BiConsumer> getModelInferenceAsync(SetOnce ML_CLIENT.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(modelId()).inputTexts(List.of(queryText)).build(), + InferenceRequest.builder().modelId(modelId()).inputTexts(List.of(queryText)).build(), ActionListener.wrap(mapResultList -> { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index d45c6f15c..10fe165f3 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -67,7 +67,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentence( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)).build(), singleSentenceResultListener ); @@ -101,7 +101,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -122,7 +122,7 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -142,7 +142,9 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST) + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputTexts(TestCommonConstants.SENTENCES_LIST) .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) .build(), resultListener @@ -168,7 +170,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST) + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputTexts(TestCommonConstants.SENTENCES_LIST) .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) .build(), resultListener @@ -190,7 +194,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) .inputTexts(TestCommonConstants.SENTENCES_LIST) .build(), resultListener @@ -211,7 +217,9 @@ public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(true); accessor.inferenceSentence( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputText(TestCommonConstants.SENTENCES_LIST.get(0)) .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) .build(), singleSentenceResultListener @@ -238,7 +246,9 @@ public void testInferenceSentences_whenGetModelException_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(exception); accessor.inferenceSentence( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputText(TestCommonConstants.SENTENCES_LIST.get(0)) .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) .build(), singleSentenceResultListener @@ -269,7 +279,7 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -291,7 +301,7 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -321,7 +331,7 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -354,7 +364,7 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -379,7 +389,7 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -400,7 +410,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesWithMapResult( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), resultListener ); @@ -420,7 +430,7 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesMap( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), singleSentenceResultListener ); @@ -441,7 +451,7 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesMap( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), singleSentenceResultListener ); @@ -465,7 +475,7 @@ public void testInferenceSentencesMapMultimodal_whenNodeNotConnectedException_th setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentencesMap( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), singleSentenceResultListener ); @@ -485,7 +495,9 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .queryText("is it sunny") .inputTexts(List.of("it is sunny today", "roses are red")) .build(), singleSentenceResultListener @@ -508,7 +520,9 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .queryText("is it sunny") .inputTexts(List.of("it is sunny today", "roses are red")) .build(), singleSentenceResultListener @@ -534,7 +548,9 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .queryText("is it sunny") .inputTexts(List.of("it is sunny today", "roses are red")) .build(), singleSentenceResultListener diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index 52ee1009e..80b05a1e8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -246,7 +246,7 @@ public void doExecute( void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { // use to verify if doBatchExecute is called from InferenceProcessor clientAccessor.inferenceSentences( - new InferenceRequest.Builder(MODEL_ID).inputTexts(inferenceList).build(), + InferenceRequest.builder().modelId(MODEL_ID).inputTexts(inferenceList).build(), ActionListener.wrap(results -> {}, ex -> {}) ); allInferenceInputs.add(inferenceList); From 65a4b3917db54ef73451875a6554c0287e0b6b32 Mon Sep 17 00:00:00 2001 From: br3no Date: Fri, 8 Nov 2024 11:07:01 +0100 Subject: [PATCH 10/10] make max cache entries private Signed-off-by: br3no --- .../org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 3b86c20e4..f55823b7b 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -50,7 +50,7 @@ @Log4j2 public class MLCommonsClientAccessor { - public static final int MAXIMUM_CACHE_ENTRIES = 10_000; + private static final int MAXIMUM_CACHE_ENTRIES = 10_000; /** * Inference parameters for calls to the MLCommons client.