Skip to content

Commit f56bf0a

Browse files
committed
Add binarySearchFirst and binarySearchLast for stable search with duplicit elements
1 parent 76e2a2c commit f56bf0a

File tree

2 files changed

+197
-41
lines changed

2 files changed

+197
-41
lines changed

src/main/java/org/apache/commons/lang3/ArrayUtils.java

+143-23
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,8 @@ public static <T> T arraycopy(final T source, final int sourcePos, final T dest,
14331433
}
14341434

14351435
/**
1436-
* Searches element in array sorted by key.
1436+
* Searches element in array sorted by key. If there are multiple elements matching, it returns first occurrence.
1437+
* If the array is not sorted, the result is undefined.
14371438
*
14381439
* @param array
14391440
* array sorted by key field
@@ -1445,25 +1446,26 @@ public static <T> T arraycopy(final T source, final int sourcePos, final T dest,
14451446
* comparator for keys
14461447
*
14471448
* @return
1448-
* index of the search key, if it is contained in the array; otherwise, (-first_greater - 1).
1449-
* The first_greater is the index of lowest greater element in the list - if all elements are lower, the
1450-
* first_greater is defined as array.length.
1449+
* index of the first occurrence of search key, if it is contained in the array; otherwise,
1450+
* (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements
1451+
* are lower, the first_greater is defined as array.length.
14511452
*
14521453
* @param <T>
14531454
* type of array element
14541455
* @param <K>
14551456
* type of key
14561457
*/
1457-
public static <K, T> int binarySearch(
1458+
public static <K, T> int binarySearchFirst(
14581459
T[] array,
14591460
K key,
14601461
Function<T, K> keyExtractor, Comparator<? super K> comparator
14611462
) {
1462-
return binarySearch0(array, 0, array.length, key, keyExtractor, comparator);
1463+
return binarySearchFirst0(array, 0, array.length, key, keyExtractor, comparator);
14631464
}
14641465

14651466
/**
1466-
* Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive).
1467+
* Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are
1468+
* multiple elements matching, it returns first occurrence. If the array is not sorted, the result is undefined.
14671469
*
14681470
* @param array
14691471
* array sorted by key field
@@ -1479,9 +1481,9 @@ public static <K, T> int binarySearch(
14791481
* comparator for keys
14801482
*
14811483
* @return
1482-
* index of the search key, if it is contained in the array within specified range; otherwise,
1483-
* (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements
1484-
* are lower, the first_greater is defined as toIndex.
1484+
* index of the first occurrence of search key, if it is contained in the array within specified range;
1485+
* otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if
1486+
* all elements are lower, the first_greater is defined as toIndex.
14851487
*
14861488
* @throws ArrayIndexOutOfBoundsException
14871489
* when fromIndex or toIndex is out of array range
@@ -1493,28 +1495,124 @@ public static <K, T> int binarySearch(
14931495
* @param <K>
14941496
* type of key
14951497
*/
1496-
public static <T, K> int binarySearch(
1498+
public static <T, K> int binarySearchFirst(
14971499
T[] array,
14981500
int fromIndex, int toIndex,
14991501
K key,
15001502
Function<T, K> keyExtractor, Comparator<? super K> comparator
15011503
) {
1502-
if (fromIndex > toIndex) {
1503-
throw new IllegalArgumentException(
1504-
"fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
1505-
}
1506-
if (fromIndex < 0) {
1507-
throw new ArrayIndexOutOfBoundsException(fromIndex);
1508-
}
1509-
if (toIndex > array.length) {
1510-
throw new ArrayIndexOutOfBoundsException(toIndex);
1504+
checkRange(array.length, fromIndex, toIndex);
1505+
1506+
return binarySearchFirst0(array, fromIndex, toIndex, key, keyExtractor, comparator);
1507+
}
1508+
1509+
// common implementation for binarySearch methods, with same semantics:
1510+
private static <T, K> int binarySearchFirst0(
1511+
T[] array,
1512+
int fromIndex, int toIndex,
1513+
K key,
1514+
Function<T, K> keyExtractor, Comparator<? super K> comparator
1515+
) {
1516+
int l = fromIndex;
1517+
int h = toIndex - 1;
1518+
1519+
while (l <= h) {
1520+
final int m = (l + h) >>> 1; // unsigned shift to avoid overflow
1521+
final K value = keyExtractor.apply(array[m]);
1522+
final int c = comparator.compare(value, key);
1523+
if (c < 0) {
1524+
l = m + 1;
1525+
} else if (c > 0) {
1526+
h = m - 1;
1527+
} else if (l < h) {
1528+
// possibly multiple matching items remaining:
1529+
h = m;
1530+
} else {
1531+
// single matching item remaining:
1532+
return m;
1533+
}
15111534
}
15121535

1513-
return binarySearch0(array, fromIndex, toIndex, key, keyExtractor, comparator);
1536+
// not found, the l points to the lowest higher match:
1537+
return -l - 1;
1538+
}
1539+
1540+
/**
1541+
* Searches element in array sorted by key. If there are multiple elements matching, it returns last occurrence.
1542+
* If the array is not sorted, the result is undefined.
1543+
*
1544+
* @param array
1545+
* array sorted by key field
1546+
* @param key
1547+
* key to search for
1548+
* @param keyExtractor
1549+
* function to extract key from element
1550+
* @param comparator
1551+
* comparator for keys
1552+
*
1553+
* @return
1554+
* index of the last occurrence of search key, if it is contained in the array; otherwise,
1555+
* (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements
1556+
* are lower, the first_greater is defined as array.length.
1557+
*
1558+
* @param <T>
1559+
* type of array element
1560+
* @param <K>
1561+
* type of key
1562+
*/
1563+
public static <K, T> int binarySearchLast(
1564+
T[] array,
1565+
K key,
1566+
Function<T, K> keyExtractor, Comparator<? super K> comparator
1567+
) {
1568+
return binarySearchLast0(array, 0, array.length, key, keyExtractor, comparator);
1569+
}
1570+
1571+
/**
1572+
* Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are
1573+
* multiple elements matching, it returns last occurrence. If the array is not sorted, the result is undefined.
1574+
*
1575+
* @param array
1576+
* array sorted by key field
1577+
* @param fromIndex
1578+
* start index (inclusive)
1579+
* @param toIndex
1580+
* end index (exclusive)
1581+
* @param key
1582+
* key to search for
1583+
* @param keyExtractor
1584+
* function to extract key from element
1585+
* @param comparator
1586+
* comparator for keys
1587+
*
1588+
* @return
1589+
* index of the last occurrence of search key, if it is contained in the array within specified range;
1590+
* otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if
1591+
* all elements are lower, the first_greater is defined as toIndex.
1592+
*
1593+
* @throws ArrayIndexOutOfBoundsException
1594+
* when fromIndex or toIndex is out of array range
1595+
* @throws IllegalArgumentException
1596+
* when fromIndex is greater than toIndex
1597+
*
1598+
* @param <T>
1599+
* type of array element
1600+
* @param <K>
1601+
* type of key
1602+
*/
1603+
public static <T, K> int binarySearchLast(
1604+
T[] array,
1605+
int fromIndex, int toIndex,
1606+
K key,
1607+
Function<T, K> keyExtractor, Comparator<? super K> comparator
1608+
) {
1609+
checkRange(array.length, fromIndex, toIndex);
1610+
1611+
return binarySearchLast0(array, fromIndex, toIndex, key, keyExtractor, comparator);
15141612
}
15151613

15161614
// common implementation for binarySearch methods, with same semantics:
1517-
private static <T, K> int binarySearch0(
1615+
private static <T, K> int binarySearchLast0(
15181616
T[] array,
15191617
int fromIndex, int toIndex,
15201618
K key,
@@ -1531,8 +1629,16 @@ private static <T, K> int binarySearch0(
15311629
l = m + 1;
15321630
} else if (c > 0) {
15331631
h = m - 1;
1632+
} else if (m + 1 < h) {
1633+
// matching, more than two items remaining:
1634+
l = m;
1635+
} else if (m + 1 == h) {
1636+
// two items remaining, next loops would result in unchanged l and h, we have to choose m or h:
1637+
final K valueH = keyExtractor.apply(array[h]);
1638+
final int cH = comparator.compare(valueH, key);
1639+
return cH == 0 ? h : m;
15341640
} else {
1535-
// 0, found
1641+
// one item remaining, single match:
15361642
return m;
15371643
}
15381644
}
@@ -9573,4 +9679,18 @@ public static String[] toStringArray(final Object[] array, final String valueFor
95739679
public ArrayUtils() {
95749680
// empty
95759681
}
9682+
9683+
static void checkRange(int length, int fromIndex, int toIndex) {
9684+
if (fromIndex > toIndex) {
9685+
throw new IllegalArgumentException(
9686+
"fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
9687+
}
9688+
if (fromIndex < 0) {
9689+
throw new ArrayIndexOutOfBoundsException(fromIndex);
9690+
}
9691+
if (toIndex > length) {
9692+
throw new ArrayIndexOutOfBoundsException(toIndex);
9693+
}
9694+
9695+
}
95769696
}

src/test/java/org/apache/commons/lang3/ArrayUtilsBinarySearchTest.java

+54-18
Original file line numberDiff line numberDiff line change
@@ -30,63 +30,99 @@
3030
public class ArrayUtilsBinarySearchTest extends AbstractLangTest {
3131

3232
@Test
33-
public void binarySearch_whenLowHigherThanEnd_throw() {
33+
public void binarySearchFirst_whenLowHigherThanEnd_throw() {
3434
final Data[] list = createList(0, 1);
35-
assertThrowsExactly(IllegalArgumentException.class, () -> ArrayUtils.binarySearch(list, 1, 0, 0, Data::getValue, Integer::compare));
35+
assertThrowsExactly(IllegalArgumentException.class, () ->
36+
ArrayUtils.binarySearchFirst(list, 1, 0, 0, Data::getValue, Integer::compare));
3637
}
3738

3839
@Test
39-
public void binarySearch_whenLowNegative_throw() {
40+
public void binarySearchFirst_whenLowNegative_throw() {
4041
final Data[] list = createList(0, 1);
41-
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> ArrayUtils.binarySearch(list, -1, 0, 0, Data::getValue, Integer::compare));
42+
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () ->
43+
ArrayUtils.binarySearchFirst(list, -1, 0, 0, Data::getValue, Integer::compare));
4244
}
4345

4446
@Test
45-
public void binarySearch_whenEndBeyondLength_throw() {
47+
public void binarySearchFirst_whenEndBeyondLength_throw() {
4648
final Data[] list = createList(0, 1);
47-
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> ArrayUtils.binarySearch(list, 0, 3, 0, Data::getValue, Integer::compare));
49+
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () ->
50+
ArrayUtils.binarySearchFirst(list, 0, 3, 0, Data::getValue, Integer::compare));
4851
}
4952

5053
@Test
51-
public void binarySearch_whenEmpty_returnM1() {
54+
public void binarySearchLast_whenLowHigherThanEnd_throw() {
55+
final Data[] list = createList(0, 1);
56+
assertThrowsExactly(IllegalArgumentException.class, () ->
57+
ArrayUtils.binarySearchLast(list, 1, 0, 0, Data::getValue, Integer::compare));
58+
}
59+
60+
@Test
61+
public void binarySearchFirst_whenEmpty_returnM1() {
5262
final Data[] list = createList();
53-
final int found = ArrayUtils.binarySearch(list, 0, Data::getValue, Integer::compare);
63+
final int found = ArrayUtils.binarySearchFirst(list, 0, Data::getValue, Integer::compare);
5464
assertEquals(-1, found);
5565
}
5666

5767
@Test
58-
public void binarySearch_whenExists_returnIndex() {
68+
public void binarySearchFirst_whenExists_returnIndex() {
5969
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
60-
final int found = ArrayUtils.binarySearch(list, 9, Data::getValue, Integer::compare);
70+
final int found = ArrayUtils.binarySearchFirst(list, 9, Data::getValue, Integer::compare);
6171
assertEquals(5, found);
6272
}
6373

6474
@Test
65-
public void binarySearch_whenNotExistsMiddle_returnMinusInsertion() {
75+
@Timeout(10)
76+
public void binarySearchFirst_whenMultiple_returnFirst() {
77+
final Data[] list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9);
78+
for (int i = 0; i < list.length; ++i) {
79+
if (i > 0 && list[i].value == list[i - 1].value) {
80+
continue;
81+
}
82+
final int found = ArrayUtils.binarySearchFirst(list, list[i].value, Data::getValue, Integer::compare);
83+
assertEquals(i, found);
84+
}
85+
}
86+
87+
@Test
88+
@Timeout(10)
89+
public void binarySearchLast_whenMultiple_returnFirst() {
90+
final Data[] list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9);
91+
for (int i = 0; i < list.length; ++i) {
92+
if (i < list.length - 1 && list[i].value == list[i + 1].value) {
93+
continue;
94+
}
95+
final int found = ArrayUtils.binarySearchLast(list, list[i].value, Data::getValue, Integer::compare);
96+
assertEquals(i, found);
97+
}
98+
}
99+
100+
@Test
101+
public void binarySearchFirst_whenNotExistsMiddle_returnMinusInsertion() {
66102
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
67-
final int found = ArrayUtils.binarySearch(list, 8, Data::getValue, Integer::compare);
103+
final int found = ArrayUtils.binarySearchFirst(list, 8, Data::getValue, Integer::compare);
68104
assertEquals(-6, found);
69105
}
70106

71107
@Test
72-
public void binarySearch_whenNotExistsBeginning_returnMinus1() {
108+
public void binarySearchFirst_whenNotExistsBeginning_returnMinus1() {
73109
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
74-
final int found = ArrayUtils.binarySearch(list, -3, Data::getValue, Integer::compare);
110+
final int found = ArrayUtils.binarySearchFirst(list, -3, Data::getValue, Integer::compare);
75111
assertEquals(-1, found);
76112
}
77113

78114
@Test
79-
public void binarySearch_whenNotExistsEnd_returnMinusLength() {
115+
public void binarySearchFirst_whenNotExistsEnd_returnMinusLength() {
80116
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
81-
final int found = ArrayUtils.binarySearch(list, 29, Data::getValue, Integer::compare);
117+
final int found = ArrayUtils.binarySearchFirst(list, 29, Data::getValue, Integer::compare);
82118
assertEquals(-(list.length + 1), found);
83119
}
84120

85121
@Test
86122
@Timeout(10)
87-
public void binarySearch_whenUnsorted_dontInfiniteLoop() {
123+
public void binarySearchFirst_whenUnsorted_dontInfiniteLoop() {
88124
final Data[] list = createList(7, 1, 4, 9, 11, 8);
89-
final int found = ArrayUtils.binarySearch(list, 10, Data::getValue, Integer::compare);
125+
final int found = ArrayUtils.binarySearchFirst(list, 10, Data::getValue, Integer::compare);
90126
}
91127

92128
private Data[] createList(int... values) {

0 commit comments

Comments
 (0)