diff --git a/CHANGELOG.md b/CHANGELOG.md index 793922e..36a2278 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ + Implement `everyNth()` to get every `n`<sup>th</sup> element from the stream + Implement `uniquelyOccurring()` to emit stream elements that occur a single time + Implement `takeUntil()` to take from a stream until a predicate is met, including the first element that matches the predicate ++ Implement `foldIndexed()` to perform a traditional fold along with the index of each element ### 0.7.0 + Use greedy integrators where possible (Fixes #57) diff --git a/README.md b/README.md index a91c48b..2f59272 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ implementation("com.ginsberg:gatherers4j:0.8.0") | `dropLast(n)` | Keep all but the last `n` elements of the stream | | `everyNth(n)` | Limit the stream to every `n`<sup>th</sup> element | | `filterWithIndex(predicate)` | Filter the stream with the given `predicate`, which takes an `element` and its `index` | +| `foldIndexed(fn)` | Perform a fold over the input stream where each element is included along with its index | | `grouping()` | Group consecutive identical elements into lists | | `groupingBy(fn)` | Group consecutive elements that are identical according to `fn` into lists | | `interleave(iterable)` | Creates a stream of alternating objects from the input stream and the argument iterable | diff --git a/src/main/java/com/ginsberg/gatherers4j/FoldIndexedGatherer.java b/src/main/java/com/ginsberg/gatherers4j/FoldIndexedGatherer.java new file mode 100644 index 0000000..1afcf73 --- /dev/null +++ b/src/main/java/com/ginsberg/gatherers4j/FoldIndexedGatherer.java @@ -0,0 +1,70 @@ +/* + * Copyright 2025 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + +import org.jspecify.annotations.Nullable; + +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.Gatherer; + +import static com.ginsberg.gatherers4j.GathererUtils.mustNotBeNull; + +public class FoldIndexedGatherer<INPUT extends @Nullable Object, OUTPUT extends @Nullable Object> + implements Gatherer<INPUT, FoldIndexedGatherer.State<OUTPUT>, OUTPUT> { + + private final TriFunction<Long, OUTPUT, INPUT, OUTPUT> foldFunction; + private final Supplier<OUTPUT> initialValue; + + FoldIndexedGatherer( + final Supplier<OUTPUT> initialValue, + final TriFunction<Long, OUTPUT, INPUT, OUTPUT> foldFunction + ) { + mustNotBeNull(initialValue, "Initial value supplier must not be null"); + mustNotBeNull(foldFunction, "Fold function must not be null"); + this.foldFunction = foldFunction; + this.initialValue = initialValue; + } + + @Override + public Supplier<State<OUTPUT>> initializer() { + return () -> new State<>(initialValue.get()); + } + + @Override + public Integrator<State<OUTPUT>, INPUT, OUTPUT> integrator() { + return Integrator.ofGreedy((state, element, downstream) -> { + state.carriedValue = foldFunction.apply(state.index++, state.carriedValue, element); + return !downstream.isRejecting(); + }); + } + + @Override + public BiConsumer<State<OUTPUT>, Downstream<? super OUTPUT>> finisher() { + return (outputState, downstream) -> downstream.push(outputState.carriedValue); + } + + public static class State<OUTPUT> { + @Nullable + OUTPUT carriedValue; + long index; + + private State(@Nullable final OUTPUT initialValue) { + carriedValue = initialValue; + } + } +} diff --git a/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java b/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java index 21c15fd..e9213b1 100644 --- a/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java +++ b/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java @@ -24,6 +24,7 @@ import java.util.function.BiPredicate; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.random.RandomGenerator; import java.util.stream.Stream; @@ -89,7 +90,7 @@ public abstract class Gatherers4j { /// Keep every nth element of the stream. /// - /// @param count The number of the elements to keep, must be at least 2 + /// @param count The number of the elements to keep, must be at least 2 /// @param <INPUT> Type of elements in both the input and output streams /// @return A non-null `EveryNthGatherer` public static <INPUT extends @Nullable Object> EveryNthGatherer<INPUT> everyNth(final int count) { @@ -100,7 +101,7 @@ public abstract class Gatherers4j { /// and its index. /// /// @param predicate A non-null `BiPredicate<Long,INPUT>` where the `Long` is the zero-based index of the element - /// being filtered, and the `INPUT` is the element itself. + /// being filtered, and the `INPUT` is the element itself. /// @param <INPUT> Type of elements in the input stream /// @return A non-null `FilteringWithIndexGatherer` public static <INPUT extends @Nullable Object> FilteringWithIndexGatherer<INPUT> filterWithIndex( @@ -109,6 +110,20 @@ public abstract class Gatherers4j { return new FilteringWithIndexGatherer<>(predicate); } + /// Perform a fold over every element in the input stream along with its index + /// + /// @param <INPUT> Type of elements in the input stream + /// @param <OUTPUT> Type elements are folded to (the carry value) + /// @param initialValue Initial value of the fold + /// @param foldFunction Function that performs the fold given an element, its index, and the carry value + /// @return A non-null FoldIndexedGatherer + public static <INPUT extends @Nullable Object, OUTPUT extends @Nullable Object> FoldIndexedGatherer<INPUT, OUTPUT> foldIndexed( + final Supplier<OUTPUT> initialValue, + final TriFunction<Long, OUTPUT, INPUT, OUTPUT> foldFunction + ) { + return new FoldIndexedGatherer<>(initialValue, foldFunction); + } + /// Turn a `Stream<INPUT>` into a `Stream<List<INPUT>>` where consecutive /// equal elements, where equality is measured by `Object.equals(Object)`. /// @@ -180,7 +195,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// /// @param windowSize The trailing number of elements to multiply, must be greater than 1. /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the moving product calculation + /// in the moving product calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalMovingProductGatherer` public static <INPUT extends @Nullable Object> BigDecimalMovingProductGatherer<INPUT> movingProductBy( @@ -204,7 +219,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// /// @param windowSize The trailing number of elements to multiply, must be greater than 1. /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the moving sum calculation + /// in the moving sum calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalMovingSumGatherer` public static <INPUT extends @Nullable Object> BigDecimalMovingSumGatherer<INPUT> movingSumBy( @@ -286,7 +301,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// objects mapped from a `Stream<BigDecimal>` via a `mappingFunction`. /// /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the standard deviation calculation + /// in the standard deviation calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalStandardDeviationGatherer` public static <INPUT extends @Nullable Object> BigDecimalStandardDeviationGatherer<INPUT> runningPopulationStandardDeviationBy( @@ -309,7 +324,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// from a `Stream<INPUT>` via a `mappingFunction`. /// /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the product calculation + /// in the product calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalProductGatherer` public static <INPUT extends @Nullable Object> BigDecimalProductGatherer<INPUT> runningProductBy( @@ -332,7 +347,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// from a `Stream<INPUT>` via a `mappingFunction`. /// /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the standard deviation calculation + /// in the standard deviation calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalStandardDeviationGatherer` public static <INPUT extends @Nullable Object> BigDecimalStandardDeviationGatherer<INPUT> runningSampleStandardDeviationBy( @@ -355,7 +370,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// from a `Stream<INPUT>` via a `mappingFunction`. /// /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the running sum calculation + /// in the running sum calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalSumGatherer` public static <INPUT extends @Nullable Object> BigDecimalSumGatherer<INPUT> runningSumBy( @@ -398,7 +413,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// the given function. This is useful when paired with the `withOriginal` function. /// /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the running average calculation + /// in the running average calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalSimpleAverageGatherer` public static <INPUT extends @Nullable Object> BigDecimalSimpleAverageGatherer<INPUT> simpleRunningAverageBy( @@ -421,7 +436,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// /// @param windowSize The number of elements to average, must be greater than 1. /// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used - /// in the moving average calculation + /// in the moving average calculation /// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction` /// @return A non-null `BigDecimalSimpleMovingAverageGatherer` public static <INPUT extends @Nullable Object> BigDecimalSimpleMovingAverageGatherer<INPUT> simpleMovingAverageBy( @@ -445,7 +460,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// Ensure the input stream is greater than `size` elements long, and emit all elements if so. /// If not, throw an `IllegalStateException`. /// - /// @param size The size the stream must be longer than + /// @param size The size the stream must be longer than /// @param <INPUT> Type of elements in both the input and output streams /// @return A non-null `SizeGatherer` /// @throws IllegalStateException when the input stream is not exactly `size` elements long @@ -456,7 +471,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// Ensure the input stream is greater than or equal to `size` elements long, and emit all elements if so. /// If not, throw an `IllegalStateException`. /// - /// @param size The minimum size of the stream + /// @param size The minimum size of the stream /// @param <INPUT> Type of elements in both the input and output streams /// @return A non-null `SizeGatherer` /// @throws IllegalStateException when the input stream is not exactly `size` elements long @@ -467,7 +482,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// Ensure the input stream is less than `size` elements long, and emit all elements if so. /// If not, throw an `IllegalStateException`. /// - /// @param size The size the stream must be shorter than + /// @param size The size the stream must be shorter than /// @param <INPUT> Type of elements in both the input and output streams /// @return A non-null `SizeGatherer` /// @throws IllegalStateException when the input stream is not exactly `size` elements long @@ -478,7 +493,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// Ensure the input stream is less than or equal to `size` elements long, and emit all elements if so. /// If not, throw an `IllegalStateException`. /// - /// @param size The maximum size the stream + /// @param size The maximum size the stream /// @param <INPUT> Type of elements in both the input and output streams /// @return A non-null `SizeGatherer` /// @throws IllegalStateException when the input stream is not exactly `size` elements long @@ -490,7 +505,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// matches the `predicate`. /// /// @param predicate A non-null predicate function - /// @param <INPUT> Type of elements in both the input and output streams + /// @param <INPUT> Type of elements in both the input and output streams /// @return A non-null `TakeUntilGatherer` public static <INPUT extends @Nullable Object> TakeUntilGatherer<INPUT> takeUntil( final Predicate<INPUT> predicate @@ -514,7 +529,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) { /// Emit only those elements that occur in the input stream a single time. /// - /// @param <INPUT> Type of elements in the input stream + /// @param <INPUT> Type of elements in the input stream /// @return A non-null `UniquelyOccurringGatherer` public static <INPUT extends @Nullable Object> UniquelyOccurringGatherer<INPUT> uniquelyOccurring() { return new UniquelyOccurringGatherer<>(); diff --git a/src/main/java/com/ginsberg/gatherers4j/TriFunction.java b/src/main/java/com/ginsberg/gatherers4j/TriFunction.java new file mode 100644 index 0000000..dbc76ab --- /dev/null +++ b/src/main/java/com/ginsberg/gatherers4j/TriFunction.java @@ -0,0 +1,35 @@ +/* + * Copyright 2025 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + +import org.jspecify.annotations.Nullable; + +@FunctionalInterface +public interface TriFunction< + A extends @Nullable Object, + B extends @Nullable Object, + C extends @Nullable Object, + R extends @Nullable Object> { + + /// Applies this function to the given arguments + /// + /// @param a the first function argument + /// @param b the second function argument + /// @param c the third function argument + /// @return the function result + R apply(A a, B b, C c); +} \ No newline at end of file diff --git a/src/test/java/com/ginsberg/gatherers4j/FoldIndexedGathererTest.java b/src/test/java/com/ginsberg/gatherers4j/FoldIndexedGathererTest.java new file mode 100644 index 0000000..ab696d4 --- /dev/null +++ b/src/test/java/com/ginsberg/gatherers4j/FoldIndexedGathererTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2025 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class FoldIndexedGathererTest { + + @Test + @SuppressWarnings("DataFlowIssue") + void foldFunctionMustNotBeNull() { + assertThatThrownBy(() -> new FoldIndexedGatherer<>(() -> "A", null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void foldWithIndex() { + // Arrange + final Stream<String> input = Stream.of("A", "B", "C", "D"); + + // Act + final List<IndexedValue<String>> output = input + .gather( + Gatherers4j.foldIndexed( + () -> new ArrayList<IndexedValue<String>>(), + (index, carry, next) -> { + carry.add(new IndexedValue<>(index, next)); + return carry; + } + ) + ) + .toList() + .getFirst(); + + // Assert + assertThat(output) + .containsExactly( + new IndexedValue<>(0, "A"), + new IndexedValue<>(1, "B"), + new IndexedValue<>(2, "C"), + new IndexedValue<>(3, "D") + ); + } + + @Test + void foldWithIndexInteger() { + // Arrange + final Stream<Integer> input = Stream.of(1, 2, 3, 4, 5, 6); + + // Act + final int output = input + .gather( + Gatherers4j.foldIndexed( + () -> 0, + (index, carry, next) -> index % 2 == 0 ? carry + next : carry + ) + ) + .toList() + .getFirst(); + + // Assert + assertThat(output).isEqualTo(9); + } + + @Test + @SuppressWarnings("DataFlowIssue") + void initialSupplierMustNotBeNull() { + assertThatThrownBy(() -> new FoldIndexedGatherer<>(null, (_, it, _) -> it)).isInstanceOf(IllegalArgumentException.class); + } + +} \ No newline at end of file