Skip to content

Commit b3d2f4f

Browse files
committed
adding tests
1 parent 1c0894a commit b3d2f4f

File tree

4 files changed

+193
-57
lines changed

4 files changed

+193
-57
lines changed

fdb-extensions/fdb-extensions.gradle

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,38 @@ dependencies {
3636
testAnnotationProcessor(libs.autoService)
3737
}
3838

39+
def siftSmallFile = layout.buildDirectory.file('downloads/siftsmall.tar.gz')
40+
def extractDir = layout.buildDirectory.dir("extracted")
41+
42+
// Task that downloads the CSV exactly once unless it changed
43+
tasks.register('downloadSiftSmall', de.undercouch.gradle.tasks.download.Download) {
44+
src 'https://huggingface.co/datasets/vecdata/siftsmall/resolve/3106e1b83049c44713b1ce06942d0ab474bbdfb6/siftsmall.tar.gz'
45+
dest siftSmallFile.get().asFile
46+
onlyIfModified true
47+
tempAndMove true
48+
retries 3
49+
}
50+
51+
tasks.register('extractSiftSmall', Copy) {
52+
dependsOn 'downloadSiftSmall'
53+
from(tarTree(resources.gzip(siftSmallFile)))
54+
into extractDir
55+
56+
doLast {
57+
println "Extracted files into: ${extractDir.get().asFile}"
58+
fileTree(extractDir).visit { details ->
59+
if (!details.isDirectory()) {
60+
println " - ${details.file}"
61+
}
62+
}
63+
}
64+
}
65+
66+
test {
67+
dependsOn tasks.named('extractSiftSmall')
68+
inputs.dir extractDir
69+
}
70+
3971
publishing {
4072
publications {
4173
library(MavenPublication) {

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,18 @@
2222

2323
import com.christianheina.langx.half4j.Half;
2424
import com.google.common.base.Suppliers;
25+
import com.google.common.collect.AbstractIterator;
26+
import com.google.common.collect.ImmutableList;
2527

2628
import javax.annotation.Nonnull;
29+
import javax.annotation.Nullable;
30+
import java.io.EOFException;
31+
import java.io.IOException;
32+
import java.nio.ByteBuffer;
33+
import java.nio.ByteOrder;
34+
import java.nio.channels.FileChannel;
2735
import java.util.Arrays;
36+
import java.util.List;
2837
import java.util.Objects;
2938
import java.util.function.Supplier;
3039
import java.util.stream.Collectors;
@@ -221,4 +230,109 @@ public static Vector<?> fromBytes(@Nonnull final byte[] bytes, int precision) {
221230
// TODO
222231
throw new UnsupportedOperationException("not implemented yet");
223232
}
233+
234+
public abstract static class StoredVecsIterator<N extends Number, T> extends AbstractIterator<T> {
235+
@Nonnull
236+
private final FileChannel fileChannel;
237+
238+
protected StoredVecsIterator(@Nonnull final FileChannel fileChannel) {
239+
this.fileChannel = fileChannel;
240+
}
241+
242+
@Nonnull
243+
protected abstract N[] newComponentArray(int size);
244+
245+
@Nonnull
246+
protected abstract N toComponent(@Nonnull ByteBuffer byteBuffer);
247+
248+
@Nonnull
249+
protected abstract T toTarget(@Nonnull N[] components);
250+
251+
252+
@Nullable
253+
@Override
254+
protected T computeNext() {
255+
try {
256+
final ByteBuffer headerBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
257+
// allocate a buffer for reading floats later; you may reuse
258+
headerBuf.clear();
259+
final int bytesRead = fileChannel.read(headerBuf);
260+
if (bytesRead < 4) {
261+
if (bytesRead == -1) {
262+
return endOfData();
263+
}
264+
throw new IOException("corrupt fvecs file");
265+
}
266+
headerBuf.flip();
267+
final int dims = headerBuf.getInt();
268+
if (dims <= 0) {
269+
throw new IOException("Invalid dimension " + dims + " at position " + (fileChannel.position() - 4));
270+
}
271+
final ByteBuffer vecBuf = ByteBuffer.allocate(dims * 4).order(ByteOrder.LITTLE_ENDIAN);
272+
while (vecBuf.hasRemaining()) {
273+
int read = fileChannel.read(vecBuf);
274+
if (read < 0) {
275+
throw new EOFException("unexpected EOF when reading vector data");
276+
}
277+
}
278+
vecBuf.flip();
279+
final N[] rawVecData = newComponentArray(dims);
280+
for (int i = 0; i < dims; i++) {
281+
rawVecData[i] = toComponent(vecBuf);
282+
}
283+
284+
return toTarget(rawVecData);
285+
} catch (final IOException ioE) {
286+
throw new RuntimeException(ioE);
287+
}
288+
}
289+
}
290+
291+
public static class StoredFVecsIterator extends StoredVecsIterator<Double, DoubleVector> {
292+
public StoredFVecsIterator(@Nonnull final FileChannel fileChannel) {
293+
super(fileChannel);
294+
}
295+
296+
@Nonnull
297+
@Override
298+
protected Double[] newComponentArray(final int size) {
299+
return new Double[size];
300+
}
301+
302+
@Nonnull
303+
@Override
304+
protected Double toComponent(@Nonnull final ByteBuffer byteBuffer) {
305+
return (double)byteBuffer.getFloat();
306+
}
307+
308+
@Nonnull
309+
@Override
310+
protected DoubleVector toTarget(@Nonnull final Double[] components) {
311+
return new DoubleVector(components);
312+
}
313+
}
314+
315+
public static class StoredIVecsIterator extends StoredVecsIterator<Integer, List<Integer>> {
316+
public StoredIVecsIterator(@Nonnull final FileChannel fileChannel) {
317+
super(fileChannel);
318+
}
319+
320+
@Nonnull
321+
@Override
322+
protected Integer[] newComponentArray(final int size) {
323+
return new Integer[size];
324+
}
325+
326+
@Nonnull
327+
@Override
328+
protected Integer toComponent(@Nonnull final ByteBuffer byteBuffer) {
329+
return byteBuffer.getInt();
330+
}
331+
332+
@Nonnull
333+
@Override
334+
protected List<Integer> toTarget(@Nonnull final Integer[] components) {
335+
return ImmutableList.copyOf(components);
336+
}
337+
}
224338
}

fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* HNSWHelpers.java
2+
* NodeHelpers.java
33
*
44
* This source file is part of the FoundationDB open source project
55
*

fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java

Lines changed: 46 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
import com.apple.foundationdb.tuple.Tuple;
3131
import com.apple.test.Tags;
3232
import com.christianheina.langx.half4j.Half;
33+
import com.google.common.base.Verify;
3334
import com.google.common.collect.ImmutableList;
3435
import com.google.common.collect.Maps;
3536
import org.assertj.core.util.Lists;
3637
import org.junit.jupiter.api.Assertions;
3738
import org.junit.jupiter.api.BeforeEach;
38-
import org.junit.jupiter.api.Disabled;
3939
import org.junit.jupiter.api.Tag;
4040
import org.junit.jupiter.api.Test;
4141
import org.junit.jupiter.api.Timeout;
@@ -53,8 +53,13 @@
5353
import java.io.FileReader;
5454
import java.io.FileWriter;
5555
import java.io.IOException;
56+
import java.nio.channels.FileChannel;
57+
import java.nio.file.Path;
58+
import java.nio.file.Paths;
59+
import java.nio.file.StandardOpenOption;
5660
import java.util.ArrayList;
5761
import java.util.Comparator;
62+
import java.util.Iterator;
5863
import java.util.List;
5964
import java.util.Map;
6065
import java.util.NavigableSet;
@@ -208,9 +213,10 @@ private int basicInsertBatch(final HNSW hnsw, final int batchSize,
208213
final long beginTs = System.nanoTime();
209214
for (int i = 0; i < batchSize; i ++) {
210215
final var newNodeReference = insertFunction.apply(tr);
211-
if (newNodeReference != null) {
212-
hnsw.insert(tr, newNodeReference).join();
216+
if (newNodeReference == null) {
217+
return i;
213218
}
219+
hnsw.insert(tr, newNodeReference).join();
214220
}
215221
final long endTs = System.nanoTime();
216222
logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId,
@@ -243,7 +249,6 @@ private int insertBatch(final HNSW hnsw, final int batchSize,
243249
}
244250

245251
@Test
246-
@Timeout(value = 150, unit = TimeUnit.MINUTES)
247252
public void testSIFTInsert10k() throws Exception {
248253
final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric();
249254
final int k = 10;
@@ -255,76 +260,62 @@ public void testSIFTInsert10k() throws Exception {
255260
HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(),
256261
OnWriteListener.NOOP, onReadListener);
257262

258-
final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv";
259-
final int dimensions = 128;
263+
final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs");
260264

261-
final AtomicReference<HalfVector> queryVectorAtomic = new AtomicReference<>();
262-
final NavigableSet<NodeReferenceWithDistance> trueResults = new ConcurrentSkipListSet<>(
263-
Comparator.comparing(NodeReferenceWithDistance::getDistance));
265+
try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) {
266+
final Iterator<Vector.DoubleVector> vectorIterator = new Vector.StoredFVecsIterator(fileChannel);
264267

265-
try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) {
266-
for (int i = 0; i < 10000;) {
268+
int i = 0;
269+
while (vectorIterator.hasNext()) {
267270
i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener,
268271
tr -> {
269-
final String line;
270-
try {
271-
line = br.readLine();
272-
} catch (IOException e) {
273-
throw new RuntimeException(e);
272+
if (!vectorIterator.hasNext()) {
273+
return null;
274274
}
275275

276-
final String[] values = Objects.requireNonNull(line).split("\t");
277-
Assertions.assertEquals(dimensions, values.length);
278-
final Half[] halfs = new Half[dimensions];
276+
final Vector.DoubleVector doubleVector = vectorIterator.next();
279277

280-
for (int c = 0; c < values.length; c++) {
281-
final String value = values[c];
282-
halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value));
283-
}
284278
final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic);
285-
final HalfVector currentVector = new HalfVector(halfs);
286-
final HalfVector queryVector = queryVectorAtomic.get();
287-
if (queryVector == null) {
288-
queryVectorAtomic.set(currentVector);
289-
return null;
290-
} else {
291-
final double currentDistance =
292-
Vector.comparativeDistance(metric, currentVector, queryVector);
293-
if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) {
294-
trueResults.add(
295-
new NodeReferenceWithDistance(currentPrimaryKey, currentVector,
296-
Vector.comparativeDistance(metric, currentVector, queryVector)));
297-
}
298-
if (trueResults.size() > k) {
299-
trueResults.remove(trueResults.last());
300-
}
301-
return new NodeReferenceWithVector(currentPrimaryKey, currentVector);
302-
}
279+
final HalfVector currentVector = doubleVector.toHalfVector();
280+
return new NodeReferenceWithVector(currentPrimaryKey, currentVector);
303281
});
304282
}
305283
}
306284

307-
onReadListener.reset();
308-
final long beginTs = System.nanoTime();
309-
final List<? extends NodeReferenceAndNode<?>> results =
310-
db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join());
311-
final long endTs = System.nanoTime();
285+
final Path siftSmallGroundTruthPath = Paths.get(".out/extracted/siftsmall/siftsmall_groundtruth.ivecs");
286+
final Path siftSmallQueryPath = Paths.get(".out/extracted/siftsmall/siftsmall_query.fvecs");
312287

