Skip to content

Commit 70a5f38

Browse files
committed
Add support for skipping over non-competitive documents in FirstPassGroupingCollector
1 parent 249152f commit 70a5f38

File tree

2 files changed

+55
-21
lines changed

2 files changed

+55
-21
lines changed

lucene/grouping/src/java/org/apache/lucene/search/grouping/FirstPassGroupingCollector.java

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ public class FirstPassGroupingCollector<T> extends SimpleCollector {
5050
private final LeafFieldComparator[] leafComparators;
5151
private final int[] reversed;
5252
private final int topNGroups;
53-
private final boolean needsScores;
5453
private final HashMap<T, CollectedSearchGroup<T>> groupMap;
5554
private final int compIDXEnd;
55+
private final ScoreMode scoreMode;
56+
private final boolean canSetMinScore;
5657

5758
// Set once we reach topNGroups unique groups:
5859
/**
@@ -62,6 +63,9 @@ public class FirstPassGroupingCollector<T> extends SimpleCollector {
6263

6364
private int docBase;
6465
private int spareSlot;
66+
private Scorable scorer;
67+
private int bottomSlot;
68+
private float minCompetitiveScore;
6569

6670
/**
6771
* Create the first pass collector.
@@ -83,7 +87,6 @@ public FirstPassGroupingCollector(
8387
// and specialize it?
8488

8589
this.topNGroups = topNGroups;
86-
this.needsScores = groupSort.needsScores();
8790
final SortField[] sortFields = groupSort.getSort();
8891
comparators = new FieldComparator<?>[sortFields.length];
8992
leafComparators = new LeafFieldComparator[sortFields.length];
@@ -105,13 +108,21 @@ public FirstPassGroupingCollector(
105108
reversed[i] = sortField.getReverse() ? -1 : 1;
106109
}
107110

111+
if (SortField.FIELD_SCORE.equals(sortFields[0]) == true) {
112+
scoreMode = ScoreMode.TOP_SCORES;
113+
canSetMinScore = true;
114+
} else {
115+
scoreMode = groupSort.needsScores() ? ScoreMode.TOP_DOCS_WITH_SCORES : ScoreMode.TOP_DOCS;
116+
canSetMinScore = false;
117+
}
118+
108119
spareSlot = topNGroups;
109120
groupMap = CollectionUtil.newHashMap(topNGroups);
110121
}
111122

112123
@Override
113124
public ScoreMode scoreMode() {
114-
return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
125+
return scoreMode;
115126
}
116127

117128
/**
@@ -162,10 +173,12 @@ public Collection<SearchGroup<T>> getTopGroups(int groupOffset) throws IOExcepti
162173

163174
@Override
164175
public void setScorer(Scorable scorer) throws IOException {
176+
this.scorer = scorer;
165177
groupSelector.setScorer(scorer);
166178
for (LeafFieldComparator comparator : leafComparators) {
167179
comparator.setScorer(scorer);
168180
}
181+
setMinCompetitiveScore(scorer);
169182
}
170183

171184
private boolean isCompetitive(int doc) throws IOException {
@@ -273,9 +286,7 @@ private void collectNewGroup(final int doc) throws IOException {
273286
assert orderedGroups.size() == topNGroups;
274287

275288
final int lastComparatorSlot = orderedGroups.last().comparatorSlot;
276-
for (LeafFieldComparator fc : leafComparators) {
277-
fc.setBottom(lastComparatorSlot);
278-
}
289+
setBottomSlot(lastComparatorSlot);
279290
}
280291
}
281292

@@ -331,9 +342,7 @@ private void collectExistingGroup(final int doc, final CollectedSearchGroup<T> g
331342
// If we changed the value of the last group, or changed which group was last, then update
332343
// bottom:
333344
if (group == newLast || prevLast != newLast) {
334-
for (LeafFieldComparator fc : leafComparators) {
335-
fc.setBottom(newLast.comparatorSlot);
336-
}
345+
setBottomSlot(newLast.comparatorSlot);
337346
}
338347
}
339348
}
@@ -364,13 +373,12 @@ public int compare(CollectedSearchGroup<?> o1, CollectedSearchGroup<?> o2) {
364373
orderedGroups.addAll(groupMap.values());
365374
assert orderedGroups.size() > 0;
366375

367-
for (LeafFieldComparator fc : leafComparators) {
368-
fc.setBottom(orderedGroups.last().comparatorSlot);
369-
}
376+
setBottomSlot(orderedGroups.last().comparatorSlot);
370377
}
371378

372379
@Override
373380
protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
381+
minCompetitiveScore = 0f;
374382
docBase = readerContext.docBase;
375383
for (int i = 0; i < comparators.length; i++) {
376384
leafComparators[i] = comparators[i].getLeafComparator(readerContext);
@@ -388,4 +396,25 @@ public GroupSelector<T> getGroupSelector() {
388396
private boolean isGroupMapFull() {
389397
return groupMap.size() >= topNGroups;
390398
}
399+
400+
private void setBottomSlot(final int bottomSlot) throws IOException {
401+
for (LeafFieldComparator fc : leafComparators) {
402+
fc.setBottom(bottomSlot);
403+
}
404+
405+
this.bottomSlot = bottomSlot;
406+
setMinCompetitiveScore(scorer);
407+
}
408+
409+
private void setMinCompetitiveScore(final Scorable scorer) throws IOException {
410+
if (canSetMinScore == false || isGroupMapFull() == false) {
411+
return;
412+
}
413+
414+
final float minScore = (float) comparators[0].value(bottomSlot);
415+
if (minScore > minCompetitiveScore) {
416+
scorer.setMinCompetitiveScore(minScore);
417+
minCompetitiveScore = minScore;
418+
}
419+
}
391420
}

lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.lucene.search.grouping;
1818

19+
import static org.hamcrest.Matchers.lessThanOrEqualTo;
20+
1921
import java.io.IOException;
2022
import java.util.ArrayList;
2123
import java.util.Arrays;
@@ -1520,14 +1522,14 @@ private void assertEquals(
15201522
"expected.groups.length != actual.groups.length",
15211523
expected.groups.length,
15221524
actual.groups.length);
1523-
assertEquals(
1524-
"expected.totalHitCount != actual.totalHitCount",
1525-
expected.totalHitCount,
1526-
actual.totalHitCount);
1527-
assertEquals(
1528-
"expected.totalGroupedHitCount != actual.totalGroupedHitCount",
1529-
expected.totalGroupedHitCount,
1530-
actual.totalGroupedHitCount);
1525+
assertThat(
1526+
"expected.totalHitCount >= actual.totalHitCount",
1527+
actual.totalHitCount,
1528+
lessThanOrEqualTo(expected.totalHitCount));
1529+
assertThat(
1530+
"expected.totalGroupedHitCount >= actual.totalGroupedHitCount",
1531+
actual.totalGroupedHitCount,
1532+
lessThanOrEqualTo(expected.totalGroupedHitCount));
15311533
if (expected.totalGroupCount != null && verifyTotalGroupCount) {
15321534
assertEquals(
15331535
"expected.totalGroupCount != actual.totalGroupCount",
@@ -1556,7 +1558,10 @@ private void assertEquals(
15561558

15571559
// TODO
15581560
// assertEquals(expectedGroup.maxScore, actualGroup.maxScore);
1559-
assertEquals(expectedGroup.totalHits().value(), actualGroup.totalHits().value());
1561+
assertThat(
1562+
"expectedGroup.totalHits().value() >= actualGroup.totalHits().value()",
1563+
actualGroup.totalHits().value(),
1564+
lessThanOrEqualTo(expectedGroup.totalHits().value()));
15601565

15611566
final ScoreDoc[] expectedFDs = expectedGroup.scoreDocs();
15621567
final ScoreDoc[] actualFDs = actualGroup.scoreDocs();

0 commit comments

Comments
 (0)