Skip to content

Commit 2bdef68

Browse files
author
Saurabh Singh
committed
Support for traversing BKD tree with prefetching
1 parent 8e8e37d commit 2bdef68

File tree

4 files changed

+153
-4
lines changed

4 files changed

+153
-4
lines changed

lucene/core/src/java/org/apache/lucene/index/PointValues.java

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.io.UncheckedIOException;
2121
import java.math.BigInteger;
2222
import java.net.InetAddress;
23+
import java.util.ArrayList;
24+
import java.util.List;
2325
import org.apache.lucene.document.BinaryPoint;
2426
import org.apache.lucene.document.DoublePoint;
2527
import org.apache.lucene.document.Field;
@@ -274,6 +276,17 @@ public interface PointTree extends Cloneable {
274276

275277
/** Visit all the docs and values below the current node. */
276278
void visitDocValues(IntersectVisitor visitor) throws IOException;
279+
280+
/** Visit all the docs below the node at position pos */
281+
default void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {}
282+
;
283+
284+
/**
285+
* call prefetch for docs below the current node if vistor supports prefetching otherwise visit
286+
* docIds
287+
*/
288+
default void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {}
289+
;
277290
}
278291

279292
/**
@@ -282,6 +295,7 @@ public interface PointTree extends Cloneable {
282295
* @lucene.experimental
283296
*/
284297
public interface IntersectVisitor {
298+
285299
/**
286300
* Called for all documents in a leaf cell that's fully contained by the query. The consumer
287301
* should blindly accept the docID.
@@ -341,13 +355,48 @@ default void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcep
341355
default void grow(int count) {}
342356
}
343357

358+
public abstract static class PrefetchCapableVisitor implements IntersectVisitor {
359+
360+
int lastMatchingOrdinal = -1;
361+
List<Long> prefetchedBlocks = new ArrayList<>();
362+
363+
/**
364+
* return the last matched block ordinal - this is used to avoid prefetching call for contiguous
365+
* ordinals assuming contiguous ordinals prefetching can be taken care by readaheads.
366+
*/
367+
public int lastMatchedBlock() {
368+
return lastMatchingOrdinal;
369+
}
370+
371+
/** set last matched block ordinal * */
372+
public void setLastMatchedBlock(int leafNodeOrdinal) {
373+
lastMatchingOrdinal = leafNodeOrdinal;
374+
}
375+
376+
/** save prefetched block for visting later on * */
377+
public void savePrefetchedBlockForLaterVisit(long leafFp) {
378+
prefetchedBlocks.add(leafFp);
379+
}
380+
381+
/** returns the saved prefetch blocks * */
382+
public List<Long> savedPrefetchedBlocks() {
383+
return new ArrayList<>(prefetchedBlocks);
384+
}
385+
}
386+
344387
/**
345388
* Finds all documents and points matching the provided visitor. This method does not enforce live
346389
* documents, so it's up to the caller to test whether each document is deleted, if necessary.
347390
*/
348391
public final void intersect(IntersectVisitor visitor) throws IOException {
349392
final PointTree pointTree = getPointTree();
350393
intersect(visitor, pointTree);
394+
if (visitor instanceof PrefetchCapableVisitor prefetchCapableVisitor) {
395+
List<Long> fps = prefetchCapableVisitor.savedPrefetchedBlocks();
396+
for (int fp = 0; fp < fps.size(); ++fp) {
397+
pointTree.visitDocIDs(fps.get(fp), visitor);
398+
}
399+
}
351400
assert pointTree.moveToParent() == false;
352401
}
353402

