Skip to content

Commit 0cfa849

Browse files
orlpjswrennjoshlf
committed
Add next_array and collect_array
Co-authored-by: Jack Wrenn <[email protected]> Co-authored-by: Joshua Liebow-Feeser <[email protected]>
1 parent ff0c942 commit 0cfa849

File tree

3 files changed

+339
-0
lines changed

3 files changed

+339
-0
lines changed

src/lib.rs

+45
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ mod merge_join;
209209
mod minmax;
210210
#[cfg(feature = "use_alloc")]
211211
mod multipeek_impl;
212+
mod next_array;
212213
mod pad_tail;
213214
#[cfg(feature = "use_alloc")]
214215
mod peek_nth;
@@ -1968,6 +1969,50 @@ pub trait Itertools: Iterator {
19681969
}
19691970

19701971
// non-adaptor methods
1972+
/// Advances the iterator and returns the next items grouped in an array of
1973+
/// a specific size.
1974+
///
1975+
/// If there are enough elements to be grouped in an array, then the array
1976+
/// is returned inside `Some`, otherwise `None` is returned.
1977+
///
1978+
/// ```
1979+
/// use itertools::Itertools;
1980+
///
1981+
/// let mut iter = 1..5;
1982+
///
1983+
/// assert_eq!(Some([1, 2]), iter.next_array());
1984+
/// ```
1985+
fn next_array<const N: usize>(&mut self) -> Option<[Self::Item; N]>
1986+
where
1987+
Self: Sized,
1988+
{
1989+
next_array::next_array(self)
1990+
}
1991+
1992+
/// Collects all items from the iterator into an array of a specific size.
1993+
///
1994+
/// If the number of elements inside the iterator is **exactly** equal to
1995+
/// the array size, then the array is returned inside `Some`, otherwise
1996+
/// `None` is returned.
1997+
///
1998+
/// ```
1999+
/// use itertools::Itertools;
2000+
///
2001+
/// let iter = 1..3;
2002+
///
2003+
/// if let Some([x, y]) = iter.collect_array() {
2004+
/// assert_eq!([x, y], [1, 2])
2005+
/// } else {
2006+
/// panic!("Expected two elements")
2007+
/// }
2008+
/// ```
2009+
fn collect_array<const N: usize>(mut self) -> Option<[Self::Item; N]>
2010+
where
2011+
Self: Sized,
2012+
{
2013+
self.next_array().filter(|_| self.next().is_none())
2014+
}
2015+
19712016
/// Advances the iterator and returns the next items grouped in a tuple of
19722017
/// a specific size (up to 12).
19732018
///

src/next_array.rs

