3030import com .apple .foundationdb .tuple .Tuple ;
3131import com .apple .test .Tags ;
3232import com .christianheina .langx .half4j .Half ;
33+ import com .google .common .base .Verify ;
3334import com .google .common .collect .ImmutableList ;
3435import com .google .common .collect .Maps ;
3536import org .assertj .core .util .Lists ;
3637import org .junit .jupiter .api .Assertions ;
3738import org .junit .jupiter .api .BeforeEach ;
38- import org .junit .jupiter .api .Disabled ;
3939import org .junit .jupiter .api .Tag ;
4040import org .junit .jupiter .api .Test ;
4141import org .junit .jupiter .api .Timeout ;
5353import java .io .FileReader ;
5454import java .io .FileWriter ;
5555import java .io .IOException ;
56+ import java .nio .channels .FileChannel ;
57+ import java .nio .file .Path ;
58+ import java .nio .file .Paths ;
59+ import java .nio .file .StandardOpenOption ;
5660import java .util .ArrayList ;
5761import java .util .Comparator ;
62+ import java .util .Iterator ;
5863import java .util .List ;
5964import java .util .Map ;
6065import java .util .NavigableSet ;
@@ -208,9 +213,10 @@ private int basicInsertBatch(final HNSW hnsw, final int batchSize,
208213 final long beginTs = System .nanoTime ();
209214 for (int i = 0 ; i < batchSize ; i ++) {
210215 final var newNodeReference = insertFunction .apply (tr );
211- if (newNodeReference ! = null ) {
212- hnsw . insert ( tr , newNodeReference ). join () ;
216+ if (newNodeReference = = null ) {
217+ return i ;
213218 }
219+ hnsw .insert (tr , newNodeReference ).join ();
214220 }
215221 final long endTs = System .nanoTime ();
216222 logger .info ("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}" , batchSize , nextNodeId ,
@@ -243,7 +249,6 @@ private int insertBatch(final HNSW hnsw, final int batchSize,
243249 }
244250
245251 @ Test
246- @ Timeout (value = 150 , unit = TimeUnit .MINUTES )
247252 public void testSIFTInsert10k () throws Exception {
248253 final Metric metric = Metrics .EUCLIDEAN_METRIC .getMetric ();
249254 final int k = 10 ;
@@ -255,76 +260,62 @@ public void testSIFTInsert10k() throws Exception {
255260 HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (metric ).setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
256261 OnWriteListener .NOOP , onReadListener );
257262
258- final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv" ;
259- final int dimensions = 128 ;
263+ final Path siftSmallPath = Paths .get (".out/extracted/siftsmall/siftsmall_base.fvecs" );
260264
261- final AtomicReference <HalfVector > queryVectorAtomic = new AtomicReference <>();
262- final NavigableSet <NodeReferenceWithDistance > trueResults = new ConcurrentSkipListSet <>(
263- Comparator .comparing (NodeReferenceWithDistance ::getDistance ));
265+ try (final var fileChannel = FileChannel .open (siftSmallPath , StandardOpenOption .READ )) {
266+ final Iterator <Vector .DoubleVector > vectorIterator = new Vector .StoredFVecsIterator (fileChannel );
264267
265- try ( BufferedReader br = new BufferedReader ( new FileReader ( tsvFile ))) {
266- for ( int i = 0 ; i < 10000 ; ) {
268+ int i = 0 ;
269+ while ( vectorIterator . hasNext () ) {
267270 i += basicInsertBatch (hnsw , 100 , nextNodeIdAtomic , onReadListener ,
268271 tr -> {
269- final String line ;
270- try {
271- line = br .readLine ();
272- } catch (IOException e ) {
273- throw new RuntimeException (e );
272+ if (!vectorIterator .hasNext ()) {
273+ return null ;
274274 }
275275
276- final String [] values = Objects .requireNonNull (line ).split ("\t " );
277- Assertions .assertEquals (dimensions , values .length );
278- final Half [] halfs = new Half [dimensions ];
276+ final Vector .DoubleVector doubleVector = vectorIterator .next ();
279277
280- for (int c = 0 ; c < values .length ; c ++) {
281- final String value = values [c ];
282- halfs [c ] = HNSWHelpers .halfValueOf (Double .parseDouble (value ));
283- }
284278 final Tuple currentPrimaryKey = createNextPrimaryKey (nextNodeIdAtomic );
285- final HalfVector currentVector = new HalfVector (halfs );
286- final HalfVector queryVector = queryVectorAtomic .get ();
287- if (queryVector == null ) {
288- queryVectorAtomic .set (currentVector );
289- return null ;
290- } else {
291- final double currentDistance =
292- Vector .comparativeDistance (metric , currentVector , queryVector );
293- if (trueResults .size () < k || trueResults .last ().getDistance () > currentDistance ) {
294- trueResults .add (
295- new NodeReferenceWithDistance (currentPrimaryKey , currentVector ,
296- Vector .comparativeDistance (metric , currentVector , queryVector )));
297- }
298- if (trueResults .size () > k ) {
299- trueResults .remove (trueResults .last ());
300- }
301- return new NodeReferenceWithVector (currentPrimaryKey , currentVector );
302- }
279+ final HalfVector currentVector = doubleVector .toHalfVector ();
280+ return new NodeReferenceWithVector (currentPrimaryKey , currentVector );
303281 });
304282 }
305283 }
306284
307- onReadListener .reset ();
308- final long beginTs = System .nanoTime ();
309- final List <? extends NodeReferenceAndNode <?>> results =
310- db .run (tr -> hnsw .kNearestNeighborsSearch (tr , k , 100 , queryVectorAtomic .get ()).join ());
311- final long endTs = System .nanoTime ();
285+ final Path siftSmallGroundTruthPath = Paths .get (".out/extracted/siftsmall/siftsmall_groundtruth.ivecs" );
286+ final Path siftSmallQueryPath = Paths .get (".out/extracted/siftsmall/siftsmall_query.fvecs" );
312287
313- for (NodeReferenceAndNode <?> nodeReferenceAndNode : results ) {
314- final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode .getNodeReferenceWithDistance ();
315- logger .info ("retrieved result nodeId = {} at distance= {}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
316- nodeReferenceWithDistance .getDistance ());
317- }
318288
319- for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults ) {
320- logger .info ("true result nodeId ={} at distance={}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
321- nodeReferenceWithDistance .getDistance ());
289+ try (final var queryChannel = FileChannel .open (siftSmallQueryPath , StandardOpenOption .READ );
290+ final var groundTruthChannel = FileChannel .open (siftSmallGroundTruthPath , StandardOpenOption .READ )) {
291+ final Iterator <Vector .DoubleVector > queryIterator = new Vector .StoredFVecsIterator (queryChannel );
292+ final Iterator <List <Integer >> groundTruthIterator = new Vector .StoredIVecsIterator (groundTruthChannel );
293+
294+ Verify .verify (queryIterator .hasNext () == groundTruthIterator .hasNext ());
295+
296+ while (queryIterator .hasNext ()) {
297+ final HalfVector queryVector = queryIterator .next ().toHalfVector ();
298+ onReadListener .reset ();
299+ final long beginTs = System .nanoTime ();
300+ final List <? extends NodeReferenceAndNode <?>> results =
301+ db .run (tr -> hnsw .kNearestNeighborsSearch (tr , k , 100 , queryVector ).join ());
302+ final long endTs = System .nanoTime ();
303+ logger .info ("retrieved result in elapsedTimeMs={}" , TimeUnit .NANOSECONDS .toMillis (endTs - beginTs ));
304+
305+ for (NodeReferenceAndNode <?> nodeReferenceAndNode : results ) {
306+ final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode .getNodeReferenceWithDistance ();
307+ logger .info ("retrieved result nodeId = {} at distance = {}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
308+ nodeReferenceWithDistance .getDistance ());
309+ }
310+
311+ logger .info ("true result vector={}" , groundTruthIterator .next ());
312+ }
322313 }
323314
324315 System .out .println (onReadListener .getNodeCountByLayer ());
325316 System .out .println (onReadListener .getBytesReadByLayer ());
326317
327- logger .info ("search transaction took elapsedTime={}ms" , TimeUnit .NANOSECONDS .toMillis (endTs - beginTs ));
318+ // logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs));
328319 }
329320
330321 @ Test
@@ -499,7 +490,6 @@ public void testSIFTVectors() throws Exception {
499490 standardDeviation / mean );
500491 }
501492
502-
503493 @ ParameterizedTest
504494 @ ValueSource (ints = {2 , 3 , 10 , 100 , 768 })
505495 public void testManyVectorsStandardDeviation (final int dimensionality ) {
0 commit comments