@@ -358,7 +407,8 @@ private static void intersect(IntersectVisitor visitor, PointTree pointTree) thr
358407
if (compare == Relation.CELL_INSIDE_QUERY) {
359408
// This cell is fully inside the query shape: recursively add all points in this cell
360409
// without filtering
361-
pointTree.visitDocIDs(visitor);
410+
// pointTree.visitDocIDs( visitor);
411+
pointTree.prepareOrVisitDocIDs(visitor);
362412
} else if (compare == Relation.CELL_CROSSES_QUERY) {
363413
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
364414
// through and do full filtering:

lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ private boolean matches(byte[] packedValue) {
147147
}
148148

149149
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
150-
return new IntersectVisitor() {
150+
return new PointValues.PrefetchCapableVisitor() {
151151

152152
DocIdSetBuilder.BulkAdder adder;
153153

@@ -194,7 +194,7 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
194194

195195
/** Create a visitor that sets documents that do NOT match the range. */
196196
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
197-
return new IntersectVisitor() {
197+
return new PointValues.PrefetchCapableVisitor() {
198198

199199
@Override
200200
public void visit(int docID) {

lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,71 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException
589589
addAll(visitor, false);
590590
}
591591

592+
@Override
593+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
594+
resetNodeDataPosition();
595+
prefetchAll(visitor, false);
596+
}
597+
598+
@Override
599+
public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException {
600+
visitDocIDs(position, visitor, false);
601+
}
602+
603+
private void visitDocIDs(long position, IntersectVisitor visitor, boolean grown)
604+
throws IOException {
605+
leafNodes.seek(position);
606+
int count = leafNodes.readVInt();
607+
if (!grown) {
608+
if (count <= Integer.MAX_VALUE) {
609+
visitor.grow(count);
610+
}
611+
}
612+
docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs);
613+
}
614+
615+
private int getLeafNodeOrdinal() {
616+
assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf";
617+
return nodeID - leafNodeOffset;
618+
}
619+
620+
public void prefetchAll(IntersectVisitor visitor, boolean grown) throws IOException {
621+
if (grown == false) {
622+
final long size = size();
623+
if (size <= Integer.MAX_VALUE) {
624+
visitor.grow((int) size);
625+
grown = true;
626+
}
627+
}
628+
if (isLeafNode()) {
629+
// int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount;
630+
long leafFp = getLeafBlockFP();
631+
int leafNodeOrdinal = getLeafNodeOrdinal();
632+
if (visitor instanceof PrefetchCapableVisitor prefetchCapableVisitor) {
633+
// Only call prefetch is this is the first leaf node ordinal or the first match in
634+
// contigiuous sequence of matches for leaf nodes
635+
// boolean prefetched = false;
636+
if (prefetchCapableVisitor.lastMatchedBlock() == -1
637+
|| prefetchCapableVisitor.lastMatchedBlock() + 1 < leafNodeOrdinal) {
638+
// System.out.println("Prefetched called on " + leafNodeOrdinal);
639+
leafNodes.prefetch(leafFp, 1);
640+
// prefetched = true;
641+
}
642+
prefetchCapableVisitor.setLastMatchedBlock(leafNodeOrdinal);
643+
prefetchCapableVisitor.savePrefetchedBlockForLaterVisit(leafFp);
644+
} else {
645+
visitDocIDs(getLeafBlockFP(), visitor, true);
646+
}
647+
} else {
648+
pushLeft();
649+
prefetchAll(visitor, grown);
650+
pop();
651+
pushRight();
652+
prefetchAll(visitor, grown);
653+
pop();
654+
}
655+
}
656+
592657
public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
593658
if (grown == false) {
594659
final long size = size();

lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package org.apache.lucene.tests.index;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.Iterator;
23+
import java.util.List;
2224
import java.util.Objects;
2325
import org.apache.lucene.index.BinaryDocValues;
2426
import org.apache.lucene.index.DocValues;
@@ -1597,13 +1599,23 @@ public void visitDocValues(IntersectVisitor visitor) throws IOException {
15971599
pointValues.getBytesPerDimension(),
15981600
visitor));
15991601
}
1602+
1603+
@Override
1604+
public void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {
1605+
in.visitDocIDs(pos, visitor);
1606+
}
1607+
1608+
@Override
1609+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
1610+
in.prepareOrVisitDocIDs(visitor);
1611+
}
16001612
}
16011613

16021614
/**
16031615
* Validates in the 1D case that all points are visited in order, and point values are in bounds
16041616
* of the last cell checked
16051617
*/
1606-
static class AssertingIntersectVisitor implements IntersectVisitor {
1618+
static class AssertingIntersectVisitor extends PointValues.PrefetchCapableVisitor {
16071619
final IntersectVisitor in;
16081620
final int numDataDims;
16091621
final int numIndexDims;
@@ -1614,6 +1626,8 @@ static class AssertingIntersectVisitor implements IntersectVisitor {
16141626
private Relation lastCompareResult;
16151627
private int lastDocID = -1;
16161628
private int docBudget;
1629+
int lastMatchedBlock;
1630+
private List<Long> prefetchedBlocks;
16171631

16181632
AssertingIntersectVisitor(
16191633
int numDataDims, int numIndexDims, int bytesPerDim, IntersectVisitor in) {
@@ -1716,6 +1730,26 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
17161730
lastCompareResult = in.compare(minPackedValue, maxPackedValue);
17171731
return lastCompareResult;
17181732
}
1733+
1734+
@Override
1735+
public int lastMatchedBlock() {
1736+
return lastMatchedBlock;
1737+
}
1738+
1739+
@Override
1740+
public void setLastMatchedBlock(int leafNodeOrdinal) {
1741+
lastMatchedBlock = leafNodeOrdinal;
1742+
}
1743+
1744+
@Override
1745+
public void savePrefetchedBlockForLaterVisit(long leafFp) {
1746+
prefetchedBlocks.add(leafFp);
1747+
}
1748+
1749+
@Override
1750+
public List<Long> savedPrefetchedBlocks() {
1751+
return new ArrayList<>(prefetchedBlocks);
1752+
}
17191753
}
17201754

17211755
@Override

0 commit comments

Comments
 (0)