diff --git a/build.gradle b/build.gradle index 90b1ca1f1..f94462590 100644 --- a/build.gradle +++ b/build.gradle @@ -259,6 +259,8 @@ dependencies { api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' + implementation group: 'ai.djl', name: 'api', version: '0.28.0' + implementation group: 'ai.djl.huggingface', name: 'tokenizers', version: '0.28.0' // ml-common excluded reflection for runtime so we need to add it by ourselves. // https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9 // TODO: Remove following three lines of dependencies if ml-common include them in their jar diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java b/src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java new file mode 100644 index 000000000..2a9f6d3dc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.util.Utils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Callable; + +public class DJLUtils { + static private Path ML_CACHE_PATH; + static private String ML_CACHE_DIR_NAME = "ml_cache"; + static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/"; + static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/"; + + static public void buildDJLCachePath(Path opensearchDataFolder) { + // the logic to build cache path is consistent with ml-commons plugin + // see + // https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53 + ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME); + } + + public static T withDJLContext(Callable action) throws PrivilegedActionException { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + try { + System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString()); + System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString()); + Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader()); + + return action.call(); + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader); + } + }); + } + + public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) { + try { + return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId)); + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e); + } + } + + public static Map parseInputStreamToTokenWeights(InputStream inputStream) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { + Map tokenWeights = new HashMap<>(); + String line; + while ((line = reader.readLine()) != null) { + if (line.trim().isEmpty()) { + continue; + } + String[] parts = line.split("\t"); + if (parts.length != 2) { + throw new IllegalArgumentException("Invalid line in token weights file: " + line); + } + String token = parts[0]; + float weight = Float.parseFloat(parts[1]); + tokenWeights.put(token, weight); + } + return tokenWeights; + } catch (IOException e) { + throw new RuntimeException("Failed to parse token weights file. " + e); + } + } + + public static Map fetchTokenWeights(String tokenizerId, String fileName) { + Map tokenWeights = new HashMap<>(); + String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName; + + InputStream inputStream = null; + try { + inputStream = withDJLContext(() -> Utils.openUrl(url)); + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to download file from " + url, e); + } + + return parseInputStreamToTokenWeights(inputStream); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzer.java b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzer.java new file mode 100644 index 000000000..70c12e8d6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.Tokenizer; + +import java.util.function.Supplier; + +public class HFModelAnalyzer extends Analyzer { + public static final String NAME = "hf_model_tokenizer"; + Supplier tokenizerSupplier; + + public HFModelAnalyzer() { + this.tokenizerSupplier = HFModelTokenizerFactory::createDefault; + } + + HFModelAnalyzer(Supplier tokenizerSupplier) { + this.tokenizerSupplier = tokenizerSupplier; + } + + @Override + protected TokenStreamComponents createComponents(String fieldName) { + final Tokenizer src = tokenizerSupplier.get(); + return new TokenStreamComponents(src, src); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzerProvider.java b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzerProvider.java new file mode 100644 index 000000000..b15789f99 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzerProvider.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.analysis.AbstractIndexAnalyzerProvider; + +public class HFModelAnalyzerProvider extends AbstractIndexAnalyzerProvider { + private final HFModelAnalyzer analyzer; + + public HFModelAnalyzerProvider(IndexSettings indexSettings, Environment environment, String name, Settings settings) { + super(indexSettings, name, settings); + HFModelTokenizerFactory tokenizerFactory = new HFModelTokenizerFactory(indexSettings, environment, name, settings); + analyzer = new HFModelAnalyzer(tokenizerFactory::create); + } + + @Override + public HFModelAnalyzer get() { + return analyzer; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizer.java b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizer.java new file mode 100644 index 000000000..a05ad8632 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizer.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Objects; + +import com.google.common.io.CharStreams; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import org.apache.lucene.util.BytesRef; + +public class HFModelTokenizer extends Tokenizer { + public static final String NAME = "hf_model_tokenizer"; + private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f; + + private final CharTermAttribute termAtt; + private final PayloadAttribute payloadAtt; + private final OffsetAttribute offsetAtt; + private final HuggingFaceTokenizer tokenizer; + private final Map tokenWeights; + + private Encoding encoding; + private int tokenIdx = 0; + private int overflowingIdx = 0; + + public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) { + this(huggingFaceTokenizer, null); + } + + public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map weights) { + termAtt = addAttribute(CharTermAttribute.class); + offsetAtt = addAttribute(OffsetAttribute.class); + if (Objects.nonNull(weights)) { + payloadAtt = addAttribute(PayloadAttribute.class); + } else { + payloadAtt = null; + } + tokenizer = huggingFaceTokenizer; + tokenWeights = weights; + } + + @Override + public void reset() throws IOException { + super.reset(); + tokenIdx = 0; + overflowingIdx = -1; + String inputStr = CharStreams.toString(input); + encoding = tokenizer.encode(inputStr, false, true); + } + + private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) { + return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0; + } + + public static byte[] floatToBytes(float value) { + return ByteBuffer.allocate(4).putFloat(value).array(); + } + + public static float bytesToFloat(byte[] bytes) { + return ByteBuffer.wrap(bytes).getFloat(); + } + + @Override + final public boolean incrementToken() throws IOException { + clearAttributes(); + Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx]; + + while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) { + if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) { + // reset cur segment, go to the next segment + // until overflowingIdx = encoding.getOverflowing().length + tokenIdx = 0; + overflowingIdx++; + if (overflowingIdx >= encoding.getOverflowing().length) { + return false; + } + curEncoding = encoding.getOverflowing()[overflowingIdx]; + } else { + termAtt.append(curEncoding.getTokens()[tokenIdx]); + offsetAtt.setOffset( + curEncoding.getCharTokenSpans()[tokenIdx].getStart(), + curEncoding.getCharTokenSpans()[tokenIdx].getEnd() + ); + if (Objects.nonNull(tokenWeights)) { + // for neural sparse query, write the token weight to payload field + payloadAtt.setPayload( + new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT))) + ); + } + tokenIdx++; + return true; + } + } + + return false; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizerFactory.java b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizerFactory.java new file mode 100644 index 000000000..03acf0b4e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizerFactory.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import org.apache.lucene.analysis.Tokenizer; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.analysis.AbstractTokenizerFactory; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; + +import java.util.Map; +import java.util.Objects; + +public class HFModelTokenizerFactory extends AbstractTokenizerFactory { + private final HuggingFaceTokenizer tokenizer; + private final Map tokenWeights; + + /** + * Atomically loads the HF tokenizer in a lazy fashion once the outer class accesses the static final set the first time.; + */ + private static class DefaultTokenizerHolder { + static final HuggingFaceTokenizer TOKENIZER; + static final Map TOKEN_WEIGHTS; + static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"; + static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt"; + + static { + try { + TOKENIZER = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID); + TOKEN_WEIGHTS = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize default hf_model_tokenizer", e); + } + } + } + + static public Tokenizer createDefault() { + return new HFModelTokenizer(DefaultTokenizerHolder.TOKENIZER, DefaultTokenizerHolder.TOKEN_WEIGHTS); + } + + public HFModelTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) { + // For custom tokenizer, the factory is created during IndexModule.newIndexService + // And can be accessed via indexService.getIndexAnalyzers() + super(indexSettings, settings, name); + String tokenizerId = settings.get("tokenizer_id", null); + Objects.requireNonNull(tokenizerId, "tokenizer_id is required"); + String tokenWeightsFileName = settings.get("token_weights_file", null); + tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId); + if (tokenWeightsFileName != null) { + tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName); + } else { + tokenWeights = null; + } + } + + @Override + public Tokenizer create() { + // the create method will be called for every single analyze request + return new HFModelTokenizer(tokenizer, tokenWeights); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 1350a7963..931cbe4b9 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -7,6 +7,7 @@ import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -14,6 +15,7 @@ import java.util.Optional; import java.util.function.Supplier; +import org.apache.lucene.analysis.Analyzer; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -24,8 +26,19 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.analysis.AnalyzerProvider; +import org.opensearch.index.analysis.PreBuiltAnalyzerProviderFactory; +import org.opensearch.index.analysis.PreConfiguredTokenizer; +import org.opensearch.index.analysis.TokenizerFactory; +import org.opensearch.indices.analysis.AnalysisModule; +import org.opensearch.indices.analysis.PreBuiltCacheFactory; import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.neuralsearch.analysis.DJLUtils; +import org.opensearch.neuralsearch.analysis.HFModelAnalyzer; +import org.opensearch.neuralsearch.analysis.HFModelAnalyzerProvider; +import org.opensearch.neuralsearch.analysis.HFModelTokenizer; +import org.opensearch.neuralsearch.analysis.HFModelTokenizerFactory; import org.opensearch.neuralsearch.executors.HybridQueryExecutor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; @@ -56,6 +69,7 @@ import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.AnalysisPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; @@ -77,7 +91,14 @@ * Neural Search plugin class */ @Log4j2 -public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { +public class NeuralSearch extends Plugin + implements + ActionPlugin, + SearchPlugin, + IngestPlugin, + ExtensiblePlugin, + SearchPipelinePlugin, + AnalysisPlugin { private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); @@ -103,6 +124,7 @@ public Collection createComponents( NeuralSparseQueryBuilder.initialize(clientAccessor); HybridQueryExecutor.initialize(threadPool); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); + DJLUtils.buildDJLCachePath(environment.dataFiles()[0]); return List.of(clientAccessor); } @@ -200,4 +222,30 @@ public List> getSearchExts() { ) ); } + + @Override + public Map> getTokenizers() { + return Map.of(HFModelTokenizer.NAME, HFModelTokenizerFactory::new); + } + + @Override + public List getPreConfiguredTokenizers() { + List tokenizers = new ArrayList<>(); + tokenizers.add(PreConfiguredTokenizer.singleton(HFModelTokenizer.NAME, HFModelTokenizerFactory::createDefault)); + return tokenizers; + } + + @Override + public Map>> getAnalyzers() { + return Map.of(HFModelAnalyzer.NAME, HFModelAnalyzerProvider::new); + } + + @Override + public List getPreBuiltAnalyzerProviderFactories() { + List factories = new ArrayList<>(); + factories.add( + new PreBuiltAnalyzerProviderFactory(HFModelAnalyzer.NAME, PreBuiltCacheFactory.CachingStrategy.ONE, HFModelAnalyzer::new) + ); + return factories; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index be9719452..c51699d78 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -16,10 +16,15 @@ import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; import org.apache.lucene.document.FeatureField; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.client.Client; import org.opensearch.common.SetOnce; @@ -36,6 +41,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.neuralsearch.analysis.HFModelTokenizer; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; @@ -74,10 +80,13 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder> queryTokensSupplier; // A field that for neural_sparse_two_phase_processor, if twoPhaseSharedQueryToken is not null, @@ -93,6 +102,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); @@ -140,6 +153,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float .queryName(this.queryName) .queryText(this.queryText) .modelId(this.modelId) + .analyzer(this.analyzer) .maxTokenScore(this.maxTokenScore) .twoPhasePruneRatio(-1f * pruneRatio); if (Objects.nonNull(this.queryTokensSupplier)) { @@ -167,6 +181,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { } else { out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY)); } + if (isClusterOnOrAfterMinReqVersionForAnalyzer()) { + out.writeOptionalString(this.analyzer); + } out.writeOptionalFloat(maxTokenScore); if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) { out.writeBoolean(true); @@ -186,6 +203,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (Objects.nonNull(modelId)) { xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); } + if (Objects.nonNull(analyzer)) { + xContentBuilder.field(ANALYZER_FIELD.getPreferredName(), analyzer); + } if (Objects.nonNull(maxTokenScore)) { xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore); } @@ -275,6 +295,9 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.modelId())) { throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName())); } + if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.analyzer())) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", ANALYZER_FIELD.getPreferredName())); + } return sparseEncodingQueryBuilder; } @@ -294,6 +317,8 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui sparseEncodingQueryBuilder.queryText(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseEncodingQueryBuilder.modelId(parser.text()); + } else if (ANALYZER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.analyzer(parser.text()); } else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue()); } else { @@ -324,6 +349,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { if (Objects.nonNull(queryTokensSupplier)) { return this; } + if (Objects.nonNull(analyzer)) { + return this; + } validateForRewrite(queryText, modelId); SetOnce> queryTokensSetOnce = new SetOnce<>(); queryRewriteContext.registerAsyncAction(getModelInferenceAsync(queryTokensSetOnce)); @@ -363,14 +391,36 @@ private BiConsumer> getModelInferenceAsync(SetOnce getQueryTokens(QueryShardContext context) { + if (Objects.nonNull(queryTokensSupplier)) { + return queryTokensSupplier.get(); + } else if (Objects.nonNull(analyzer)) { + Map queryTokens = new HashMap<>(); + Analyzer luceneAnalyzer = context.convertToShardContext().getIndexAnalyzers().getAnalyzers().get(this.analyzer); + try (TokenStream stream = luceneAnalyzer.tokenStream(fieldName, queryText)) { + stream.reset(); + CharTermAttribute term = stream.addAttribute(CharTermAttribute.class); + PayloadAttribute payload = stream.addAttribute(PayloadAttribute.class); + + while (stream.incrementToken()) { + String token = term.toString(); + Float weight = Objects.isNull(payload.getPayload()) ? 1 : HFModelTokenizer.bytesToFloat(payload.getPayload().bytes); + queryTokens.put(token, weight); + } + stream.end(); + } catch (IOException e) { + throw new OpenSearchException("failed to analyze query text. ", e); + } + return queryTokens; + } + throw new IllegalArgumentException("Query tokens cannot be null."); + } + @Override protected Query doToQuery(QueryShardContext context) throws IOException { final MappedFieldType ft = context.fieldMapper(fieldName); validateFieldType(ft); - Map queryTokens = queryTokensSupplier.get(); - if (Objects.isNull(queryTokens)) { - throw new IllegalArgumentException("Query tokens cannot be null."); - } + Map queryTokens = getQueryTokens(context); BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Map.Entry entry : queryTokens.entrySet()) { builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD); @@ -447,4 +497,7 @@ private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); } + private static boolean isClusterOnOrAfterMinReqVersionForAnalyzer() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_ANALYZER); + } } diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy index db2413e86..dc04de6ab 100644 --- a/src/main/plugin-metadata/plugin-security.policy +++ b/src/main/plugin-metadata/plugin-security.policy @@ -4,4 +4,10 @@ grant { permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; permission java.lang.RuntimePermission "setContextClassLoader"; + + permission java.net.SocketPermission "*", "connect,resolve"; + permission java.lang.RuntimePermission "loadLibrary.*"; + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.util.PropertyPermission "DJL_CACHE_DIR", "read,write"; + permission java.util.PropertyPermission "java.library.path", "read,write"; }; diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 9a969e71b..1d30b0d21 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -8,6 +8,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Collection; import java.util.List; import java.util.Map; @@ -61,6 +63,8 @@ public class NeuralSearchTests extends OpenSearchQueryTestCase { private ClusterService clusterService; @Mock private ThreadPool threadPool; + @Mock + private Environment environment; @Before public void setup() { @@ -77,6 +81,7 @@ public void setup() { public void testCreateComponents() { // clientAccessor can not be null, and this is the only way to access it from this test plugin.getProcessors(ingestParameters); + when(environment.dataFiles()).thenReturn(new Path[] { Paths.get("test") }); Collection components = plugin.createComponents( null, clusterService, @@ -84,7 +89,7 @@ public void testCreateComponents() { null, null, null, - null, + environment, null, null, null,