Skip to content

Commit cf0a160

Browse files
committed
next_array: Revise safety comments, Drop, and push
1 parent 2c3af5c commit cf0a160

File tree

3 files changed

+202
-29
lines changed

3 files changed

+202
-29
lines changed

src/lib.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1928,9 +1928,9 @@ pub trait Itertools: Iterator {
19281928
///
19291929
/// assert_eq!(Some([1, 2]), iter.next_array());
19301930
/// ```
1931-
fn next_array<T, const N: usize>(&mut self) -> Option<[T; N]>
1931+
fn next_array<const N: usize>(&mut self) -> Option<[Self::Item; N]>
19321932
where
1933-
Self: Sized + Iterator<Item = T>,
1933+
Self: Sized,
19341934
{
19351935
next_array::next_array(self)
19361936
}
@@ -1952,9 +1952,9 @@ pub trait Itertools: Iterator {
19521952
/// panic!("Expected two elements")
19531953
/// }
19541954
/// ```
1955-
fn collect_array<T, const N: usize>(mut self) -> Option<[T; N]>
1955+
fn collect_array<const N: usize>(mut self) -> Option<[Self::Item; N]>
19561956
where
1957-
Self: Sized + Iterator<Item = T>,
1957+
Self: Sized,
19581958
{
19591959
self.next_array().filter(|_| self.next().is_none())
19601960
}

src/next_array.rs

+195-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use core::mem::{self, MaybeUninit};
2-
use core::ptr;
32

43
/// An array of at most `N` elements.
54
struct ArrayBuilder<T, const N: usize> {
@@ -17,7 +16,7 @@ struct ArrayBuilder<T, const N: usize> {
1716
impl<T, const N: usize> ArrayBuilder<T, N> {
1817
/// Initializes a new, empty `ArrayBuilder`.
1918
pub fn new() -> Self {
20-
// SAFETY: the validity invariant trivially hold for a zero-length array.
19+
// SAFETY: The safety invariant of `arr` trivially holds for `len = 0`.
2120
Self {
2221
arr: [(); N].map(|_| MaybeUninit::uninit()),
2322
len: 0,
@@ -28,54 +27,228 @@ impl<T, const N: usize> ArrayBuilder<T, N> {
2827
///
2928
/// # Panics
3029
///
31-
/// This panics if `self.len() >= N`.
30+
/// This panics if `self.len >= N`.
31+
#[inline(always)]
3232
pub fn push(&mut self, value: T) {
33-
// SAFETY: we maintain the invariant here that arr[..len] is valid.
34-
// Indexing with self.len also ensures self.len < N, and thus <= N after
35-
// the increment.
36-
self.arr[self.len] = MaybeUninit::new(value);
33+
// PANICS: This will panic if `self.len >= N`.
34+
let place = &mut self.arr[self.len];
35+
// SAFETY: The safety invariant of `self.arr` applies to elements at
36+
// indices `0..self.len` — not to the element at `self.len`. Writing to
37+
// the element at index `self.len` therefore does not violate the safety
38+
// invariant of `self.arr`. Even if this line panics, we have not
39+
// created any intermediate invalid state.
40+
*place = MaybeUninit::new(value);
41+
// PANICS: This cannot panic, since `self.len < N <= usize::MAX`.
42+
// `0..self.len` are valid. Due to the above write, the element at
43+
// `self.len` is now also valid. Consequently, all elements at indicies
44+
// `0..(self.len + 1)` are valid, and `self.len` can be safely
45+
// incremented without violating `self.arr`'s invariant. It is fine if
46+
// this increment panics, as we have not created any intermediate
47+
// invalid state.
3748
self.len += 1;
3849
}
3950

40-
/// Consumes the elements in the `ArrayBuilder` and returns them as an array `[T; N]`.
51+
/// Consumes the elements in the `ArrayBuilder` and returns them as an array
52+
/// `[T; N]`.
4153
///
4254
/// If `self.len() < N`, this returns `None`.
4355
pub fn take(&mut self) -> Option<[T; N]> {
4456
if self.len == N {
45-
// Take the array, resetting our length back to zero.
57+
// SAFETY: Decreasing the value of `self.len` cannot violate the
58+
// safety invariant on `self.arr`.
4659
self.len = 0;
60+
61+
// SAFETY: Since `self.len` is 0, `self.arr` may safely contain
62+
// uninitialized elements.
4763
let arr = mem::replace(&mut self.arr, [(); N].map(|_| MaybeUninit::uninit()));
4864

49-
// SAFETY: we had len == N, so all elements in arr are valid.
50-
Some(unsafe { arr.map(|v| v.assume_init()) })
65+
Some(arr.map(|v| {
66+
// SAFETY: We know that all elements of `arr` are valid because
67+
// we checked that `len == N`.
68+
unsafe { v.assume_init() }
69+
}))
5170
} else {
5271
None
5372
}
5473
}
5574
}
5675

76+
impl<T, const N: usize> AsMut<[T]> for ArrayBuilder<T, N> {
77+
fn as_mut(&mut self) -> &mut [T] {
78+
let valid = &mut self.arr[..self.len];
79+
// SAFETY: By invariant on `self.arr`, the elements of `self.arr` at
80+
// indices `0..self.len` are in a valid state. Since `valid` references
81+
// only these elements, the safety precondition of
82+
// `slice_assume_init_mut` is satisfied.
83+
unsafe { slice_assume_init_mut(valid) }
84+
}
85+
}
86+
5787
impl<T, const N: usize> Drop for ArrayBuilder<T, N> {
88+
// We provide a non-trivial `Drop` impl, because the trivial impl would be a
89+
// no-op; `MaybeUninit<T>` has no innate awareness of its own validity, and
90+
// so it can only forget its contents. By leveraging the safety invariant of
91+
// `self.arr`, we do know which elements of `self.arr` are valid, and can
92+
// selectively run their destructors.
5893
fn drop(&mut self) {
59-
unsafe {
60-
// SAFETY: arr[..len] is valid, so must be dropped. First we create
61-
// a pointer to this valid slice, then drop that slice in-place.
62-
// The cast from *mut MaybeUninit<T> to *mut T is always sound by
63-
// the layout guarantees of MaybeUninit.
64-
let ptr_to_first: *mut MaybeUninit<T> = self.arr.as_mut_ptr();
65-
let ptr_to_slice = ptr::slice_from_raw_parts_mut(ptr_to_first.cast::<T>(), self.len);
66-
ptr::drop_in_place(ptr_to_slice);
67-
}
94+
// SAFETY:
95+
// - by invariant on `&[T]`, `self.as_mut()` is:
96+
// - valid for reads and writes
97+
// - properly aligned
98+
// - non-null
99+
// - the dropped `T` are valid for dropping; they do not have any
100+
// additional library invariants that we've violated
101+
// - no other pointers to `valid` exist (since we're in the context of
102+
// `drop`)
103+
unsafe { core::ptr::drop_in_place(self.as_mut()) }
68104
}
69105
}
70106

107+
/// Assuming all the elements are initialized, get a mutable slice to them.
108+
///
109+
/// # Safety
110+
///
111+
/// The caller guarantees that the elements `T` referenced by `slice` are in a
112+
/// valid state.
113+
unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
114+
// SAFETY: Casting `&mut [MaybeUninit<T>]` to `&mut [T]` is sound, because
115+
// `MaybeUninit<T>` is guaranteed to have the same size, alignment and ABI
116+
// as `T`, and because the caller has guaranteed that `slice` is in the
117+
// valid state.
118+
unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
119+
}
120+
71121
/// Equivalent to `it.next_array()`.
72-
pub fn next_array<I, T, const N: usize>(it: &mut I) -> Option<[T; N]>
122+
pub(crate) fn next_array<I, const N: usize>(it: &mut I) -> Option<[I::Item; N]>
73123
where
74-
I: Iterator<Item = T>,
124+
I: Iterator,
75125
{
76126
let mut builder = ArrayBuilder::new();
77127
for _ in 0..N {
78128
builder.push(it.next()?);
79129
}
80130
builder.take()
81131
}
132+
133+
#[cfg(test)]
134+
mod test {
135+
use super::ArrayBuilder;
136+
137+
#[test]
138+
fn zero_len_take() {
139+
let mut builder = ArrayBuilder::<(), 0>::new();
140+
let taken = builder.take();
141+
assert_eq!(taken, Some([(); 0]));
142+
}
143+
144+
#[test]
145+
#[should_panic]
146+
fn zero_len_push() {
147+
let mut builder = ArrayBuilder::<(), 0>::new();
148+
builder.push(());
149+
}
150+
151+
#[test]
152+
fn push_4() {
153+
let mut builder = ArrayBuilder::<(), 4>::new();
154+
assert_eq!(builder.take(), None);
155+
156+
builder.push(());
157+
assert_eq!(builder.take(), None);
158+
159+
builder.push(());
160+
assert_eq!(builder.take(), None);
161+
162+
builder.push(());
163+
assert_eq!(builder.take(), None);
164+
165+
builder.push(());
166+
assert_eq!(builder.take(), Some([(); 4]));
167+
}
168+
169+
#[test]
170+
fn tracked_drop() {
171+
use std::panic::{catch_unwind, AssertUnwindSafe};
172+
use std::sync::atomic::{AtomicU16, Ordering};
173+
174+
static DROPPED: AtomicU16 = AtomicU16::new(0);
175+
176+
#[derive(Debug, PartialEq)]
177+
struct TrackedDrop;
178+
179+
impl Drop for TrackedDrop {
180+
fn drop(&mut self) {
181+
DROPPED.fetch_add(1, Ordering::Relaxed);
182+
}
183+
}
184+
185+
{
186+
let builder = ArrayBuilder::<TrackedDrop, 0>::new();
187+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
188+
drop(builder);
189+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
190+
}
191+
192+
{
193+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
194+
builder.push(TrackedDrop);
195+
assert_eq!(builder.take(), None);
196+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
197+
drop(builder);
198+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 1);
199+
}
200+
201+
{
202+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
203+
builder.push(TrackedDrop);
204+
builder.push(TrackedDrop);
205+
assert!(matches!(builder.take(), Some(_)));
206+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 2);
207+
drop(builder);
208+
assert_eq!(DROPPED.load(Ordering::Relaxed), 0);
209+
}
210+
211+
{
212+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
213+
214+
builder.push(TrackedDrop);
215+
builder.push(TrackedDrop);
216+
217+
assert!(catch_unwind(AssertUnwindSafe(|| {
218+
builder.push(TrackedDrop);
219+
}))
220+
.is_err());
221+
222+
assert_eq!(DROPPED.load(Ordering::Relaxed), 1);
223+
224+
drop(builder);
225+
226+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 3);
227+
}
228+
229+
{
230+
let mut builder = ArrayBuilder::<TrackedDrop, 2>::new();
231+
232+
builder.push(TrackedDrop);
233+
builder.push(TrackedDrop);
234+
235+
assert!(catch_unwind(AssertUnwindSafe(|| {
236+
builder.push(TrackedDrop);
237+
}))
238+
.is_err());
239+
240+
assert_eq!(DROPPED.load(Ordering::Relaxed), 1);
241+
242+
assert!(matches!(builder.take(), Some(_)));
243+
244+
assert_eq!(DROPPED.load(Ordering::Relaxed), 3);
245+
246+
builder.push(TrackedDrop);
247+
builder.push(TrackedDrop);
248+
249+
assert!(matches!(builder.take(), Some(_)));
250+
251+
assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 5);
252+
}
253+
}
254+
}

tests/test_core.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ fn next_array() {
380380
assert_eq!(iter.next_array(), Some([]));
381381
assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([1, 2]));
382382
assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([3, 4]));
383-
assert_eq!(iter.next_array::<_, 2>(), None);
383+
assert_eq!(iter.next_array::<2>(), None);
384384
}
385385

386386
#[test]
@@ -391,9 +391,9 @@ fn collect_array() {
391391

392392
let v = [1];
393393
let iter = v.iter().cloned();
394-
assert_eq!(iter.collect_array::<_, 2>(), None);
394+
assert_eq!(iter.collect_array::<2>(), None);
395395

396396
let v = [1, 2, 3];
397397
let iter = v.iter().cloned();
398-
assert_eq!(iter.collect_array::<_, 2>(), None);
398+
assert_eq!(iter.collect_array::<2>(), None);
399399
}

0 commit comments

Comments
 (0)