From acf9436c1bbb20fef17d0670b06c6fe966dc8d91 Mon Sep 17 00:00:00 2001
From: Orson Peters <orsonpeters@gmail.com>
Date: Wed, 5 Jun 2024 02:29:51 +0200
Subject: [PATCH] Add `next_array` and `collect_array`

Co-authored-by: Jack Wrenn <jack@wrenn.fyi>
Co-authored-by: Joshua Liebow-Feeser <joshlf@users.noreply.github.com>
---
 CHANGELOG.md       |   1 +
 src/lib.rs         |  45 ++++++++
 src/next_array.rs  | 269 +++++++++++++++++++++++++++++++++++++++++++++
 tests/test_core.rs |  25 +++++
 4 files changed, 340 insertions(+)
 create mode 100644 src/next_array.rs

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 11392488a..6b08f6830 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@
 ### Added
 - Added `array_combinations` (#991)
 - Added `k_smallest_relaxed` and variants (#925)
+- Added `next_array` and `collect_array` (#560)
 - Implemented `DoubleEndedIterator` for `FilterOk` (#948)
 - Implemented `DoubleEndedIterator` for `FilterMapOk` (#950)
 
diff --git a/src/lib.rs b/src/lib.rs
index 37c3b1f41..20226d88a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -209,6 +209,7 @@ mod merge_join;
 mod minmax;
 #[cfg(feature = "use_alloc")]
 mod multipeek_impl;
+mod next_array;
 mod pad_tail;
 #[cfg(feature = "use_alloc")]
 mod peek_nth;
@@ -1968,6 +1969,50 @@ pub trait Itertools: Iterator {
     }
 
     // non-adaptor methods
+    /// Advances the iterator and returns the next items grouped in an array of
+    /// a specific size.
+    ///
+    /// If there are enough elements to be grouped in an array, then the array
+    /// is returned inside `Some`, otherwise `None` is returned.
+    ///
+    /// ```
+    /// use itertools::Itertools;
+    ///
+    /// let mut iter = 1..5;
+    ///
+    /// assert_eq!(Some([1, 2]), iter.next_array());
+    /// ```
+    fn next_array<const N: usize>(&mut self) -> Option<[Self::Item; N]>
+    where
+        Self: Sized,
+    {
+        next_array::next_array(self)
+    }
+
+    /// Collects all items from the iterator into an array of a specific size.
+    ///
+    /// If the number of elements inside the iterator is **exactly** equal to
+    /// the array size, then the array is returned inside `Some`, otherwise
+    /// `None` is returned.
+    ///
+    /// ```
+    /// use itertools::Itertools;
+    ///
+    /// let iter = 1..3;
+    ///
+    /// if let Some([x, y]) = iter.collect_array() {
+    ///     assert_eq!([x, y], [1, 2])
+    /// } else {
+    ///     panic!("Expected two elements")
+    /// }
+    /// ```
+    fn collect_array<const N: usize>(mut self) -> Option<[Self::Item; N]>
+    where
+        Self: Sized,
+    {
+        self.next_array().filter(|_| self.next().is_none())
+    }
+
     /// Advances the iterator and returns the next items grouped in a tuple of
     /// a specific size (up to 12).
     ///
diff --git a/src/next_array.rs b/src/next_array.rs
new file mode 100644
index 000000000..86480b197
--- /dev/null
+++ b/src/next_array.rs
@@ -0,0 +1,269 @@
+use core::mem::{self, MaybeUninit};
+
+/// An array of at most `N` elements.
+struct ArrayBuilder<T, const N: usize> {
+    /// The (possibly uninitialized) elements of the `ArrayBuilder`.
+    ///
+    /// # Safety
+    ///
+    /// The elements of `arr[..len]` are valid `T`s.
+    arr: [MaybeUninit<T>; N],
+
+    /// The number of leading elements of `arr` that are valid `T`s, len <= N.
+    len: usize,
+}
+
+impl<T, const N: usize> ArrayBuilder<T, N> {
+    /// Initializes a new, empty `ArrayBuilder`.
+    pub fn new() -> Self {
+        // SAFETY: The safety invariant of `arr` trivially holds for `len = 0`.
+        Self {
+            arr: [(); N].map(|_| MaybeUninit::uninit()),
+            len: 0,
+        }
+    }
+
+    /// Pushes `value` onto the end of the array.
+    ///
+    /// # Panics
+    ///
+    /// This panics if `self.len >= N`.
+    #[inline(always)]
+    pub fn push(&mut self, value: T) {
+        // PANICS: This will panic if `self.len >= N`.
+        let place = &mut self.arr[self.len];
+        // SAFETY: The safety invariant of `self.arr` applies to elements at
+        // indices `0..self.len` — not to the element at `self.len`. Writing to
+        // the element at index `self.len` therefore does not violate the safety
+        // invariant of `self.arr`. Even if this line panics, we have not
+        // created any intermediate invalid state.
+        *place = MaybeUninit::new(value);
+        // Lemma: `self.len < N`. By invariant, `self.len <= N`. Above, we index
+        // into `self.arr`, which has size `N`, at index `self.len`. If `self.len == N`
+        // at that point, that index would be out-of-bounds, and the index
+        // operation would panic. Thus, `self.len != N`, and since `self.len <= N`,
+        // that means that `self.len < N`.
+        //
+        // PANICS: Since `self.len < N`, and since `N <= usize::MAX`,
+        // `self.len + 1 <= usize::MAX`, and so `self.len += 1` will not
+        // overflow. Overflow is the only panic condition of `+=`.
+        //
+        // SAFETY:
+        // - We are required to uphold the invariant that `self.len <= N`.
+        //   Since, by the preceding lemma, `self.len < N` at this point in the
+        //   code, `self.len += 1` results in `self.len <= N`.
+        // - We are required to uphold the invariant that `self.arr[..self.len]`
+        //   are valid instances of `T`. Since this invariant already held when
+        //   this method was called, and since we only increment `self.len`
+        //   by 1 here, we only need to prove that the element at
+        //   `self.arr[self.len]` (using the value of `self.len` before incrementing)
+        //   is valid. Above, we construct `place` to point to `self.arr[self.len]`,
+        //   and then initialize `*place` to `MaybeUninit::new(value)`, which is
+        //   a valid `T` by construction.
+        self.len += 1;
+    }
+
+    /// Consumes the elements in the `ArrayBuilder` and returns them as an array
+    /// `[T; N]`.
+    ///
+    /// If `self.len() < N`, this returns `None`.
+    pub fn take(&mut self) -> Option<[T; N]> {
+        if self.len == N {
+            // SAFETY: Decreasing the value of `self.len` cannot violate the
+            // safety invariant on `self.arr`.
+            self.len = 0;
+
+            // SAFETY: Since `self.len` is 0, `self.arr` may safely contain
+            // uninitialized elements.
+            let arr = mem::replace(&mut self.arr, [(); N].map(|_| MaybeUninit::uninit()));
+
+            Some(arr.map(|v| {
+                // SAFETY: We know that all elements of `arr` are valid because
+                // we checked that `len == N`.
+                unsafe { v.assume_init() }
+            }))
+        } else {
+            None
+        }
+    }
+}
+
+impl<T, const N: usize> AsMut<[T]> for ArrayBuilder<T, N> {
+    fn as_mut(&mut self) -> &mut [T] {
+        let valid = &mut self.arr[..self.len];
+        // SAFETY: By invariant on `self.arr`, the elements of `self.arr` at
+        // indices `0..self.len` are in a valid state. Since `valid` references
+        // only these elements, the safety precondition of
+        // `slice_assume_init_mut` is satisfied.
+        unsafe { slice_assume_init_mut(valid) }
+    }
+}
+
+impl<T, const N: usize> Drop for ArrayBuilder<T, N> {
+    // We provide a non-trivial `Drop` impl, because the trivial impl would be a
+    // no-op; `MaybeUninit<T>` has no innate awareness of its own validity, and
+    // so it can only forget its contents. By leveraging the safety invariant of
+    // `self.arr`, we do know which elements of `self.arr` are valid, and can
+    // selectively run their destructors.
+    fn drop(&mut self) {
+        // SAFETY:
+        // - by invariant on `&mut [T]`, `self.as_mut()` is:
+        //   - valid for reads and writes
+        //   - properly aligned
+        //   - non-null
+        // - the dropped `T` are valid for dropping; they do not have any
+        //   additional library invariants that we've violated
+        // - no other pointers to `valid` exist (since we're in the context of
+        //   `drop`)
+        unsafe { core::ptr::drop_in_place(self.as_mut()) }
+    }
+}
+
+/// Assuming all the elements are initialized, get a mutable slice to them.
+///
+/// # Safety
+///
+/// The caller guarantees that the elements `T` referenced by `slice` are in a
+/// valid state.
+unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
+    // SAFETY: Casting `&mut [MaybeUninit<T>]` to `&mut [T]` is sound, because
+    // `MaybeUninit<T>` is guaranteed to have the same size, alignment and ABI
+    // as `T`, and because the caller has guaranteed that `slice` is in the
+    // valid state.
+    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
+}
+
+/// Equivalent to `it.next_array()`.
+pub(crate) fn next_array<I, const N: usize>(it: &mut I) -> Option<[I::Item; N]>
+where
+    I: Iterator,
+{
+    let mut builder = ArrayBuilder::new();
+    for _ in 0..N {
+        builder.push(it.next()?);
+    }
+    builder.take()
+}
+
+#[cfg(test)]
+mod test {
+    use super::ArrayBuilder;
+
+    #[test]
+    fn zero_len_take() {
+        let mut builder = ArrayBuilder::<(), 0>::new();
+        let taken = builder.take();
+        assert_eq!(taken, Some([(); 0]));
+    }
+
+    #[test]
+    #[should_panic]
+    fn zero_len_push() {
+        let mut builder = ArrayBuilder::<(), 0>::new();
+        builder.push(());
+    }
+
+    #[test]
+    fn push_4() {
+        let mut builder = ArrayBuilder::<(), 4>::new();
+        assert_eq!(builder.take(), None);
+
+        builder.push(());
+        assert_eq!(builder.take(), None);
+
+        builder.push(());
+        assert_eq!(builder.take(), None);
+
+        builder.push(());
+        assert_eq!(builder.take(), None);
+
+        builder.push(());
+        assert_eq!(builder.take(), Some([(); 4]));
+    }
+
+    #[test]
+    fn tracked_drop() {
+        use std::panic::{catch_unwind, AssertUnwindSafe};
+        use std::sync::atomic::{AtomicU16, Ordering};
+
+        static DROPPED: AtomicU16 = AtomicU16::new(0);
+
+        #[derive(Debug, PartialEq)]
+        struct TrackedDrop;
+
+        impl Drop for TrackedDrop {
+            fn drop(&mut self) {
+                DROPPED.fetch_add(1, Ordering::Relaxed);
+            }
+        }
+
+        {
+            let builder = ArrayBuilder::<TrackedDrop, 0>::new();
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
+            drop(builder);
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
+        }
+
+        {
+            let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
+            builder.push(TrackedDrop);
+            assert_eq!(builder.take(), None);
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
+            drop(builder);
+            assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 1);
+        }
+
+        {
+            let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
+            builder.push(TrackedDrop);
+            builder.push(TrackedDrop);
+            assert!(matches!(builder.take(), Some(_)));
+            assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 2);
+            drop(builder);
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
+        }
+
+        {
+            let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
+
+            builder.push(TrackedDrop);
+            builder.push(TrackedDrop);
+
+            assert!(catch_unwind(AssertUnwindSafe(|| {
+                builder.push(TrackedDrop);
+            }))
+            .is_err());
+
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 1);
+
+            drop(builder);
+
+            assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 3);
+        }
+
+        {
+            let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
+
+            builder.push(TrackedDrop);
+            builder.push(TrackedDrop);
+
+            assert!(catch_unwind(AssertUnwindSafe(|| {
+                builder.push(TrackedDrop);
+            }))
+            .is_err());
+
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 1);
+
+            assert!(matches!(builder.take(), Some(_)));
+
+            assert_eq!(DROPPED.load(Ordering::Relaxed), 3);
+
+            builder.push(TrackedDrop);
+            builder.push(TrackedDrop);
+
+            assert!(matches!(builder.take(), Some(_)));
+
+            assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 5);
+        }
+    }
+}
diff --git a/tests/test_core.rs b/tests/test_core.rs
index 32af246c0..493616085 100644
--- a/tests/test_core.rs
+++ b/tests/test_core.rs
@@ -372,3 +372,28 @@ fn product1() {
     assert_eq!(v[1..3].iter().cloned().product1::<i32>(), Some(2));
     assert_eq!(v[1..5].iter().cloned().product1::<i32>(), Some(24));
 }
+
+#[test]
+fn next_array() {
+    let v = [1, 2, 3, 4, 5];
+    let mut iter = v.iter();
+    assert_eq!(iter.next_array(), Some([]));
+    assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([1, 2]));
+    assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([3, 4]));
+    assert_eq!(iter.next_array::<2>(), None);
+}
+
+#[test]
+fn collect_array() {
+    let v = [1, 2];
+    let iter = v.iter().cloned();
+    assert_eq!(iter.collect_array(), Some([1, 2]));
+
+    let v = [1];
+    let iter = v.iter().cloned();
+    assert_eq!(iter.collect_array::<2>(), None);
+
+    let v = [1, 2, 3];
+    let iter = v.iter().cloned();
+    assert_eq!(iter.collect_array::<2>(), None);
+}