Skip to content

Commit e0f150b

Browse files
author
Saurabh Singh
committed
Support for traversing BKD tree with prefetching
1 parent 8e8e37d commit e0f150b

File tree

5 files changed

+158
-4
lines changed

5 files changed

+158
-4
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ API Changes
984984
* GITHUB#13820, GITHUB#13825, GITHUB#13830: Corrects DataInput.readGroupVInts to be public and not-final, removes the protected
985985
DataInput.readGroupVInt method. (Zhang Chao, Robert Muir, Uwe Schindler, Dawid Weiss)
986986

987+
* GITHUB#15376, GITHUB#15197: Added prefetching in bkd tree traversal, couple of new api in PointValues visitDocIDs from a position and prepareOrVisitDocIDs to prefetch the IO before visiting docIds (Saurabh Singh)
988+
987989
New Features
988990
---------------------
989991

lucene/core/src/java/org/apache/lucene/index/PointValues.java

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.io.UncheckedIOException;
2121
import java.math.BigInteger;
2222
import java.net.InetAddress;
23+
import java.util.ArrayList;
24+
import java.util.List;
2325
import org.apache.lucene.document.BinaryPoint;
2426
import org.apache.lucene.document.DoublePoint;
2527
import org.apache.lucene.document.Field;
@@ -274,6 +276,17 @@ public interface PointTree extends Cloneable {
274276

275277
/** Visit all the docs and values below the current node. */
276278
void visitDocValues(IntersectVisitor visitor) throws IOException;
279+
280+
/** Visit all the docs below the node at position pos */
281+
default void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {}
282+
;
283+
284+
/**
285+
* call prefetch for docs below the current node if vistor supports prefetching otherwise visit
286+
* docIds
287+
*/
288+
default void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {}
289+
;
277290
}
278291

279292
/**
@@ -341,13 +354,54 @@ default void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcep
341354
default void grow(int count) {}
342355
}
343356

357+
/**
358+
* We can recurse the {@link PointTree} using prefetch capable visitor. This visitor caches the blocks
359+
* the blocks during recursion, calling prefetch on required blocks. This should potentially trigger IO
360+
* for these blocks asynchronously. Once the recursion is complete all the cached blocks are visited one by one.
361+
* @lucene.experimental
362+
*/
363+
public abstract static class PrefetchCapableVisitor implements IntersectVisitor {
364+
365+
int lastMatchingOrdinal = -1;
366+
List<Long> prefetchedBlocks = new ArrayList<>();
367+
368+
/**
369+
* return the last matched block ordinal - this is used to avoid prefetching call for contiguous
370+
* ordinals assuming contiguous ordinals prefetching can be taken care by readaheads.
371+
*/
372+
public int lastMatchedBlock() {
373+
return lastMatchingOrdinal;
374+
}
375+
376+
/** set last matched block ordinal * */
377+
public void setLastMatchedBlock(int leafNodeOrdinal) {
378+
lastMatchingOrdinal = leafNodeOrdinal;
379+
}
380+
381+
/** save prefetched block for visting later on * */
382+
public void savePrefetchedBlockForLaterVisit(long leafFp) {
383+
prefetchedBlocks.add(leafFp);
384+
}
385+
386+
/** returns the saved prefetch blocks * */
387+
public List<Long> savedPrefetchedBlocks() {
388+
return new ArrayList<>(prefetchedBlocks);
389+
}
390+
}
391+
344392
/**
345393
* Finds all documents and points matching the provided visitor. This method does not enforce live
346394
* documents, so it's up to the caller to test whether each document is deleted, if necessary.
347395
*/
348396
public final void intersect(IntersectVisitor visitor) throws IOException {
349397
final PointTree pointTree = getPointTree();
350398
intersect(visitor, pointTree);
399+
if (visitor instanceof PrefetchCapableVisitor prefetchCapableVisitor) {
400+
List<Long> fps = prefetchCapableVisitor.savedPrefetchedBlocks();
401+
for (int fp = 0; fp < fps.size(); ++fp) {
402+
pointTree.visitDocIDs(fps.get(fp), visitor);
403+
}
404+
}
351405
assert pointTree.moveToParent() == false;
352406
}
353407