+269
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
use core::mem::{self, MaybeUninit};
2+
3+
/// An array of at most `N` elements.
4+
struct ArrayBuilder<T, const N: usize> {
5+
/// The (possibly uninitialized) elements of the `ArrayBuilder`.
6+
///
7+
/// # Safety
8+
///
9+
/// The elements of `arr[..len]` are valid `T`s.
10+
arr: [MaybeUninit<T>; N],
11+
12+
/// The number of leading elements of `arr` that are valid `T`s, len <= N.
13+
len: usize,
14+
}
15+
16+
impl<T, const N: usize> ArrayBuilder<T, N> {
17+
/// Initializes a new, empty `ArrayBuilder`.
18+
pub fn new() -> Self {
19+
// SAFETY: The safety invariant of `arr` trivially holds for `len = 0`.
20+
Self {
21+
arr: [(); N].map(|_| MaybeUninit::uninit()),
22+
len: 0,
23+
}
24+
}
25+
26+
/// Pushes `value` onto the end of the array.
27+
///
28+
/// # Panics
29+
///
30+
/// This panics if `self.len >= N`.
31+
#[inline(always)]
32+
pub fn push(&mut self, value: T) {
33+
// PANICS: This will panic if `self.len >= N`.
34+
let place = &mut self.arr[self.len];
35+
// SAFETY: The safety invariant of `self.arr` applies to elements at
36+
// indices `0..self.len` — not to the element at `self.len`. Writing to
37+
// the element at index `self.len` therefore does not violate the safety
38+
// invariant of `self.arr`. Even if this line panics, we have not
39+
// created any intermediate invalid state.
40+
*place = MaybeUninit::new(value);
41+
// Lemma: `self.len < N`. By invariant, `self.len <= N`. Above, we index
42+
// into `self.arr`, which has size `N`, at index `self.len`. If `self.len == N`
43+
// at that point, that index would be out-of-bounds, and the index
44+
// operation would panic. Thus, `self.len != N`, and since `self.len <= N`,
45+
// that means that `self.len < N`.
46+
//
47+
// PANICS: Since `self.len < N`, and since `N <= usize::MAX`,
48+
// `self.len + 1 <= usize::MAX`, and so `self.len += 1` will not
49+
// overflow. Overflow is the only panic condition of `+=`.
50+
//
51+
// SAFETY:
52+
// - We are required to uphold the invariant that `self.len <= N`.
53+
// Since, by the preceding lemma, `self.len < N` at this point in the
54+
// code, `self.len += 1` results in `self.len <= N`.
55+
// - We are required to uphold the invariant that `self.arr[..self.len]`
56+
// are valid instances of `T`. Since this invariant already held when
57+
// this method was called, and since we only increment `self.len`
58+
// by 1 here, we only need to prove that the element at
59+
// `self.arr[self.len]` (using the value of `self.len` before incrementing)
60+
// is valid. Above, we construct `place` to point to `self.arr[self.len]`,
61+
// and then initialize `*place` to `MaybeUninit::new(value)`, which is
62+
// a valid `T` by construction.
63+
self.len += 1;
64+
}
65+
66+
/// Consumes the elements in the `ArrayBuilder` and returns them as an array
67+
/// `[T; N]`.
68+
///
69+
/// If `self.len() < N`, this returns `None`.
70+
pub fn take(&mut self) -> Option<[T; N]> {
71+
if self.len == N {
72+
// SAFETY: Decreasing the value of `self.len` cannot violate the
73+
// safety invariant on `self.arr`.
74+
self.len = 0;
75+
76+
// SAFETY: Since `self.len` is 0, `self.arr` may safely contain
77+
// uninitialized elements.
78+
let arr = mem::replace(&mut self.arr, [(); N].map(|_| MaybeUninit::uninit()));
79+
80+
Some(arr.map(|v| {
81+
// SAFETY: We know that all elements of `arr` are valid because
82+
// we checked that `len == N`.
83+
unsafe { v.assume_init() }
84+
}))
85+
} else {
86+
None
87+
}
88+
}
89+
}
90+
91+
impl<T, const N: usize> AsMut<[T]> for ArrayBuilder<T, N> {
92+
fn as_mut(&mut self) -> &mut [T] {
93+
let valid = &mut self.arr[..self.len];
94+
// SAFETY: By invariant on `self.arr`, the elements of `self.arr` at
95+
// indices `0..self.len` are in a valid state. Since `valid` references
96+
// only these elements, the safety precondition of
97+
// `slice_assume_init_mut` is satisfied.
98+
unsafe { slice_assume_init_mut(valid) }
99+
}
100+
}
101+
102+
impl<T, const N: usize> Drop for ArrayBuilder<T, N> {
103+
// We provide a non-trivial `Drop` impl, because the trivial impl would be a
104+
// no-op; `MaybeUninit<T>` has no innate awareness of its own validity, and
105+
// so it can only forget its contents. By leveraging the safety invariant of
106+
// `self.arr`, we do know which elements of `self.arr` are valid, and can
107+
// selectively run their destructors.
108+
fn drop(&mut self) {
109+
// SAFETY:
110+
// - by invariant on `&mut [T]`, `self.as_mut()` is:
111+
// - valid for reads and writes
112+
// - properly aligned
113+
// - non-null
114+
// - the dropped `T` are valid for dropping; they do not have any
115+
// additional library invariants that we've violated
116+
// - no other pointers to `valid` exist (since we're in the context of
117+
// `drop`)
118+
unsafe { core::ptr::drop_in_place(self.as_mut()) }
119+
}
120+
}
121+
122+
/// Assuming all the elements are initialized, get a mutable slice to them.
123+
///
124+
/// # Safety
125+
///
126+
/// The caller guarantees that the elements `T` referenced by `slice` are in a
127+
/// valid state.
128+
unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
129+
// SAFETY: Casting `&mut [MaybeUninit<T>]` to `&mut [T]` is sound, because
130+
// `MaybeUninit<T>` is guaranteed to have the same size, alignment and ABI
131+
// as `T`, and because the caller has guaranteed that `slice` is in the
132+
// valid state.
133+
unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
134+
}
135+
136+
/// Equivalent to `it.next_array()`.
137+
pub(crate) fn next_array<I, const N: usize>(it: &mut I) -> Option<[I::Item; N]>
138+
where
139+
I: Iterator,
140+
{
141+
let mut builder = ArrayBuilder::new();
142+
for _ in 0..N {
143+
builder.push(it.next()?);
144+
}
145+
builder.take()
146+
}
147+
148+
#[cfg(test)]
149+
mod test {
150+
use super::ArrayBuilder;
151+
152+
#[test]
153+
fn zero_len_take() {
154+
let mut builder = ArrayBuilder::<(), 0>::new();
155+
let taken = builder.take();
156+
assert_eq!(taken, Some([(); 0]));
157+
}
158+
159+
#[test]
160+
#[should_panic]
161+
fn zero_len_push() {
162+
let mut builder = ArrayBuilder::<(), 0>::new();
163+
builder.push(());
164+
}
165+
166+
#[test]
167+
fn push_4() {
168+
let mut builder = ArrayBuilder::<(), 4>::new();
169+
assert_eq!(builder.take(), None);
170+
171+
builder.push(());
172+
assert_eq!(builder.take(), None);
173+
174+
builder.push(());
175+
assert_eq!(builder.take(), None);
176+
177+
builder.push(());
178+
assert_eq!(builder.take(), None);
179+
180+
builder.push(());
181+
assert_eq!(builder.take(), Some([(); 4]));
182+
}
183+
184+
#[test]
185+
fn tracked_drop() {
186+
use std::panic::{catch_unwind, AssertUnwindSafe};
187+
use std::sync::atomic::{AtomicU16, Ordering};
188+
189+
static DROPPED: AtomicU16 = AtomicU16::new(0);
190+
191+
#[derive(Debug, PartialEq)]
192+
struct TrackedDrop;
193+
194+
impl Drop for TrackedDrop {
195+
fn drop(&mut self) {
196+
DROPPED.fetch_add(1, Ordering::Relaxed);
197+
}
198+
}
199+
200+
{
201+
let builder = ArrayBuilder::<TrackedDrop, 0>::new();
202+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
203+
drop(builder);
204+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
205+
}
206+
207+
{
208+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
209+
builder.push(TrackedDrop);
210+
assert_eq!(builder.take(), None);
211+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
212+
drop(builder);
213+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 1);
214+
}
215+
216+
{
217+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
218+
builder.push(TrackedDrop);
219+
builder.push(TrackedDrop);
220+
assert!(matches!(builder.take(), Some(_)));
221+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 2);
222+
drop(builder);
223+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
224+
}
225+
226+
{
227+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
228+
229+
builder.push(TrackedDrop);
230+
builder.push(TrackedDrop);
231+
232+
assert!(catch_unwind(AssertUnwindSafe(|| {
233+
builder.push(TrackedDrop);
234+
}))
235+
.is_err());
236+
237+
assert_eq!(DROPPED.load(Ordering::Relaxed), 1);
238+
239+
drop(builder);
240+
241+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 3);
242+
}
243+
244+
{
245+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
246+
247+
builder.push(TrackedDrop);
248+
builder.push(TrackedDrop);
249+
250+
assert!(catch_unwind(AssertUnwindSafe(|| {
251+
builder.push(TrackedDrop);
252+
}))
253+
.is_err());
254+
255+
assert_eq!(DROPPED.load(Ordering::Relaxed), 1);
256+
257+
assert!(matches!(builder.take(), Some(_)));
258+
259+
assert_eq!(DROPPED.load(Ordering::Relaxed), 3);
260+
261+
builder.push(TrackedDrop);
262+
builder.push(TrackedDrop);
263+
264+
assert!(matches!(builder.take(), Some(_)));
265+
266+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 5);
267+
}
268+
}
269+
}

tests/test_core.rs

+25
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,28 @@ fn product1() {
372372
assert_eq!(v[1..3].iter().cloned().product1::<i32>(), Some(2));
373373
assert_eq!(v[1..5].iter().cloned().product1::<i32>(), Some(24));
374374
}
375+
376+
#[test]
377+
fn next_array() {
378+
let v = [1, 2, 3, 4, 5];
379+
let mut iter = v.iter();
380+
assert_eq!(iter.next_array(), Some([]));
381+
assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([1, 2]));
382+
assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([3, 4]));
383+
assert_eq!(iter.next_array::<2>(), None);
384+
}
385+
386+
#[test]
387+
fn collect_array() {
388+
let v = [1, 2];
389+
let iter = v.iter().cloned();
390+
assert_eq!(iter.collect_array(), Some([1, 2]));
391+
392+
let v = [1];
393+
let iter = v.iter().cloned();
394+
assert_eq!(iter.collect_array::<2>(), None);
395+
396+
let v = [1, 2, 3];
397+
let iter = v.iter().cloned();
398+
assert_eq!(iter.collect_array::<2>(), None);
399+
}

0 commit comments

Comments
 (0)