313-
for (NodeReferenceAndNode<?> nodeReferenceAndNode : results) {
314-
final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance();
315-
logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0),
316-
nodeReferenceWithDistance.getDistance());
317-
}
318288

319-
for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) {
320-
logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0),
321-
nodeReferenceWithDistance.getDistance());
289+
try (final var queryChannel = FileChannel.open(siftSmallQueryPath, StandardOpenOption.READ);
290+
final var groundTruthChannel = FileChannel.open(siftSmallGroundTruthPath, StandardOpenOption.READ)) {
291+
final Iterator<Vector.DoubleVector> queryIterator = new Vector.StoredFVecsIterator(queryChannel);
292+
final Iterator<List<Integer>> groundTruthIterator = new Vector.StoredIVecsIterator(groundTruthChannel);
293+
294+
Verify.verify(queryIterator.hasNext() == groundTruthIterator.hasNext());
295+
296+
while (queryIterator.hasNext()) {
297+
final HalfVector queryVector = queryIterator.next().toHalfVector();
298+
onReadListener.reset();
299+
final long beginTs = System.nanoTime();
300+
final List<? extends NodeReferenceAndNode<?>> results =
301+
db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join());
302+
final long endTs = System.nanoTime();
303+
logger.info("retrieved result in elapsedTimeMs={}", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
304+
305+
for (NodeReferenceAndNode<?> nodeReferenceAndNode : results) {
306+
final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance();
307+
logger.info("retrieved result nodeId = {} at distance = {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0),
308+
nodeReferenceWithDistance.getDistance());
309+
}
310+
311+
logger.info("true result vector={}", groundTruthIterator.next());
312+
}
322313
}
323314

324315
System.out.println(onReadListener.getNodeCountByLayer());
325316
System.out.println(onReadListener.getBytesReadByLayer());
326317

327-
logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
318+
// logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
328319
}
329320

330321
@Test
@@ -499,7 +490,6 @@ public void testSIFTVectors() throws Exception {
499490
standardDeviation / mean);
500491
}
501492

502-
503493
@ParameterizedTest
504494
@ValueSource(ints = {2, 3, 10, 100, 768})
505495
public void testManyVectorsStandardDeviation(final int dimensionality) {

0 commit comments

Comments
 (0)