Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.github.jbellis.jvector.bench;

import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.quantization.MutablePQVectors;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
* Benchmark that compares the distance calculation of mutable Product Quantized vectors vs full precision vectors.
*/
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"})
@Warmup(iterations = 2)
@Measurement(iterations = 3)
@Threads(1)
public class PQDistanceCalculationMutableVectorBenchmark {
private static final Logger log = LoggerFactory.getLogger(PQDistanceCalculationMutableVectorBenchmark.class);
private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport();

private List<VectorFloat<?>> vectors;
private PQVectors pqVectors;
private List<VectorFloat<?>> queryVectors;
private ProductQuantization pq;
private BuildScoreProvider buildScoreProvider;

@Param({"1536"})
private int dimension;

@Param({"10000"})
private int vectorCount;

@Param({"100"})
private int queryCount;

@Param({ "16","32", "64","96", "192"})
private int M; // Number of subspaces for PQ

@Param
private VectorSimilarityFunction vsf;

@Setup
public void setup() throws IOException {
log.info("Creating dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount);

// Create random vectors
vectors = new ArrayList<>(vectorCount);
for (int i = 0; i < vectorCount; i++) {
vectors.add(createRandomVector(dimension));
}

// Create query vectors
queryVectors = new ArrayList<>(queryCount);
for (int i = 0; i < queryCount; i++) {
queryVectors.add(createRandomVector(dimension));
}

RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension);
// Create Mutable PQ vectors
pq = ProductQuantization.compute(ravv, M, 256, true);
pqVectors = new MutablePQVectors(pq);
// build the index vector-at-a-time (on disk)
for (int ordinal = 0; ordinal < vectors.size(); ordinal++)
{
VectorFloat<?> v = vectors.get(ordinal);
// compress the new vector and add it to the PQVectors
((MutablePQVectors)pqVectors).encodeAndSet(ordinal, v);
}
buildScoreProvider = BuildScoreProvider.pqBuildScoreProvider(vsf, pqVectors);
log.info("Created dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount);
}

@Benchmark
public void scoreCalculation(Blackhole blackhole) {
float totalSimilarity = 0;

for (VectorFloat<?> query : queryVectors) {

ScoreFunction.ApproximateScoreFunction asf = pqVectors.scoreFunctionFor(query, vsf);
for (int i = 0; i < vectorCount; i++) {
float similarity = asf.similarityTo(i);
totalSimilarity += similarity;
}
}

blackhole.consume(totalSimilarity);
}

@Benchmark
public void diversityCalculation(Blackhole blackhole) {
float totalSimilarity = 0;

for (int q = 0; q < queryCount; q++) {
for (int i = 0; i < vectorCount; i++) {
final ScoreFunction sf = buildScoreProvider.diversityProviderFor(i).scoreFunction();
float similarity = sf.similarityTo(q);
totalSimilarity += similarity;
}
}

blackhole.consume(totalSimilarity);
}

private VectorFloat<?> createRandomVector(int dimension) {
VectorFloat<?> vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension);
for (int i = 0; i < dimension; i++) {
vector.set(i, (float) Math.random());
}
return vector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,33 +211,16 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
var encodedChunk = getChunk(node2);
var encodedOffset = getOffsetInChunk(node2);
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
float dp = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
}
float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount);
// scale to [0, 1]
return (1 + dp) / 2;
};
case COSINE:
float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
return (node2) -> {
var encodedChunk = getChunk(node2);
var encodedOffset = getOffsetInChunk(node2);
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
float sum = 0;
float norm2 = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
var codebookOffset = centroidIndex * centroidLength;
sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength);
}
float cosine = sum / (float) Math.sqrt(norm1 * norm2);
// compute the cosine of the query and the codebook centroids corresponding to the encoded points
float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount);
// scale to [0, 1]
return (1 + cosine) / 2;
};
Expand All @@ -246,13 +229,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
var encodedChunk = getChunk(node2);
var encodedOffset = getOffsetInChunk(node2);
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
float sum = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
}
float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount);
// scale to [0, 1]
return 1 / (1 + sum);
};
Expand All @@ -273,40 +250,16 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
var node2Chunk = getChunk(node2);
var node2Offset = getOffsetInChunk(node2);
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
float dp = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
}
float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
// scale to [0, 1]
return (1 + dp) / 2;
};
case COSINE:
float norm1 = 0.0f;
for (int m1 = 0; m1 < subspaceCount; m1++) {
int centroidIndex = Byte.toUnsignedInt(node1Chunk.get(m1 + node1Offset));
int centroidLength = pq.subvectorSizesAndOffsets[m1][0];
var codebookOffset = centroidIndex * centroidLength;
norm1 += VectorUtil.dotProduct(pq.codebooks[m1], codebookOffset, pq.codebooks[m1], codebookOffset, centroidLength);
}
final float norm1final = norm1;
return (node2) -> {
var node2Chunk = getChunk(node2);
var node2Offset = getOffsetInChunk(node2);
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
float sum = 0;
float norm2 = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int codebookOffset = centroidIndex2 * centroidLength;
sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], centroidIndex1 * centroidLength, centroidLength);
norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength);
}
float cosine = sum / (float) Math.sqrt(norm1final * norm2);
float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
// scale to [0, 1]
return (1 + cosine) / 2;
};
Expand All @@ -315,13 +268,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
var node2Chunk = getChunk(node2);
var node2Offset = getOffsetInChunk(node2);
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
float sum = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
}
float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
// scale to [0, 1]
return 1 / (1 + sum);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,4 +584,90 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu
return squaredSum;
}