@@ -358,7 +412,8 @@ private static void intersect(IntersectVisitor visitor, PointTree pointTree) thr
358412
if (compare == Relation.CELL_INSIDE_QUERY) {
359413
// This cell is fully inside the query shape: recursively add all points in this cell
360414
// without filtering
361-
pointTree.visitDocIDs(visitor);
415+
// pointTree.visitDocIDs( visitor);
416+
pointTree.prepareOrVisitDocIDs(visitor);
362417
} else if (compare == Relation.CELL_CROSSES_QUERY) {
363418
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
364419
// through and do full filtering:

lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ private boolean matches(byte[] packedValue) {
147147
}
148148

149149
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
150-
return new IntersectVisitor() {
150+
return new PointValues.PrefetchCapableVisitor() {
151151

152152
DocIdSetBuilder.BulkAdder adder;
153153

@@ -194,7 +194,7 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
194194

195195
/** Create a visitor that sets documents that do NOT match the range. */
196196
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
197-
return new IntersectVisitor() {
197+
return new PointValues.PrefetchCapableVisitor() {
198198

199199
@Override
200200
public void visit(int docID) {

lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,69 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException
589589
addAll(visitor, false);
590590
}
591591

592+
@Override
593+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
594+
resetNodeDataPosition();
595+
prefetchAll(visitor, false);
596+
}
597+
598+
@Override
599+
public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException {
600+
visitDocIDs(position, visitor, false);
601+
}
602+
603+
private void visitDocIDs(long position, IntersectVisitor visitor, boolean grown)
604+
throws IOException {
605+
leafNodes.seek(position);
606+
int count = leafNodes.readVInt();
607+
if (!grown) {
608+
visitor.grow(count);
609+
}
610+
docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs);
611+
}
612+
613+
private int getLeafNodeOrdinal() {
614+
assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf";
615+
return nodeID - leafNodeOffset;
616+
}
617+
618+
public void prefetchAll(IntersectVisitor visitor, boolean grown) throws IOException {
619+
if (grown == false) {
620+
final long size = size();
621+
if (size <= Integer.MAX_VALUE) {
622+
visitor.grow((int) size);
623+
grown = true;
624+
}
625+
}
626+
if (isLeafNode()) {
627+
// int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount;
628+
long leafFp = getLeafBlockFP();
629+
int leafNodeOrdinal = getLeafNodeOrdinal();
630+
if (visitor instanceof PrefetchCapableVisitor prefetchCapableVisitor) {
631+
// Only call prefetch is this is the first leaf node ordinal or the first match in
632+
// contigiuous sequence of matches for leaf nodes
633+
// boolean prefetched = false;
634+
if (prefetchCapableVisitor.lastMatchedBlock() == -1
635+
|| prefetchCapableVisitor.lastMatchedBlock() + 1 < leafNodeOrdinal) {
636+
// System.out.println("Prefetched called on " + leafNodeOrdinal);
637+
leafNodes.prefetch(leafFp, 1);
638+
// prefetched = true;
639+
}
640+
prefetchCapableVisitor.setLastMatchedBlock(leafNodeOrdinal);
641+
prefetchCapableVisitor.savePrefetchedBlockForLaterVisit(leafFp);
642+
} else {
643+
visitDocIDs(getLeafBlockFP(), visitor, true);
644+
}
645+
} else {
646+
pushLeft();
647+
prefetchAll(visitor, grown);
648+
pop();
649+
pushRight();
650+
prefetchAll(visitor, grown);
651+
pop();
652+
}
653+
}
654+
592655
public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
593656
if (grown == false) {
594657
final long size = size();

lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package org.apache.lucene.tests.index;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.Iterator;
23+
import java.util.List;
2224
import java.util.Objects;
2325
import org.apache.lucene.index.BinaryDocValues;
2426
import org.apache.lucene.index.DocValues;
@@ -1597,13 +1599,23 @@ public void visitDocValues(IntersectVisitor visitor) throws IOException {
15971599
pointValues.getBytesPerDimension(),
15981600
visitor));
15991601
}
1602+
1603+
@Override
1604+
public void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {
1605+
in.visitDocIDs(pos, visitor);
1606+
}
1607+
1608+
@Override
1609+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
1610+
in.prepareOrVisitDocIDs(visitor);
1611+
}
16001612
}
16011613

16021614
/**
16031615
* Validates in the 1D case that all points are visited in order, and point values are in bounds
16041616
* of the last cell checked
16051617
*/
1606-
static class AssertingIntersectVisitor implements IntersectVisitor {
1618+
static class AssertingIntersectVisitor extends PointValues.PrefetchCapableVisitor {
16071619
final IntersectVisitor in;
16081620
final int numDataDims;
16091621
final int numIndexDims;
@@ -1614,6 +1626,8 @@ static class AssertingIntersectVisitor implements IntersectVisitor {
16141626
private Relation lastCompareResult;
16151627
private int lastDocID = -1;
16161628
private int docBudget;
1629+
int lastMatchedBlock;
1630+
private List<Long> prefetchedBlocks;
16171631

16181632
AssertingIntersectVisitor(
16191633
int numDataDims, int numIndexDims, int bytesPerDim, IntersectVisitor in) {
@@ -1716,6 +1730,26 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
17161730
lastCompareResult = in.compare(minPackedValue, maxPackedValue);
17171731
return lastCompareResult;
17181732
}
1733+
1734+
@Override
1735+
public int lastMatchedBlock() {
1736+
return lastMatchedBlock;
1737+
}
1738+
1739+
@Override
1740+
public void setLastMatchedBlock(int leafNodeOrdinal) {
1741+
lastMatchedBlock = leafNodeOrdinal;
1742+
}
1743+
1744+
@Override
1745+
public void savePrefetchedBlockForLaterVisit(long leafFp) {
1746+
prefetchedBlocks.add(leafFp);
1747+
}
1748+
1749+
@Override
1750+
public List<Long> savedPrefetchedBlocks() {
1751+
return new ArrayList<>(prefetchedBlocks);
1752+
}
17191753
}
17201754

17211755
@Override

0 commit comments

Comments
 (0)