diff --git a/CHANGELOG.md b/CHANGELOG.md index 70ec3ae51..698d37fa6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features +* Added a boolean flag that enable user to use Lucene searcher on FAISS index. [#2405](https://github.com/opensearch-project/k-NN/pull/2405) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 14e95887c..66d0ee0b9 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -120,6 +120,8 @@ public class KNNConstants { public static final List FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16); public static final String FAISS_SIGNED_BYTE_SQ = "SQ8_direct_signed"; public static final String FAISS_SQ_CLIP = "clip"; + public static final String USE_LUCENE_HNSW_SEARCHER = "use_lucene_searcher"; + public static final boolean DEFAULT_USE_LUCENE_HNSW_SEARCHER = false; // Parameter defaults/limits public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 2366a6d57..21ee1eb61 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -21,9 +21,21 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.opensearch.common.UUIDs; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.codec.luceneonfaiss.FaissHNSWVectorReader; +import org.opensearch.knn.index.codec.luceneonfaiss.LuceneOnFaissUtils; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; @@ -32,6 +44,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; +import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -48,12 +61,71 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private Map quantizationStateCacheKeyPerField; private SegmentReadState segmentReadState; private final List cacheKeys; + private Map faissHNSWVectorReaderMap; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) { this.flatVectorsReader = flatVectorsReader; this.segmentReadState = state; this.cacheKeys = getVectorCacheKeysFromSegmentReaderState(state); + this.faissHNSWVectorReaderMap = new HashMap<>(2 * state.fieldInfos.size(), 0.6f); loadCacheKeyMap(); + loadFaissIndexForLuceneSearcher(segmentReadState); + } + + private void loadFaissIndexForLuceneSearcher(SegmentReadState state) { + for (FieldInfo fieldInfo : state.fieldInfos) { + // Ex: {"index_description":"HNSW16,Flat","spaceType":"l2","name":"hnsw","data_type":"float", + // "parameters":{"use_lucene_searcher":true,"ef_search":100,"ef_construction":100,"encoder":{"name":"flat","parameters":{}}}} + final String parametersString = fieldInfo.getAttribute(KNNConstants.PARAMETERS); + if (parametersString != null) { + try { + try ( + XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ) + ) { + // Extract boolean flag + final Map parameters = parser.map(); + final Object innerParameters = parameters.get(KNNConstants.PARAMETERS); + if (!LuceneOnFaissUtils.isUseLuceneOnFaiss(innerParameters)) { + continue; + } + + // Acquire index file name + final String faissIndexFile = KNNCodecUtil.getNativeEngineFileFromFieldInfo(fieldInfo, state.segmentInfo); + if (faissIndexFile == null) { + continue; + } + + // Load faiss index with IndexInput + final IndexInput indexInput = state.directory.openInput(faissIndexFile, + new IOContext(IOContext.Context.DEFAULT, + null, + null, + ReadAdvice.RANDOM + ) + ); + + try { + final FaissHNSWVectorReader vectorReader = new FaissHNSWVectorReader(indexInput); + faissHNSWVectorReaderMap.put(fieldInfo.getName(), vectorReader); + } catch (Exception e) { + // If something went bad, we close the stream and rethrow + try { + indexInput.close(); + } catch (Exception ioException) { + // Ignore + } + throw e; + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } } /** @@ -124,18 +196,27 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits String cacheKey = quantizationStateCacheKeyPerField.get(field); FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(field); QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() - .getQuantizationState( - new QuantizationStateReadConfig( - segmentReadState, - QuantizationService.getInstance().getQuantizationParams(fieldInfo), - field, - cacheKey - ) - ); + .getQuantizationState(new QuantizationStateReadConfig(segmentReadState, + QuantizationService.getInstance().getQuantizationParams(fieldInfo), + field, + cacheKey + )); ((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState); return; } - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + + // Try with Lucene searcher + final FaissHNSWVectorReader vectorReader = faissHNSWVectorReaderMap.get(field); + if (vectorReader != null) { + try { + vectorReader.search(target, knnCollector, acceptDocs); + } catch (Exception e) { + // KDY + e.printStackTrace(); + } + } else { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } } /** @@ -187,8 +268,10 @@ public void close() throws IOException { final NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); cacheKeys.forEach(nativeMemoryCacheManager::invalidate); - // Close a reader. - IOUtils.close(flatVectorsReader); + // Close all reader. + List readers = new ArrayList<>(faissHNSWVectorReaderMap.values()); + readers.add(flatVectorsReader); + IOUtils.close(readers); // Clean up quantized state cache. if (quantizationStateCacheKeyPerField != null) { diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSW.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSW.java new file mode 100644 index 000000000..2d2b62052 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSW.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * Ported implementation of the FAISS HNSW graph search algorithm. + * While it follows the same steps as the original FAISS implementation, differences in how the JVM and C++ handle floating-point + * calculations can lead to slight variations in results. However, such cases are very rare, and in most instances, the results are + * identical to FAISS. Even when there are ranking differences, they do not impact the precision or recall of the search. + * For more details, refer to the [FAISS HNSW implementation]( + * ...). + */ +@Getter +public class FaissHNSW { + // Cumulative number of neighbors per each level. + private int[] cumNumberNeighborPerLevel; + // Offset to be added to cumNumberNeighborPerLevel[level] to get the actual start offset of neighbor list. + private long[] offsets = null; + // Neighbor list storage. + private final Storage neighbors = new Storage(); + // Entry point in HNSW graph + private int entryPoint; + // Maximum level of HNSW graph + private int maxLevel = -1; + // Default efSearch parameter. This determines the navigation queue size. + // More value, algorithm will more navigate candidates. + private int efSearch = 16; + // Total number of vectors stored in graph. + private long totalNumberOfVectors; + + /** + * Partially loads the FAISS HNSW graph from the provided index input stream. + * The graph is divided into multiple sections, and this method marks the starting offset of each section then skip to the next + * section instead of loading the entire graph into memory. During the search, bytes will be accessed via {@link IndexInput}. + * + * @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations. + * @param totalNumberOfVectors The total number of vectors stored in the graph. + * @return {@link FaissHNSW}, a graph search structure that represents the FAISS HNSW graph + * @throws IOException + */ + public static FaissHNSW load(IndexInput input, long totalNumberOfVectors) throws IOException { + // Total number of vectors + FaissHNSW faissHNSW = new FaissHNSW(); + faissHNSW.totalNumberOfVectors = totalNumberOfVectors; + + // We don't use `double[] assignProbas` for search. It is for index construction. + long size = input.readLong(); + input.skipBytes(Double.BYTES * size); + + // Accumulate number of neighbor per each level. + size = input.readLong(); + faissHNSW.cumNumberNeighborPerLevel = new int[(int) size]; + if (size > 0) { + input.readInts(faissHNSW.cumNumberNeighborPerLevel, 0, (int) size); + } + + // We don't use `level`. + final Storage levels = new Storage(); + levels.markSection(input, Integer.BYTES); + + // Load `offsets` into memory. + size = input.readLong(); + faissHNSW.offsets = new long[(int) size]; + input.readLongs(faissHNSW.offsets, 0, faissHNSW.offsets.length); + + // Mark neighbor list section. + faissHNSW.neighbors.markSection(input, Integer.BYTES); + + // HNSW graph parameters + faissHNSW.entryPoint = input.readInt(); + + faissHNSW.maxLevel = input.readInt(); + + // We don't use this field. It's for index building. + final int efConstruction = input.readInt(); + + faissHNSW.efSearch = input.readInt(); + + // dummy read a deprecated field. + input.readInt(); + + return faissHNSW; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSWFlatIndex.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSWFlatIndex.java new file mode 100644 index 000000000..1bd076a90 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSWFlatIndex.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * A flat HNSW index that contains both an HNSW graph and flat vector storage. + * This is the ported version of `IndexHNSW` from FAISS. + * For more details, please refer to ... + */ +public class FaissHNSWFlatIndex extends FaissIndex { + public static final String IHNF = "IHNf"; + + @Getter + private FaissHNSW hnsw = new FaissHNSW(); + @Getter + private FaissIndexFlat storage; + + /** + * Partially loads both the HNSW graph and the underlying flat vectors. + * + * @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations. + * @return {@link FaissHNSWFlatIndex} instance consists of index hierarchy. + * @throws IOException + */ + public static FaissHNSWFlatIndex load(IndexInput input) throws IOException { + // Read common header + FaissHNSWFlatIndex faissHNSWFlatIndex = new FaissHNSWFlatIndex(); + readCommonHeader(input, faissHNSWFlatIndex); + + // Partial load HNSW graph + faissHNSWFlatIndex.hnsw = FaissHNSW.load(input, faissHNSWFlatIndex.getTotalNumberOfVectors()); + + // Partial load flat vector storage + final FaissIndex faissIndex = FaissIndex.load(input); + if (faissIndex instanceof FaissIndexFlat) { + faissHNSWFlatIndex.storage = (FaissIndexFlat) faissIndex; + } else { + throw new IllegalStateException( + "Expected flat vector storage format under [" + IHNF + "] index type, but got " + faissIndex.getIndexType()); + } + return faissHNSWFlatIndex; + } + + @Override + public String getIndexType() { + return IHNF; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSWVectorReader.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSWVectorReader.java new file mode 100644 index 000000000..bdbc9f9f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissHNSWVectorReader.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOSupplier; +import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.Closeable; +import java.io.IOException; + +public class FaissHNSWVectorReader implements Closeable { + private static FlatVectorsScorer VECTOR_SCORER = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); + + private IndexInput indexInput; + private FaissIdMapIndex faissIdMapIndex; + private FaissIndexFlat faissIndexFlat; + private LuceneFaissHnswGraph faissHnswGraph; + + public FaissHNSWVectorReader(IndexInput indexInput) throws IOException { + this.indexInput = indexInput; + faissIdMapIndex = (FaissIdMapIndex) FaissIndex.load(indexInput); + final FaissHNSWFlatIndex faissHNSWFlatIndex = faissIdMapIndex.getNestedIndex(); + faissIndexFlat = faissHNSWFlatIndex.getStorage(); + faissHnswGraph = new LuceneFaissHnswGraph(faissIdMapIndex.getNestedIndex(), indexInput); + } + + public void search(float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + search(VectorEncoding.FLOAT32, + () -> VECTOR_SCORER.getRandomVectorScorer(faissIndexFlat.getSimilarityFunction(), + faissIdMapIndex.getFloatValues(indexInput), + target + ), + knnCollector, + acceptDocs + ); + } + + private void search( + final VectorEncoding vectorEncoding, + final IOSupplier scorerSupplier, + final KnnCollector knnCollector, + final Bits acceptDocs + ) throws IOException { + if (faissIndexFlat.getTotalNumberOfVectors() == 0 || knnCollector.k() == 0 + || faissIndexFlat.getVectorEncoding() != vectorEncoding) { + return; + } + + final RandomVectorScorer scorer = scorerSupplier.get(); + final KnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + + if (knnCollector.k() < scorer.maxOrd()) { + HnswGraphSearcher.search(scorer, collector, faissHnswGraph, acceptedOrds); + } else { + // if k is larger than the number of vectors, we can just iterate over all vectors + // and collect them. + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + if (!knnCollector.earlyTerminated()) { + knnCollector.incVisitedCount(1); + knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); + } else { + break; + } + } + } + } // End if + } + + @Override + public void close() throws IOException { + indexInput.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIdMapIndex.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIdMapIndex.java new file mode 100644 index 000000000..7e3ae762a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIdMapIndex.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; + +import java.io.IOException; + +/** + * A FAISS index with an ID mapping that maps the internal vector ID to a logical ID, along with the actual vector index. + * It first delegates the vector search to its nested vector index, then transforms the vector ID into a logical index that is + * understandable by upstream components. This is particularly useful when not all Lucene documents are indexed with a vector field. + * For example, if 70% of the documents have a vector field and the remaining 30% do not, the FAISS vector index will still assign + * increasing and continuous vector IDs starting from 0. + * However, these IDs only cover the sparse 30% of Lucene documents, so an ID mapping is needed to convert the internal physical vector ID + * into the corresponding Lucene document ID. + * If the mapping is an identity mapping, where each `i` is mapped to itself, we omit storing it to save memory. + */ +public class FaissIdMapIndex extends FaissIndex { + public static final String IXMP = "IxMp"; + + @Getter + private FaissHNSWFlatIndex nestedIndex; + private long[] vectorIdToDocIdMapping; + private long oneVectorByteSize; + + /** + * Partially load id mapping and its nested index to which vector searching will be delegated. + * + * @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations. + * @return {@link FaissIdMapIndex} instance consists of index hierarchy. + * @throws IOException + */ + public static FaissIdMapIndex load(IndexInput input) throws IOException { + FaissIdMapIndex faissIdMapIndex = new FaissIdMapIndex(); + readCommonHeader(input, faissIdMapIndex); + FaissIndex nestedIndex = FaissIndex.load(input); + if (nestedIndex instanceof FaissHNSWFlatIndex) { + faissIdMapIndex.nestedIndex = (FaissHNSWFlatIndex) nestedIndex; + } else { + throw new IllegalStateException("Invalid nested index. Expected FaissHNSWFlatIndex, but got " + nestedIndex.getIndexType()); + } + faissIdMapIndex.oneVectorByteSize = faissIdMapIndex.nestedIndex.getStorage().getOneVectorByteSize(); + + // Load `idMap` + final long numElements = input.readLong(); + long[] vectorIdToDocIdMapping = new long[(int) numElements]; + input.readLongs(vectorIdToDocIdMapping, 0, vectorIdToDocIdMapping.length); + + // If `idMap` is an identity function that maps `i` to `i`, then we don't need to keep it. + for (int i = 0; i < vectorIdToDocIdMapping.length; i++) { + if (vectorIdToDocIdMapping[i] != i) { + // Only keep it if it's not an identify mapping. + faissIdMapIndex.vectorIdToDocIdMapping = vectorIdToDocIdMapping; + break; + } + } + + return faissIdMapIndex; + } + + @Override + public String getIndexType() { + return IXMP; + } + + public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException { + if (vectorIdToDocIdMapping == null) { + return denseFloatValues(indexInput); + } + + return sparseFloatValues(indexInput); + } + + private FloatVectorValues denseFloatValues(IndexInput indexInput) throws IOException { + final FaissIndexFlat indexFlat = nestedIndex.getStorage(); + final Storage codes = indexFlat.getCodes(); + final int dimension = getDimension(); + final int totalNumVectors = (int) getTotalNumberOfVectors(); + + @RequiredArgsConstructor + class DenseFloatVectorValuesImpl extends FloatVectorValues { + final IndexInput data; + final float[] vector = new float[dimension]; + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + data.seek(oneVectorByteSize * targetOrd); + data.readFloats(vector, 0, vector.length); + return vector; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return totalNumVectors; + } + + @Override + public FloatVectorValues copy() throws IOException { + return new DenseFloatVectorValuesImpl(indexInput.slice("FaissIndexFlat", codes.baseOffset, codes.sectionSize)); + } + } + + return new DenseFloatVectorValuesImpl(indexInput.slice("FaissIndexFlat", codes.baseOffset, codes.sectionSize)); + } + + private FloatVectorValues sparseFloatValues(IndexInput indexInput) throws IOException { + final FaissIndexFlat indexFlat = nestedIndex.getStorage(); + final Storage codes = indexFlat.getCodes(); + final int dimension = getDimension(); + final int totalNumVectors = (int) getTotalNumberOfVectors(); + + @RequiredArgsConstructor + class SparseFloatVectorValuesImpl extends FloatVectorValues { + final IndexInput data; + final float[] vector = new float[dimension]; + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + data.seek(oneVectorByteSize * targetOrd); + data.readFloats(vector, 0, vector.length); + return vector; + } + + @Override + public int dimension() { + return dimension; + } + + public int ordToDoc(int ord) { + return (int) vectorIdToDocIdMapping[ord]; + } + + public Bits getAcceptOrds(final Bits acceptDocs) { + return acceptDocs == null ? null : new Bits() { + @Override + public boolean get(int ord) { + return acceptDocs.get((int) vectorIdToDocIdMapping[ord]); + } + + @Override + public int length() { + return totalNumVectors; + } + }; + } + + @Override + public int size() { + return totalNumVectors; + } + + @Override + public FloatVectorValues copy() throws IOException { + return new SparseFloatVectorValuesImpl(indexInput.slice("FaissIndexFlat", codes.baseOffset, codes.sectionSize)); + } + } + + return new SparseFloatVectorValuesImpl(indexInput.slice("FaissIndexFlat", codes.baseOffset, codes.sectionSize)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndex.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndex.java new file mode 100644 index 000000000..23ff23cc8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndex.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.index.SpaceType; + +import java.io.IOException; + +import static org.opensearch.knn.index.codec.luceneonfaiss.FaissHNSWFlatIndex.IHNF; +import static org.opensearch.knn.index.codec.luceneonfaiss.FaissIdMapIndex.IXMP; +import static org.opensearch.knn.index.codec.luceneonfaiss.FaissIndexFlat.IXF2; +import static org.opensearch.knn.index.codec.luceneonfaiss.FaissIndexFlat.IXFI; + +@Getter +public abstract class FaissIndex { + // Vector dimension + private int dimension; + // Total number of vectors saved within this index. + private long totalNumberOfVectors; + // Space type used to index vectors in this index. + private SpaceType spaceType; + + public static FaissIndex load(IndexInput input) throws IOException { + final String indexName = readFourBytes(input); + + switch (indexName) { + case IXMP: { + return FaissIdMapIndex.load(input); + } + case IHNF: { + return FaissHNSWFlatIndex.load(input); + } + case IXF2: + // Fallthrough + case IXFI: + return FaissIndexFlat.load(input, indexName); + default: { + throw new IllegalStateException("Partial loading does not support [" + indexName + "]."); + } + } + } + + static protected void readCommonHeader(IndexInput readStream, FaissIndex index) throws IOException { + index.dimension = readStream.readInt(); + index.totalNumberOfVectors = readStream.readLong(); + // consume 2 dummy deprecated fields. + readStream.readLong(); + readStream.readLong(); + + // We don't use this field + final boolean isTrained = readStream.readByte() == 1; + + final int metricTypeIndex = readStream.readInt(); + if (metricTypeIndex > 1) { + throw new IllegalStateException("Partial loading does not support metric type index=[" + metricTypeIndex + "] from FAISS."); + } + + if (metricTypeIndex == 0) { + index.spaceType = SpaceType.INNER_PRODUCT; + } else if (metricTypeIndex == 1) { + index.spaceType = SpaceType.L2; + } else { + throw new IllegalStateException("Partial loading does not support metric type index=" + metricTypeIndex + " from FAISS."); + } + } + + static private String readFourBytes(IndexInput input) throws IOException { + final byte[] fourBytes = new byte[4]; + input.readBytes(fourBytes, 0, fourBytes.length); + return new String(fourBytes); + } + + /** + * Returns a unique signature of the FAISS index. + * + * @return Index type string. + */ + public abstract String getIndexType(); +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlat.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlat.java new file mode 100644 index 000000000..e8e28b40c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlat.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import lombok.Getter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; +import java.util.Map; +import java.util.function.Supplier; + +@Getter +public abstract class FaissIndexFlat extends FaissIndex { + // Maps to IndexFlatL2 e.g. L2 distance + public static final String IXF2 = "IxF2"; + // Maps to IndexFlatIP e.g. InnerProduct + public static final String IXFI = "IxFI"; + private static Map> FLAT_INDEX_SUPPLIERS = + Map.of(IXF2, FaissIndexFlatL2::new, IXFI, FaissIndexFlatIP::new); + + private final Storage codes = new Storage(); + private int oneVectorByteSize; + private String indexType; + + public static FaissIndex load(final IndexInput input, final String indexType) throws IOException { + FaissIndexFlat faissIndexFlat = FLAT_INDEX_SUPPLIERS.getOrDefault( + indexType, + () -> {throw new IllegalStateException("Faiss index flat [" + indexType + "] is not supported.");} + ).get(); + + readCommonHeader(input, faissIndexFlat); + faissIndexFlat.oneVectorByteSize = Float.BYTES * faissIndexFlat.getDimension(); + + faissIndexFlat.codes.markSection(input, Float.BYTES); + if (faissIndexFlat.codes.getSectionSize() != (faissIndexFlat.getTotalNumberOfVectors() * faissIndexFlat.oneVectorByteSize)) { + throw new IllegalStateException("Got an inconsistent bytes size of vector [" + faissIndexFlat.codes.getSectionSize() + "] " + + "when faissIndexFlat.totalNumberOfVectors=" + faissIndexFlat.getTotalNumberOfVectors() + + ", faissIndexFlat.oneVectorByteSize=" + faissIndexFlat.oneVectorByteSize); + } + + faissIndexFlat.indexType = indexType; + + return faissIndexFlat; + } + + public VectorEncoding getVectorEncoding() { + // We only support float[] at the moment. + return VectorEncoding.FLOAT32; + } + + public abstract VectorSimilarityFunction getSimilarityFunction(); +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlatIP.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlatIP.java new file mode 100644 index 000000000..e2f5cd2ae --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlatIP.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import org.apache.lucene.index.VectorSimilarityFunction; + +public class FaissIndexFlatIP extends FaissIndexFlat { + + @Override + public VectorSimilarityFunction getSimilarityFunction() { + return VectorSimilarityFunction.DOT_PRODUCT; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlatL2.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlatL2.java new file mode 100644 index 000000000..49af2429c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/FaissIndexFlatL2.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import org.apache.lucene.index.VectorSimilarityFunction; + +public class FaissIndexFlatL2 extends FaissIndexFlat { + @Override + public VectorSimilarityFunction getSimilarityFunction() { + return VectorSimilarityFunction.EUCLIDEAN; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/LuceneFaissHnswGraph.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/LuceneFaissHnswGraph.java new file mode 100644 index 000000000..c09b6c1c6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/LuceneFaissHnswGraph.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +public class LuceneFaissHnswGraph extends HnswGraph { + private final FaissHNSW faissHnsw; + private final IndexInput indexInput; + private final int numVectors; + private int[] neighborIdList; + private int numNeighbors; + private int nextNeighborIndex; + + public LuceneFaissHnswGraph(FaissHNSWFlatIndex hnswFlatIndex, IndexInput indexInput) { + this.faissHnsw = hnswFlatIndex.getHnsw(); + this.indexInput = indexInput; + this.numVectors = (int) hnswFlatIndex.getStorage().getTotalNumberOfVectors(); + } + + @Override + public void seek(int level, int target) { + // Get a relative starting offset of neighbor list at `level`. + long o = faissHnsw.getOffsets()[target]; + + // `begin` and `end` represent for a pair of staring offset and end offset. + // But, what `end` represents is the maximum offset a neighbor list at a level can have. + // Therefore, it is required to traverse a list until getting a terminal `-1`. + final long begin = o + faissHnsw.getCumNumberNeighborPerLevel()[level]; + final long end = o + faissHnsw.getCumNumberNeighborPerLevel()[level + 1]; + loadNeighborIdList(begin, end); + } + + private void loadNeighborIdList(final long begin, final long end) { + // Make sure we have sufficient space for neighbor list + final long maxLength = end - begin; + if (neighborIdList == null || neighborIdList.length < maxLength) { + neighborIdList = new int[(int) (maxLength * 1.5)]; + } + + // Seek to the first offset of neighbor list + try { + indexInput.seek(faissHnsw.getNeighbors().getBaseOffset() + Integer.BYTES * begin); + } catch (IOException e) { + throw new RuntimeException(e); + } + + // Fill the array with neighbor ids + int index = 0; + try { + for (long i = begin; i < end; i++) { + final int neighborId = indexInput.readInt(); + if (neighborId >= 0) { + neighborIdList[index++] = neighborId; + } else { + break; + } + } + + // Set variables for navigation + numNeighbors = index; + nextNeighborIndex = 0; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public int size() { + return numVectors; + } + + @Override + public int nextNeighbor() { + if (nextNeighborIndex < numNeighbors) { + return neighborIdList[nextNeighborIndex++]; + } + + // Neighbor list has been exhausted. + return NO_MORE_DOCS; + } + + @Override + public int numLevels() { + return faissHnsw.getMaxLevel(); + } + + @Override + public int entryNode() { + return faissHnsw.getEntryPoint(); + } + + @Override + public NodesIterator getNodesOnLevel(int i) { + throw new UnsupportedOperationException(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/LuceneOnFaissUtils.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/LuceneOnFaissUtils.java new file mode 100644 index 000000000..33d36e533 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/LuceneOnFaissUtils.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import org.opensearch.knn.common.KNNConstants; + +import java.util.Map; + +public final class LuceneOnFaissUtils { + private LuceneOnFaissUtils() { + } + + public static boolean isUseLuceneOnFaiss(Object mapObject) { + if (mapObject instanceof Map) { + Map map = (Map) mapObject; + Object value = map.get(KNNConstants.USE_LUCENE_HNSW_SEARCHER); + return (value instanceof Boolean) ? (Boolean) value : false; + } + + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/Storage.java b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/Storage.java new file mode 100644 index 000000000..28cb156bf --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/luceneonfaiss/Storage.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.luceneonfaiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +public class Storage { + @Getter + protected long baseOffset; + @Getter + protected long sectionSize; + + /** + * Mark the starting offset and the size of section then skip to the next section. + * + * @param input Input read stream. + * @param singleElementSize Size of atomic element. In file, it only stores the number of elements and the size of element will be + * used to calculate the actual size of section. Ex: size=100, element=int, then the actual section size=400. + * @throws IOException + */ + public void markSection(IndexInput input, int singleElementSize) throws IOException { + this.sectionSize = input.readLong() * singleElementSize; + this.baseOffset = input.getFilePointer(); + // Skip the whole section and jump to the next section in the file. + try { + input.seek(baseOffset + sectionSize); + } catch (IOException e) { + throw new IOException("Failed to partial load where baseOffset=" + baseOffset + ", sectionSize=" + sectionSize, e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index 3386f871c..ba172bb18 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -24,12 +24,14 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_USE_LUCENE_HNSW_SEARCHER; import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.USE_LUCENE_HNSW_SEARCHER; /** * Faiss HNSW method implementation @@ -105,6 +107,14 @@ private static MethodComponent initMethodComponent() { ) ) .addParameter(METHOD_ENCODER_PARAMETER, initEncoderParameter()) + .addParameter( + USE_LUCENE_HNSW_SEARCHER, + new Parameter.BooleanParameter( + USE_LUCENE_HNSW_SEARCHER, + DEFAULT_USE_LUCENE_HNSW_SEARCHER, + (v, context) -> true + ) + ) .setKnnLibraryIndexingContextGenerator(((methodComponent, methodComponentContext, knnMethodConfigContext) -> { MethodAsMapBuilder methodAsMapBuilder = MethodAsMapBuilder.builder( FAISS_HNSW_DESCRIPTION, diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index 6c5eea08f..5e3255a52 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -51,6 +51,7 @@ public static class CreateQueryRequest { private QueryShardContext context; private RescoreContext rescoreContext; private Boolean expandNested; + private boolean forceUseLuceneSearcher; public Optional getFilter() { return Optional.ofNullable(filter); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index f032210aa..4711f34fd 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -42,6 +42,7 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; +import org.opensearch.knn.index.codec.luceneonfaiss.LuceneOnFaissUtils; import java.io.IOException; import java.util.Arrays; @@ -398,6 +399,7 @@ protected Query doToQuery(QueryShardContext context) { KNNEngine knnEngine = queryConfigFromMapping.getKnnEngine(); MethodComponentContext methodComponentContext = queryConfigFromMapping.getMethodComponentContext(); + final boolean forceUseLuceneSearcher = LuceneOnFaissUtils.isUseLuceneOnFaiss(methodComponentContext.getParameters()); SpaceType spaceType = queryConfigFromMapping.getSpaceType(); VectorDataType vectorDataType = queryConfigFromMapping.getVectorDataType(); RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); @@ -528,6 +530,7 @@ protected Query doToQuery(QueryShardContext context) { .context(context) .rescoreContext(processedRescoreContext) .expandNested(expandNested) + .forceUseLuceneSearcher(forceUseLuceneSearcher) .build(); return KNNQueryFactory.create(createQueryRequest); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index b6770553b..a0da36bba 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -49,8 +49,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { final Query filterQuery = getFilterQuery(createQueryRequest); final Map methodParameters = createQueryRequest.getMethodParameters(); final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null); - final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); final boolean expandNested = createQueryRequest.getExpandNested().orElse(false); + final boolean forceUseLuceneSearcher = createQueryRequest.isForceUseLuceneSearcher(); BitSetProducer parentFilter = null; int shardId = -1; if (createQueryRequest.getContext().isPresent()) { @@ -70,7 +70,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { ); } - if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + if (!forceUseLuceneSearcher && KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine()); log.debug( @@ -121,7 +121,6 @@ public static Query create(CreateQueryRequest createQueryRequest) { requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH); } int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch); - log.debug("Creating Lucene k-NN query for index: {}, field:{}, k: {}", indexName, fieldName, k); switch (vectorDataType) { case BYTE: case BINARY: