Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial loading implementation for FAISS HNSW #2405

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
* Removed JDK 11 and 17 version from CI runs [#1921](https://github.com/opensearch-project/k-NN/pull/1921)
* Upgrade min JDK compatibility to JDK 21 [#2422](https://github.com/opensearch-project/k-NN/pull/2422)
* Added initial implementation of partial loading [#2405](https://github.com/opensearch-project/k-NN/pull/2405)
### Documentation
### Maintenance
* Update package name to fix compilation issue [#2513](https://github.com/opensearch-project/k-NN/pull/2513)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.apache.lucene.util.VectorUtil;

public enum KNNVectorDistanceFunction {
EUCLIDEAN {
@Override
public float distance(float[] vec1, float[] vec2) {
return VectorUtil.squareDistance(vec1, vec2);
}

@Override
public float distance(byte[] vec1, byte[] vec2) {
return VectorUtil.squareDistance(vec1, vec2);
}
},
DOT_PRODUCT {
@Override
public float distance(float[] vec1, float[] vec2) {
return -VectorUtil.dotProduct(vec1, vec2);
}

@Override
public float distance(byte[] vec1, byte[] vec2) {
return -VectorUtil.dotProduct(vec1, vec2);
}
},
COSINE {
@Override
public float distance(float[] vec1, float[] vec2) {
return VectorUtil.cosine(vec1, vec2);
}

@Override
public float distance(byte[] vec1, byte[] vec2) {
return VectorUtil.cosine(vec1, vec2);
}
};

public abstract float distance(float[] vec1, float[] vec2);

public abstract float distance(byte[] vec1, byte[] vec2);
}
19 changes: 19 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.EUCLIDEAN;
}

@Override
public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
return KNNVectorDistanceFunction.EUCLIDEAN;
}

@Override
public float scoreToDistanceTranslation(float score) {
if (score == 0) {
Expand Down Expand Up @@ -82,6 +87,11 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
}

@Override
public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
return KNNVectorDistanceFunction.COSINE;
}

@Override
public void validateVector(byte[] vector) {
if (isZeroVector(vector)) {
Expand Down Expand Up @@ -133,6 +143,11 @@ public float scoreTranslation(float rawScore) {
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
}

@Override
public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
return KNNVectorDistanceFunction.DOT_PRODUCT;
}
},
HAMMING("hamming") {
@Override
Expand Down Expand Up @@ -177,6 +192,10 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {

public abstract float scoreTranslation(float rawScore);

public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a knn vector distance function", getValue()));
}

/**
* Get KNNVectorSimilarityFunction that maps to this SpaceType
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.partialloading.PartialLoadingContext;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.ReadWriteLock;
Expand Down Expand Up @@ -100,6 +102,10 @@ default boolean decRef() {
return true;
}

default PartialLoadingContext getPartialLoadingContext() {
return null;
}

/**
* Represents native indices loaded into memory. Because these indices are backed by files, they should be
* freed when file is deleted.
Expand All @@ -121,6 +127,27 @@ class IndexAllocation implements NativeMemoryAllocation {
@Getter
private final boolean isBinaryIndex;
private final RefCountedReleasable<IndexAllocation> refCounted;
@Getter
private final PartialLoadingContext partialLoadingContext;

/**
* Constructor
*
* @param executorService Executor service used to close the allocation
* @param knnEngine KNNEngine associated with the index allocation
* @param vectorFileName Vector file name. Ex: _0_165_my_field.faiss
* @param openSearchIndexName Name of OpenSearch index this index is associated with
*/
IndexAllocation(
ExecutorService executorService,
KNNEngine knnEngine,
String vectorFileName,
String openSearchIndexName,
boolean isBinaryIndex,
PartialLoadingContext partialLoadingContext
) {
this(executorService, 0, 0, knnEngine, vectorFileName, openSearchIndexName, null, isBinaryIndex, partialLoadingContext);
}

/**
* Constructor
Expand All @@ -140,7 +167,7 @@ class IndexAllocation implements NativeMemoryAllocation {
String vectorFileName,
String openSearchIndexName
) {
this(executorService, memoryAddress, sizeKb, knnEngine, vectorFileName, openSearchIndexName, null, false);
this(executorService, memoryAddress, sizeKb, knnEngine, vectorFileName, openSearchIndexName, null, false, null);
}

/**
Expand All @@ -163,6 +190,41 @@ class IndexAllocation implements NativeMemoryAllocation {
String openSearchIndexName,
SharedIndexState sharedIndexState,
boolean isBinaryIndex
) {
this(
executorService,
memoryAddress,
sizeKb,
knnEngine,
vectorFileName,
openSearchIndexName,
sharedIndexState,
isBinaryIndex,
null
);
}

/**
* Constructor
*
* @param executorService Executor service used to close the allocation
* @param memoryAddress Pointer in memory to the index
* @param sizeKb Size this index consumes in kilobytes
* @param knnEngine KNNEngine associated with the index allocation
* @param vectorFileName Vector file name. Ex: _0_165_my_field.faiss
* @param openSearchIndexName Name of OpenSearch index this index is associated with
* @param sharedIndexState Shared index state. If not shared state present, pass null.
*/
IndexAllocation(
ExecutorService executorService,
long memoryAddress,
int sizeKb,
KNNEngine knnEngine,
String vectorFileName,
String openSearchIndexName,
SharedIndexState sharedIndexState,
boolean isBinaryIndex,
PartialLoadingContext partialLoadingContext
) {
this.executor = executorService;
this.closed = false;
Expand All @@ -175,6 +237,7 @@ class IndexAllocation implements NativeMemoryAllocation {
this.sharedIndexState = sharedIndexState;
this.isBinaryIndex = isBinaryIndex;
this.refCounted = new RefCountedReleasable<>("IndexAllocation-Reference", this, this::closeInternal);
this.partialLoadingContext = partialLoadingContext;
}

protected void closeInternal() {
Expand Down Expand Up @@ -218,6 +281,14 @@ private void cleanup() {
if (sharedIndexState != null) {
SharedIndexStateManager.getInstance().release(sharedIndexState);
}

if (partialLoadingContext != null) {
try {
partialLoadingContext.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.partialloading.PartialLoadingContext;
import org.opensearch.knn.partialloading.faiss.FaissIndex;
import org.opensearch.knn.partialloading.search.PartialLoadingMode;
import org.opensearch.knn.training.TrainingDataConsumer;
import org.opensearch.knn.training.VectorReader;

import java.io.Closeable;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

Expand Down Expand Up @@ -87,6 +91,15 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde
final Directory directory = indexEntryContext.getDirectory();
final int indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024);

// TMP
final PartialLoadingMode partialLoadingMode = PartialLoadingMode.DISABLED;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am little confused between mapping and setting. What was decided on whether to use mapping or setting?

Ideally if the performance and recall is equal we should eventually have an option of deprecating something that is not memory-effecient. Will having a mapping make it a one way door?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Mapping is preferred as we want it to be configured at field level.
  2. You mean MEMORY_EFFICIENT mode's performance and recall are equal to the baseline where loading everything into memory right? Its performance never can be equal to the baseline as it involves load costs. Even when MMapDirectory was configured, it was shown that FAISS baseline had the best performance.
    The whole point of MEMORY_EFFICIENT is to give an option to users to operate big vector index within a memory constraints environment.

// final PartialLoadingMode partialLoadingMode = PartialLoadingMode.MEMORY_EFFICIENT;
// TMP

if (partialLoadingMode != PartialLoadingMode.DISABLED) {
return createPartialLoadedIndexAllocation(directory, indexEntryContext, knnEngine, vectorFileName, partialLoadingMode);
}

// Try to open an index input then pass it down to native engine for loading an index.
try (IndexInput readStream = directory.openInput(vectorFileName, IOContext.READONCE)) {
final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream);
Expand All @@ -96,6 +109,45 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde
}
}

private NativeMemoryAllocation.IndexAllocation createPartialLoadedIndexAllocation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a separate class which extends NativeMemoryLoadStrategy for this? and possibly PartialIndexAllocation pojo to hold the context?

This will simplify the code and isolate partialLoading related code ideally under one fork rather than an if fork in each class. Let me know how it turns out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decision whether to partial load or not is being made within NativeMemoryLoadStrategy::load by fetching mode from mapping. I think it would be good to have PartialIndexAllocation to separate the loading logic, but it seems hard to have a subclass of NativeMemoryLoadStrategy.
Will factor out partial loading related logics to PartialIndexAllocation as you suggested in the next revision.

Directory directory,
NativeMemoryEntryContext.IndexEntryContext indexEntryContext,
KNNEngine knnEngine,
String vectorFileName,
PartialLoadingMode partialLoadingMode
) throws IOException {
validatePartialLoadingSupported(indexEntryContext, knnEngine);

// Try to open an index input then pass it down to native engine for loading an index.
FaissIndex faissIndex = null;
try (IndexInput input = directory.openInput(vectorFileName, IOContext.READONCE)) {
faissIndex = FaissIndex.partiallyLoad(input);
}

// Create partial loading context.
final PartialLoadingContext partialLoadingContext = new PartialLoadingContext(faissIndex, vectorFileName, partialLoadingMode);

return new NativeMemoryAllocation.IndexAllocation(
executor,
knnEngine,
vectorFileName,
indexEntryContext.getOpenSearchIndexName(),
IndexUtil.isBinaryIndex(knnEngine, indexEntryContext.getParameters()),
partialLoadingContext
);
}

private void validatePartialLoadingSupported(NativeMemoryEntryContext.IndexEntryContext indexEntryContext, KNNEngine knnEngine)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation seems too late, are there any validations like this while creating the mapping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree! My strategy is to have two separate PRs: 1. Core logic for partial loading 2. Extending mapping
And this PR has the core logic, and will make sure the early validation to be made during mapping creation as you suggested.
Will add TODO.

throws UnsupportedEncodingException {
if (IndexUtil.isBinaryIndex(knnEngine, indexEntryContext.getParameters())) {
throw new UnsupportedEncodingException("Partial loading search does not support binary index.");
}

if (IndexUtil.isByteIndex(indexEntryContext.getParameters())) {
throw new UnsupportedEncodingException("Partial loading search does not support byte index.");
}
}

private NativeMemoryAllocation.IndexAllocation createIndexAllocation(
final NativeMemoryEntryContext.IndexEntryContext indexEntryContext,
final KNNEngine knnEngine,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
shardId = context.getShardId();
System.out.println(" +++++++++++++++++++++++++++++++++ parentFilter = context.getParentFilter(), " + parentFilter);
}

if (parentFilter == null && expandNested) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
* Place holder for the score of the document
*/
public class KNNQueryResult {
private final int id;
private final float score;
private int id;
private float score;

public KNNQueryResult(final int id, final float score) {
this.id = id;
Expand All @@ -24,4 +24,9 @@ public int getId() {
public float getScore() {
return this.score;
}

public void reset(final int id, final float score) {
this.id = id;
this.score = score;
}
}
Loading
Loading