diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..c72b87e0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Add support for asymmetric embedding models ([#710](https://github.com/opensearch-project/neural-search/pull/710)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index f9ddf73a9..f55823b7b 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -11,16 +11,27 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Getter; +import lombok.Singular; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.output.model.ModelTensor; @@ -38,23 +49,74 @@ @RequiredArgsConstructor @Log4j2 public class MLCommonsClientAccessor { - private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); + + private static final int MAXIMUM_CACHE_ENTRIES = 10_000; + + /** + * Inference parameters for calls to the MLCommons client. + */ + @Getter + @Builder + public static class InferenceRequest { + + private static final List DEFAULT_TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); + + private final String modelId; // required + @Singular + private List inputTexts; + private MLAlgoParams mlAlgoParams; + private List targetResponseFilters; + private Map inputObjects; + private String queryText; + + public InferenceRequest( + @NonNull String modelId, + List inputTexts, + MLAlgoParams mlAlgoParams, + List targetResponseFilters, + Map inputObjects, + String queryText + ) { + this.modelId = modelId; + this.inputTexts = inputTexts; + this.mlAlgoParams = mlAlgoParams; + this.targetResponseFilters = targetResponseFilters == null ? DEFAULT_TARGET_RESPONSE_FILTERS : targetResponseFilters; + this.inputObjects = inputObjects; + this.queryText = queryText; + } + } + private final MachineLearningNodeClient mlClient; + private final Cache modelAsymmetryCache = CacheBuilder.builder() + .setMaximumWeight(MAXIMUM_CACHE_ENTRIES) + .build(); /** - * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating - * point vector as a response. + * Wrapper around {@link #inferenceSentencesMap} that expects a single input text and produces a + * single floating point vector as a response. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} s + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out */ - public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @NonNull final ActionListener> listener - ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { + public void inferenceSentence(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + if (inferenceRequest.inputTexts.size() != 1) { + listener.onFailure( + new IllegalArgumentException( + "Unexpected number of input texts. Expected 1 input text, but got [" + inferenceRequest.inputTexts.size() + "]" + ) + ); + return; + } + + inferenceSentences(inferenceRequest, ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( new IllegalStateException( @@ -64,133 +126,210 @@ public void inferenceSentence( return; } - listener.onResponse(response.get(0)); + listener.onResponse(response.getFirst()); }, listener::onFailure)); } /** - * Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we - * need to run only TextEmbedding tasks only. - * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out - */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener - ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we - * need to run only TextEmbedding tasks only. + * Abstraction to call predict function of api of MLClient with default targetResponse filters. It + * uses the custom model provided as modelId and runs the {@link FunctionName#TEXT_EMBEDDING}. The + * return will be sent using the actionListener which will have a {@link List} of {@link List} of + * {@link Float} in the order of inputText. We are not making this function generic enough to take + * any function or TaskType as currently we need to run only TextEmbedding tasks only. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} as + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out */ public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, + @NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); + if (inferenceRequest.inputTexts.isEmpty()) { + listener.onFailure(new IllegalArgumentException("inputTexts must be provided")); + return; + } + retryableInferenceSentencesWithVectorResult( + inferenceRequest.targetResponseFilters, + inferenceRequest.modelId, + inferenceRequest.inputTexts, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } public void inferenceSentencesWithMapResult( - @NonNull final String modelId, - @NonNull final List inputText, + @NonNull InferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); + retryableInferenceSentencesWithMapResult( + inferenceRequest.modelId, + inferenceRequest.inputTexts, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a list of floats in the order of inputText. + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a list of floats in the order + * of inputText. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} as + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * Must contain inputObjects. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final Map inputObjects, + public void inferenceSentencesMap( + @NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener ) { - retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + if (inferenceRequest.inputObjects == null) { + listener.onFailure(new IllegalArgumentException("inputObjects must be provided")); + return; + } + retryableInferenceSentencesWithSingleVectorResult( + inferenceRequest.targetResponseFilters, + inferenceRequest.modelId, + inferenceRequest.inputObjects, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } /** - * Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the - * {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing - * the similarity scores of the texts w.r.t. the query text, in the order of the input texts. + * Abstraction to call predict function of api of MLClient. It uses the custom model provided as + * modelId and the {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via + * actionListener as a list of floats representing the similarity scores of the texts w.r.t. the + * query text, in the order of the input texts. * - * @param modelId {@link String} ML-Commons Model Id - * @param queryText {@link String} The query to compare all the inputText to - * @param inputText {@link List} of {@link String} The texts to compare to the query - * @param listener {@link ActionListener} receives the result of the inference + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * Must contain queryText. + * @param listener {@link ActionListener} receives the result of the inference */ - public void inferenceSimilarity( - @NonNull final String modelId, - @NonNull final String queryText, - @NonNull final List inputText, - @NonNull final ActionListener> listener - ) { - retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener); + public void inferenceSimilarity(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + if (inferenceRequest.queryText == null) { + listener.onFailure(new IllegalArgumentException("queryText must be provided")); + return; + } + retryableInferenceSimilarityWithVectorResult( + inferenceRequest.modelId, + inferenceRequest.queryText, + inferenceRequest.inputTexts, + 0, + listener + ); } private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLTextInput(null, inputText); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> result = buildMapResultFromResponse(mlOutput); - listener.onResponse(result); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + + Consumer runPrediction = isAsymmetricModel -> { + MLInput mlInput = createMLTextInput(null, inputText, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> result = buildMapResultFromResponse(mlOutput); + listener.onResponse(result); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithMapResult(modelId, inputText, mlAlgoParams, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, runPrediction); } private void retryableInferenceSentencesWithVectorResult( final List targetResponseFilters, final String modelId, final List inputText, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> vector = buildVectorFromResponse(mlOutput); - listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); + + Consumer runPrediction = isAsymmetricModel -> { + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> vector = buildVectorFromResponse(mlOutput); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithVectorResult( + targetResponseFilters, + modelId, + inputText, + mlAlgoParams, + retryTimeAdd, + listener + ); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, runPrediction); + } + + /** + * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept that + * is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then this + * check is not applicable. + *

+ * The asymmetry of a model is static for a given model. To avoid repeated checks for the same + * model, we cache the model asymmetry status. Non-TextEmbeddingModels are cached as false. + * + * @param modelId The model id to check + * @param onFailure The action to take if the model cannot be retrieved + * @param runPrediction The action to take if the model is successfully retrieved + */ + private void checkModelAsymmetryAndThenPredict(String modelId, Consumer onFailure, Consumer runPrediction) { + CheckedConsumer checkModelAsymmetryListener = model -> { + MLModelConfig modelConfig = model.getModelConfig(); + if (!(modelConfig instanceof TextEmbeddingModelConfig textEmbeddingModelConfig)) { + modelAsymmetryCache.computeIfAbsent(modelId, k -> false); + return; } - })); + final boolean isAsymmetricModel = textEmbeddingModelConfig.getPassagePrefix() != null + || textEmbeddingModelConfig.getQueryPrefix() != null; + modelAsymmetryCache.computeIfAbsent(modelId, k -> isAsymmetricModel); + }; + if (modelAsymmetryCache.get(modelId) != null) { + runPrediction.accept(modelAsymmetryCache.get(modelId)); + } else { + mlClient.getModel(modelId, ActionListener.wrap(mlModel -> { + checkModelAsymmetryListener.accept(mlModel); + runPrediction.accept(modelAsymmetryCache.get(modelId)); + }, onFailure)); + } } private void retryableInferenceSimilarityWithVectorResult( @@ -213,10 +352,10 @@ private void retryableInferenceSimilarityWithVectorResult( })); } - private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { + private MLInput createMLTextInput(final List targetResponseFilters, List inputText, MLAlgoParams mlAlgoParams) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + return new MLInput(FunctionName.TEXT_EMBEDDING, mlAlgoParams, inputDataset); } private MLInput createMLTextPairsInput(final String query, final List inputText) { @@ -264,25 +403,42 @@ private void retryableInferenceSentencesWithSingleVectorResult( final List targetResponseFilters, final String modelId, final Map inputObjects, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener> listener ) { - MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List vector = buildSingleVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence is : {} ", vector); - listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + + Consumer predictConsumer = isAsymmetricModel -> { + MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List vector = buildSingleVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence is : {} ", vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithSingleVectorResult( + targetResponseFilters, + modelId, + inputObjects, + mlAlgoParams, + retryTimeAdd, + listener + ); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, predictConsumer); } - private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { + private MLInput createMLMultimodalInput( + final List targetResponseFilters, + final Map input, + MLAlgoParams mlAlgoParams + ) { List inputText = new ArrayList<>(); inputText.add(input.get(INPUT_TEXT)); if (input.containsKey(INPUT_IMAGE)) { @@ -290,6 +446,6 @@ private MLInput createMLMultimodalInput(final List targetResponseFilters } final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + return new MLInput(FunctionName.TEXT_EMBEDDING, mlAlgoParams, inputDataset); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..54dc7417c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,6 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -48,17 +49,19 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), + ActionListener.wrap(resultMaps -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult( - this.modelId, - inferenceList, + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c8f9f080d..05422850d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -13,13 +13,17 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; /** - * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, - * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. + * This processor is used for user input data text embedding processing, model_id can be used to + * indicate which model user use, and field_map can be used to indicate which fields needs text + * embedding and the corresponding keys for the text embedding results. */ @Log4j2 public final class TextEmbeddingProcessor extends InferenceProcessor { @@ -27,6 +31,10 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + private static final AsymmetricTextEmbeddingParameters PASSAGE_PARAMETERS = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .build(); + public TextEmbeddingProcessor( String tag, String description, @@ -47,14 +55,20 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentences( + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException)); + mlCommonsClientAccessor.inferenceSentences( + InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), + ActionListener.wrap(handler::accept, onException) + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index e808869f9..672dfbf4d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -25,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; /** @@ -113,10 +114,13 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer { - setVectorFieldsToDocument(ingestDocument, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentencesMap( + InferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } } catch (Exception e) { handler.accept(null, e); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index d8d9e8ec3..b4a285bf2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -12,6 +12,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; @@ -73,9 +74,11 @@ public void rescoreSearchResponse( List ctxList = (List) ctxObj; List contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); mlCommonsClientAccessor.inferenceSimilarity( - modelId, - (String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), - contexts, + InferenceRequest.builder() + .modelId(modelId) + .queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD)) + .inputTexts(contexts) + .build(), listener ); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 915a79117..81ea3bcaf 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -44,6 +44,8 @@ import org.opensearch.knn.index.query.parser.RescoreParser; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.common.MinClusterVersionUtil; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import com.google.common.annotations.VisibleForTesting; @@ -55,11 +57,12 @@ import lombok.Setter; import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; /** - * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a - * k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as - * the query vector for the k-NN search. + * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a + * wrapper around a k-NN vector query. It uses a ML language model to produce a dense vector from a + * query string that is then used as the query vector for the k-NN search. */ @Log4j2 @@ -84,6 +87,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder static final ParseField K_FIELD = new ParseField("k"); private static final int DEFAULT_K = 10; + private static final AsymmetricTextEmbeddingParameters QUERY_PARAMETERS = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .build(); private static MLCommonsClientAccessor ML_CLIENT; @@ -333,10 +339,13 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { inferenceInput.put(INPUT_IMAGE, queryImage()); } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { - vectorSetOnce.set(vectorAsListToArray(floatList)); - actionListener.onResponse(null); - }, actionListener::onFailure))) + ((client, actionListener) -> ML_CLIENT.inferenceSentencesMap( + InferenceRequest.builder().modelId(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(), + ActionListener.wrap(floatList -> { + vectorSetOnce.set(vectorAsListToArray(floatList)); + actionListener.onResponse(null); + }, actionListener::onFailure) + )) ); return new NeuralQueryBuilder( fieldName(), @@ -361,8 +370,12 @@ protected Query doToQuery(QueryShardContext queryShardContext) { @Override protected boolean doEquals(NeuralQueryBuilder obj) { - if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(fieldName, obj.fieldName); equalsBuilder.append(queryText, obj.queryText); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index f46997d5e..05244de9c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; @@ -341,8 +342,7 @@ private BiConsumer> getModelInferenceAsync(SetOnce ML_CLIENT.inferenceSentencesWithMapResult( - modelId(), - List.of(queryText), + InferenceRequest.builder().modelId(modelId()).inputTexts(List.of(queryText)).build(), ActionListener.wrap(mapResultList -> { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3749e63dc..10fe165f3 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -23,13 +23,18 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.neuralsearch.constants.TestCommonConstants; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.NodeNotConnectedException; @@ -59,8 +64,12 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener); + accessor.inferenceSentence( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -68,6 +77,19 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } + private void setupMocksForTextEmbeddingModelAsymmetryCheck(boolean isAsymmetric) { + MLModel modelMock = mock(MLModel.class); + TextEmbeddingModelConfig textEmbeddingModelConfigMock = mock(TextEmbeddingModelConfig.class); + Mockito.when(textEmbeddingModelConfigMock.getPassagePrefix()).thenReturn(isAsymmetric ? "passage: " : null); + Mockito.when(textEmbeddingModelConfigMock.getQueryPrefix()).thenReturn(isAsymmetric ? "query: " : null); + Mockito.when(modelMock.getModelConfig()).thenReturn(textEmbeddingModelConfigMock); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(modelMock); + return null; + }).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + } + public void testInferenceSentences_whenValidInputThenSuccess() { final List> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -76,7 +98,12 @@ public void testInferenceSentences_whenValidInputThenSuccess() { actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentences( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -92,7 +119,12 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { actionListener.onResponse(createModelTensorOutput(new Float[] {})); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentences( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -107,10 +139,14 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { actionListener.onFailure(exception); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputTexts(TestCommonConstants.SENTENCES_LIST) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .build(), resultListener ); @@ -130,10 +166,15 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputTexts(TestCommonConstants.SENTENCES_LIST) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .build(), resultListener ); @@ -149,10 +190,15 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { actionListener.onFailure(illegalStateException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .inputTexts(TestCommonConstants.SENTENCES_LIST) + .build(), resultListener ); @@ -161,6 +207,66 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(true); + + accessor.inferenceSentence( + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) + .build(), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict( + Mockito.eq(TestCommonConstants.MODEL_ID), + Mockito.argThat((MLInput input) -> input.getParameters() != null), + Mockito.isA(ActionListener.class) + ); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSentences_whenGetModelException_thenFailure() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + RuntimeException exception = new RuntimeException("Bam!"); + setupMocksForTextEmbeddingModelAsymmetryCheck(exception); + + accessor.inferenceSentence( + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) + .build(), + singleSentenceResultListener + ); + + Mockito.verify(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + private void setupMocksForTextEmbeddingModelAsymmetryCheck(Exception exception) { + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(exception); + return null; + }).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + } + public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); @@ -169,7 +275,13 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { actionListener.onResponse(createModelTensorOutput(map)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -185,7 +297,13 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -209,7 +327,13 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -236,7 +360,13 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -255,7 +385,13 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -270,7 +406,13 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesWithMapResult( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client, times(1)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -285,7 +427,12 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesMap( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -300,7 +447,13 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { actionListener.onFailure(exception); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesMap( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -308,7 +461,7 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } - public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { + public void testInferenceSentencesMapMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( mock(DiscoveryNode.class), "Node not connected" @@ -318,7 +471,13 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + + accessor.inferenceSentencesMap( + InferenceRequest.builder().modelId(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -333,10 +492,14 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); @@ -354,10 +517,14 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); @@ -378,10 +545,14 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + InferenceRequest.builder() + .modelId(TestCommonConstants.MODEL_ID) + .queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index dc86975bd..80b05a1e8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -23,10 +23,10 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -63,7 +63,7 @@ public void test_batchExecute_emptyInput() { ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); verify(resultHandler).accept(captor.capture()); assertTrue(captor.getValue().isEmpty()); - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecuteWithEmpty_allFailedValidation() { @@ -85,7 +85,7 @@ public void test_batchExecuteWithEmpty_allFailedValidation() { ); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecuteWithNull_allFailedValidation() { @@ -104,7 +104,7 @@ public void test_batchExecuteWithNull_allFailedValidation() { assertEquals("list type field [key1] has null, cannot process it", captor.getValue().get(i).getException().getMessage()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecute_partialFailedValidation() { @@ -123,9 +123,9 @@ public void test_batchExecute_partialFailedValidation() { for (int i = 0; i < docCount; ++i) { assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(2, inferenceTextCaptor.getValue().size()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(2, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); } public void test_batchExecute_happyCase() { @@ -144,9 +144,9 @@ public void test_batchExecute_happyCase() { assertNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(4, inferenceTextCaptor.getValue().size()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(4, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); } public void test_batchExecute_sort() { @@ -165,10 +165,10 @@ public void test_batchExecute_sort() { assertNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(4, inferenceTextCaptor.getValue().size()); - assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceTextCaptor.getValue()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(4, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); + assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceRequestArgumentCaptor.getValue().getInputTexts()); List doc1Embeddings = (List) (captor.getValue().get(0).getIngestDocument().getFieldValue("embedding_key1", List.class)); List doc2Embeddings = (List) (captor.getValue().get(1).getIngestDocument().getFieldValue("embedding_key1", List.class)); @@ -197,7 +197,7 @@ public void test_doBatchExecute_exception() { assertNotNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecute_subBatches() { @@ -245,7 +245,10 @@ public void doExecute( @Override void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { // use to verify if doBatchExecute is called from InferenceProcessor - clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {})); + clientAccessor.inferenceSentences( + InferenceRequest.builder().modelId(MODEL_ID).inputTexts(inferenceList).build(), + ActionListener.wrap(results -> {}, ex -> {}) + ); allInferenceInputs.add(inferenceList); if (this.exception != null) { onException.accept(this.exception); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..ac2a1f0d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; @@ -100,10 +100,11 @@ public void testExecute_successful() { List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -132,7 +133,8 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(any(IngestDocument.class), isNull()); @@ -150,10 +152,11 @@ public void testExecute_withListTypeInput_successful() { List> dataAsMapList = createMockMapResult(6); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -169,10 +172,11 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { SparseEncodingProcessor processor = createInstance(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -206,10 +210,11 @@ public void testExecute_withMapTypeInput_successful() { List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -223,10 +228,11 @@ public void test_batchExecute_successful() { SparseEncodingProcessor processor = createInstance(docCount); List> dataAsMapList = createMockMapResult(10); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); @@ -244,10 +250,11 @@ public void test_batchExecute_exception() { List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); SparseEncodingProcessor processor = createInstance(docCount); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 4afa4031d..1611daaed 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -51,6 +51,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); + private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); private final String BULK_ITEM_TEMPLATE = Files.readString( Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI()) ); @@ -244,6 +245,11 @@ private String uploadTextEmbeddingModel() throws Exception { return registerModelGroupAndUploadModel(requestBody); } + private String uploadAsymmetricEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadAsymmetricModelRequestBody.json").toURI())); + return registerModelGroupAndUploadModel(requestBody); + } + private void createTextEmbeddingIndex() throws Exception { createIndexWithConfiguration( INDEX_NAME, @@ -252,6 +258,20 @@ private void createTextEmbeddingIndex() throws Exception { ); } + public void testAsymmetricTextEmbeddingProcessor() throws Exception { + String modelId = null; + try { + modelId = uploadAsymmetricEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); + createTextEmbeddingIndex(); + ingestDocument(INGEST_DOC5, null); + assertEquals(1, getDocCount(INDEX_NAME)); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + private void ingestDocument(String doc, String id) throws Exception { String endpoint; if (StringUtils.isEmpty(id)) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 97e85e46e..ea009e777 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.processor; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -151,10 +151,14 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -184,7 +188,11 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -212,7 +220,8 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(any(IngestDocument.class), isNull()); @@ -230,10 +239,14 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -306,10 +319,14 @@ public void testExecute_withMapTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -347,10 +364,14 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -407,10 +428,14 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -465,10 +490,14 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -516,10 +545,14 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -585,10 +618,14 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -628,7 +665,7 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE ActionListener>> listener = invocation.getArgument(2); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); @@ -828,10 +865,10 @@ public void test_batchExecute_successful() { List> modelTensorList = createMockVectorWithLength(10); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); @@ -849,10 +886,10 @@ public void test_batchExecute_exception() { List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(docCount); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 8f0018f52..a22a7ef76 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.processor; -import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -192,10 +192,11 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -229,7 +230,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -245,10 +247,11 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -277,10 +280,11 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextImageEmbeddingProcessor processor = createInstance(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -336,10 +340,11 @@ public void testExecute_whenInferencesAreEmpty_thenSuccessful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index dbd1c2bd6..3f745b9c3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -5,8 +5,7 @@ package org.opensearch.neuralsearch.processor.rerank; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -114,11 +113,12 @@ private void setupParams(Map params) { private void setupSimilarityRescoring() { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); List scores = List.of(1f, 2f, 3f); listener.onResponse(scores); return null; - }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + }).when(mlCommonsClientAccessor) + .inferenceSimilarity(argThat(request -> request.getQueryText() != null && request.getInputTexts() != null), any()); } private void setupSearchResults() throws IOException { @@ -345,11 +345,12 @@ public void testRerank_HappyPath() throws IOException { public void testRerank_whenScoresAndHitsHaveDiffLengths_thenFail() throws IOException { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); List scores = List.of(1f, 2f); listener.onResponse(scores); return null; - }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + }).when(mlCommonsClientAccessor) + .inferenceSimilarity(argThat(request -> request.getQueryText() != null && request.getInputTexts() != null), any()); setupSearchResults(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 6d8e810f3..cc3dc5fbb 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -649,10 +649,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -685,10 +685,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7509efd42..eb350a9ea 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -623,10 +623,10 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() Map expectedMap = Map.of("1", 1f, "2", 2f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any()); NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/resources/processor/UploadAsymmetricModelRequestBody.json b/src/test/resources/processor/UploadAsymmetricModelRequestBody.json new file mode 100644 index 000000000..8c5b6ec18 --- /dev/null +++ b/src/test/resources/processor/UploadAsymmetricModelRequestBody.json @@ -0,0 +1,17 @@ +{ + "name": "traced_small_model", + "version": "1.0.0", + "model_format": "TORCH_SCRIPT", + "model_task_type": "text_embedding", + "model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021", + "model_group_id": "%s", + "model_config": { + "model_type": "bert", + "embedding_dimension": 768, + "framework_type": "sentence_transformers", + "passage_prefix" : "passage: ", + "query_prefix" : "query: ", + "all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}" + }, + "url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true" +} diff --git a/src/test/resources/processor/ingest_doc5.json b/src/test/resources/processor/ingest_doc5.json new file mode 100644 index 000000000..e3302c75a --- /dev/null +++ b/src/test/resources/processor/ingest_doc5.json @@ -0,0 +1,21 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": [ + { + "text": "hello" + }, + { + "text": "world" + } + ] +}