@Override
public float pqScoreDotProduct(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
float dp = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
int centroidLength = subvectorSizesAndOffsets[m][0];
dp += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
}
return dp;
}


@Override
public float pqScoreCosine(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
float sum = 0;
float aMagnitude = 0;
float bMagnitude = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
int centroidLength = subvectorSizesAndOffsets[m][0];
sum += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
aMagnitude += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex1 * centroidLength, centroidLength);
bMagnitude += dotProduct(codebooks[m], centroidIndex2 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
}
return (float)(sum / Math.sqrt(aMagnitude * bMagnitude));
}

@Override
public float pqScoreEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
float sum = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
int centroidLength = subvectorSizesAndOffsets[m][0];

sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
}
return sum;

}

@Override
public float pqScoreDotProduct(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> encodedChunk, int encodedOffset, VectorFloat<?> centeredQuery, int subspaceCount) {
float dp = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = subvectorSizesAndOffsets[m][0];
int centroidOffset = subvectorSizesAndOffsets[m][1];
dp += dotProduct(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
}
return dp;
}

@Override
public float pqScoreCosine(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> encodedChunk, int encodedOffset, VectorFloat<?> centeredQuery, int subspaceCount) {
float sum = 0;
float aMagnitude = 0;
float bMagnitude = 0;

for (int m = 0; m < subspaceCount; m++) {
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = subvectorSizesAndOffsets[m][0];
int centroidOffset = subvectorSizesAndOffsets[m][1];
var codebookOffset = centroidIndex * centroidLength;
sum += dotProduct(codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
aMagnitude += dotProduct(codebooks[m], codebookOffset, codebooks[m], codebookOffset, centroidLength);
bMagnitude += dotProduct(centeredQuery, centroidOffset, centeredQuery, centroidOffset, centroidLength);
}
float cosine = sum / (float) Math.sqrt(aMagnitude * bMagnitude);
return cosine;
}

@Override
public float pqScoreEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> encodedChunk, int encodedOffset, VectorFloat<?> centeredQuery, int subspaceCount) {
float sum = 0;
for (int m = 0; m < subspaceCount; m++) {
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
int centroidLength = subvectorSizesAndOffsets[m][0];
int centroidOffset = subvectorSizesAndOffsets[m][1];
sum += squareDistance(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
}
return sum;
}

}
Loading