1616
1717package io .github .jbellis .jvector .quantization ;
1818
19- import io .github .jbellis .jvector .annotations .VisibleForTesting ;
2019import io .github .jbellis .jvector .disk .RandomAccessReader ;
2120import io .github .jbellis .jvector .graph .RandomAccessVectorValues ;
2221import io .github .jbellis .jvector .graph .similarity .ScoreFunction ;
3736
3837public abstract class PQVectors implements CompressedVectors {
3938 private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider .getInstance ().getVectorTypeSupport ();
40- static final int MAX_CHUNK_SIZE = Integer .MAX_VALUE - 16 ; // standard Java array size limit with some headroom
4139
4240 final ProductQuantization pq ;
4341 protected ByteSequence <?>[] compressedDataChunks ;
@@ -55,51 +53,19 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException
5553 int vectorCount = in .readInt ();
5654 int compressedDimension = in .readInt ();
5755
58- int [] params = calculateChunkParameters (vectorCount , compressedDimension );
59- int vectorsPerChunk = params [0 ];
60- int totalChunks = params [1 ];
61- int fullSizeChunks = params [2 ];
62- int remainingVectors = params [3 ];
56+ PQLayout layout = new PQLayout (vectorCount ,compressedDimension );
57+ ByteSequence <?>[] chunks = new ByteSequence <?>[layout .totalChunks ];
6358
64- ByteSequence <?>[] chunks = new ByteSequence <?>[totalChunks ];
65- int chunkBytes = vectorsPerChunk * compressedDimension ;
66-
67- for (int i = 0 ; i < fullSizeChunks ; i ++) {
68- chunks [i ] = vectorTypeSupport .readByteSequence (in , chunkBytes );
59+ for (int i = 0 ; i < layout .fullSizeChunks ; i ++) {
60+ chunks [i ] = vectorTypeSupport .readByteSequence (in , layout .fullChunkBytes );
6961 }
7062
7163 // Last chunk might be smaller
72- if (totalChunks > fullSizeChunks ) {
73- chunks [fullSizeChunks ] = vectorTypeSupport .readByteSequence (in , remainingVectors * compressedDimension );
64+ if (layout . totalChunks > layout . fullSizeChunks ) {
65+ chunks [layout . fullSizeChunks ] = vectorTypeSupport .readByteSequence (in , layout . lastChunkBytes );
7466 }
7567
76- return new ImmutablePQVectors (pq , chunks , vectorCount , vectorsPerChunk );
77- }
78-
79- /**
80- * Calculate chunking parameters for the given vector count and compressed dimension
81- * @return array of [vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors]
82- */
83- @ VisibleForTesting
84- static int [] calculateChunkParameters (int vectorCount , int compressedDimension ) {
85- if (vectorCount < 0 ) {
86- throw new IllegalArgumentException ("Invalid vector count " + vectorCount );
87- }
88- if (compressedDimension < 0 ) {
89- throw new IllegalArgumentException ("Invalid compressed dimension " + compressedDimension );
90- }
91-
92- long totalSize = (long ) vectorCount * compressedDimension ;
93- int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension ;
94- if (vectorsPerChunk == 0 ) {
95- throw new IllegalArgumentException ("Compressed dimension " + compressedDimension + " too large for chunking" );
96- }
97-
98- int fullSizeChunks = vectorCount / vectorsPerChunk ;
99- int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1 ;
100-
101- int remainingVectors = vectorCount % vectorsPerChunk ;
102- return new int [] {vectorsPerChunk , totalChunks , fullSizeChunks , remainingVectors };
68+ return new ImmutablePQVectors (pq , chunks , vectorCount , layout .fullChunkVectors );
10369 }
10470
10571 public static PQVectors load (RandomAccessReader in , long offset ) throws IOException {
@@ -118,20 +84,15 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept
11884 * @return the PQVectors instance
11985 */
12086 public static ImmutablePQVectors encodeAndBuild (ProductQuantization pq , int vectorCount , RandomAccessVectorValues ravv , ForkJoinPool simdExecutor ) {
121- // Calculate if we need to split into multiple chunks
12287 int compressedDimension = pq .compressedVectorSize ();
123- long totalSize = (long ) vectorCount * compressedDimension ;
124- int vectorsPerChunk = totalSize <= PQVectors .MAX_CHUNK_SIZE ? vectorCount : PQVectors .MAX_CHUNK_SIZE / compressedDimension ;
125-
126- int numChunks = vectorCount / vectorsPerChunk ;
127- final ByteSequence <?>[] chunks = new ByteSequence <?>[numChunks ];
128- int chunkSize = vectorsPerChunk * compressedDimension ;
129- for (int i = 0 ; i < numChunks - 1 ; i ++)
130- chunks [i ] = vectorTypeSupport .createByteSequence (chunkSize );
131-
132- // Last chunk might be smaller
133- int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1 ));
134- chunks [numChunks - 1 ] = vectorTypeSupport .createByteSequence (remainingVectors * compressedDimension );
88+ PQLayout layout = new PQLayout (vectorCount ,compressedDimension );
89+ final ByteSequence <?>[] chunks = new ByteSequence <?>[layout .totalChunks ];
90+ for (int i = 0 ; i < layout .fullSizeChunks ; i ++) {
91+ chunks [i ] = vectorTypeSupport .createByteSequence (layout .fullChunkBytes );
92+ }
93+ if (layout .lastChunkVectors > 0 ) {
94+ chunks [layout .fullSizeChunks ] = vectorTypeSupport .createByteSequence (layout .lastChunkBytes );
95+ }
13596
13697 // Encode the vectors in parallel into the compressed data chunks
13798 // The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams
@@ -142,7 +103,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
142103 .forEach (ordinal -> {
143104 // Retrieve the slice and mutate it.
144105 var localRavv = ravvCopy .get ();
145- var slice = PQVectors .get (chunks , ordinal , vectorsPerChunk , pq .getSubspaceCount ());
106+ var slice = PQVectors .get (chunks , ordinal , layout . fullChunkVectors , pq .getSubspaceCount ());
146107 var vector = localRavv .getVector (ordinal );
147108 if (vector != null )
148109 pq .encodeTo (vector , slice );
@@ -151,7 +112,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
151112 }))
152113 .join ();
153114
154- return new ImmutablePQVectors (pq , chunks , vectorCount , vectorsPerChunk );
115+ return new ImmutablePQVectors (pq , chunks , vectorCount , layout . fullChunkVectors );
155116 }
156117
157118 @ Override
@@ -443,4 +404,73 @@ public String toString() {
443404 ", count=" + count () +
444405 '}' ;
445406 }
407+
408+ /**
409+ * Chunk Dimensions and Layout
410+ * This is emulative of modern Java records, but keeps to J11 standards.
411+ * This class consolidates the layout calculations for PQ data into one place
412+ */
413+ static class PQLayout {
414+
415+ /**
416+ * total number of vectors
417+ **/
418+ public final int vectorCount ;
419+ /**
420+ * total number of chunks, including any partial
421+ **/
422+ public final int totalChunks ;
423+ /**
424+ * total number of fully-filled chunks
425+ **/
426+ public final int fullSizeChunks ;
427+ /**
428+ * number of vectors per fullSize chunk
429+ **/
430+ public final int fullChunkVectors ;
431+ /**
432+ * number of vectors in last partially filled chunk, if any
433+ **/
434+ public final int lastChunkVectors ;
435+ /**
436+ * compressed dimension of vectors
437+ **/
438+ public final int compressedDimension ;
439+ /**
440+ * number of bytes in each fully-filled chunk
441+ **/
442+ public final int fullChunkBytes ;
443+ /**
444+ * number of bytes in the last partially-filled chunk, if any
445+ **/
446+ public final int lastChunkBytes ;
447+
448+ public PQLayout (int vectorCount , int compressedDimension ) {
449+ if (vectorCount <= 0 ) {
450+ throw new IllegalArgumentException ("Invalid vector count " + vectorCount );
451+ }
452+ this .vectorCount = vectorCount ;
453+
454+ if (compressedDimension <= 0 ) {
455+ throw new IllegalArgumentException ("Invalid compressed dimension " + compressedDimension );
456+ }
457+ this .compressedDimension = compressedDimension ;
458+
459+ // Get the aligned number of bytes needed to hold a given dimension
460+ // purely for overflow prevention
461+ int layoutBytesPerVector = compressedDimension == 1 ? 1 : Integer .highestOneBit (compressedDimension - 1 ) << 1 ;
462+ // truncation welcome here, biasing for smaller chunks
463+ int addressableVectorsPerChunk = Integer .MAX_VALUE / layoutBytesPerVector ;
464+
465+ fullChunkVectors = Math .min (vectorCount , addressableVectorsPerChunk );
466+ lastChunkVectors = vectorCount % fullChunkVectors ;
467+
468+ fullChunkBytes = fullChunkVectors * compressedDimension ;
469+ lastChunkBytes = lastChunkVectors * compressedDimension ;
470+
471+ fullSizeChunks = vectorCount / fullChunkVectors ;
472+ totalChunks = fullSizeChunks + (lastChunkVectors == 0 ? 0 : 1 );
473+ }
474+
475+ }
446476}
0 commit comments