-
Notifications
You must be signed in to change notification settings - Fork 180
Expand file tree
/
Copy pathmemory.rs
More file actions
390 lines (328 loc) · 11.9 KB
/
Copy pathmemory.rs
File metadata and controls
390 lines (328 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
//! Session-scoped memory allocation for host-side buffers.
use std::any::Any;
use std::fmt::Debug;
use std::mem::size_of;
use std::sync::Arc;
use bytes::Bytes;
use vortex_buffer::Alignment;
use vortex_buffer::Buffer;
use vortex_buffer::ByteBuffer;
use vortex_buffer::ByteBufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_session::SessionExt;
use vortex_session::SessionGuard;
use vortex_session::SessionVar;
use vortex_session::VortexSession;
/// Mutable host buffer contract used by [`WritableHostBuffer`].
pub trait HostBufferMut: Send + 'static {
/// Returns the logical byte length of the buffer.
fn len(&self) -> usize;
/// Whether the buffer is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the alignment of the buffer.
fn alignment(&self) -> Alignment;
/// Returns mutable access to the writable byte range.
fn as_mut_slice(&mut self) -> &mut [u8];
/// Freeze the buffer into an immutable [`ByteBuffer`].
fn freeze(self: Box<Self>) -> ByteBuffer;
}
/// Exact-size writable host buffer returned by a [`HostAllocator`].
pub struct WritableHostBuffer {
inner: Box<dyn HostBufferMut>,
}
impl WritableHostBuffer {
/// Create a writable host buffer from an implementation of [`HostBufferMut`].
pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
Self { inner }
}
/// Returns the logical byte length of the buffer.
pub fn len(&self) -> usize {
self.inner.len()
}
/// Returns true when the buffer has zero bytes.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the alignment of the buffer.
pub fn alignment(&self) -> Alignment {
self.inner.alignment()
}
/// Returns mutable access to the writable byte range.
pub fn as_mut_slice(&mut self) -> &mut [u8] {
self.inner.as_mut_slice()
}
/// Returns mutable access to the buffer as a typed slice.
pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
vortex_ensure!(
size_of::<T>() != 0,
InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
std::any::type_name::<T>()
);
vortex_ensure!(
self.alignment().is_aligned_to(Alignment::of::<T>()),
InvalidArgument: "Buffer is not sufficiently aligned for type {}",
std::any::type_name::<T>()
);
let bytes = self.as_mut_slice();
let byte_len = bytes.len();
let ptr = bytes.as_mut_ptr();
let type_size = size_of::<T>();
vortex_ensure!(
byte_len.is_multiple_of(type_size),
InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
type_size,
std::any::type_name::<T>()
);
// SAFETY: We checked size divisibility and pointer alignment for `T`,
// and we have exclusive mutable access to the underlying bytes.
Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
}
/// Freeze the writable buffer into an immutable [`ByteBuffer`].
pub fn freeze(self) -> ByteBuffer {
self.inner.freeze()
}
/// Freeze the writable buffer into a typed immutable [`Buffer<T>`].
pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
vortex_ensure!(
size_of::<T>() != 0,
InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
std::any::type_name::<T>()
);
let buffer = self.freeze();
let byte_len = buffer.len();
let type_size = size_of::<T>();
let type_align = Alignment::of::<T>();
vortex_ensure!(
byte_len.is_multiple_of(type_size),
InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
type_size,
std::any::type_name::<T>()
);
vortex_ensure!(
buffer.is_aligned(type_align),
InvalidArgument: "Buffer pointer is not aligned to {} for {}",
type_align,
std::any::type_name::<T>()
);
Ok(Buffer::from_byte_buffer(buffer))
}
}
impl Debug for WritableHostBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WritableHostBuffer")
.field("len", &self.len())
.field("alignment", &self.alignment())
.finish()
}
}
/// Allocator for exact-size writable host buffers.
pub trait HostAllocator: Debug + Send + Sync + 'static {
/// Allocate a writable host buffer with the requested byte length and alignment.
fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
}
/// Shared allocator reference used throughout session-scoped memory APIs.
pub type HostAllocatorRef = Arc<dyn HostAllocator>;
/// Extension methods for [`HostAllocator`]s.
pub trait HostAllocatorExt: HostAllocator {
/// Allocate host memory for `len` elements of `T` using `Alignment::of::<T>()`.
fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
vortex_err!(
"Typed host allocation overflow for type {} and len {}",
std::any::type_name::<T>(),
len
)
})?;
self.allocate(bytes, Alignment::of::<T>())
}
}
impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
/// Session-scoped memory configuration for Vortex arrays.
#[derive(Clone, Debug)]
pub struct MemorySession {
allocator: HostAllocatorRef,
}
impl MemorySession {
/// Creates a new session memory configuration using the provided allocator.
pub fn new(allocator: HostAllocatorRef) -> Self {
Self { allocator }
}
/// Returns the configured allocator.
pub fn allocator(&self) -> HostAllocatorRef {
Arc::clone(&self.allocator)
}
/// Updates the configured allocator.
pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
self.allocator = allocator;
}
}
impl Default for MemorySession {
fn default() -> Self {
Self::new(Arc::new(DefaultHostAllocator))
}
}
impl SessionVar for MemorySession {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
/// Extension trait for accessing session-scoped memory configuration.
pub trait MemorySessionExt: SessionExt {
/// Returns the memory session for this execution/session context.
fn memory(&self) -> SessionGuard<'_, MemorySession> {
self.get::<MemorySession>()
}
/// Returns the configured host allocator for this execution/session context.
fn allocator(&self) -> HostAllocatorRef {
self.memory().allocator()
}
/// Configures the session to use `allocator` as its host allocator, mutating it in place and
/// returning it for chaining.
fn with_allocator(self, allocator: HostAllocatorRef) -> VortexSession {
let session = self.session();
session.get_mut::<MemorySession>().set_allocator(allocator);
session
}
}
impl<S: SessionExt> MemorySessionExt for S {}
/// Default host allocator.
#[derive(Debug, Default)]
pub struct DefaultHostAllocator;
impl HostAllocator for DefaultHostAllocator {
fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
// SAFETY: We fully initialize this slice before freezing it.
unsafe { buffer.set_len(len) };
Ok(WritableHostBuffer::new(Box::new(
DefaultWritableHostBuffer { buffer, alignment },
)))
}
}
#[derive(Debug)]
struct DefaultWritableHostBuffer {
buffer: ByteBufferMut,
alignment: Alignment,
}
#[derive(Debug)]
struct HostBufferOwner {
buffer: ByteBufferMut,
}
impl AsRef<[u8]> for HostBufferOwner {
fn as_ref(&self) -> &[u8] {
self.buffer.as_slice()
}
}
impl HostBufferMut for DefaultWritableHostBuffer {
fn len(&self) -> usize {
self.buffer.len()
}
fn alignment(&self) -> Alignment {
self.alignment
}
fn as_mut_slice(&mut self) -> &mut [u8] {
self.buffer.as_mut_slice()
}
fn freeze(self: Box<Self>) -> ByteBuffer {
let Self { buffer, alignment } = *self;
let bytes = Bytes::from_owner(HostBufferOwner { buffer });
ByteBuffer::from_bytes_aligned(bytes, alignment)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use super::*;
#[derive(Debug)]
struct CountingAllocator {
allocations: Arc<AtomicUsize>,
}
impl HostAllocator for CountingAllocator {
fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
self.allocations.fetch_add(1, Ordering::Relaxed);
DefaultHostAllocator.allocate(len, alignment)
}
}
#[test]
fn writable_host_buffer_freeze_round_trip() {
let allocator = DefaultHostAllocator;
let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
*byte = u8::try_from(idx).unwrap();
}
let host = writable.freeze();
assert_eq!(host.len(), 16);
assert!(host.is_aligned(Alignment::new(8)));
assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
}
#[test]
fn memory_session_replaces_allocator() {
let allocations = Arc::new(AtomicUsize::new(0));
let allocator = Arc::new(CountingAllocator {
allocations: Arc::clone(&allocations),
});
let mut session = MemorySession::default();
session.set_allocator(allocator);
drop(session.allocator().allocate(4, Alignment::none()).unwrap());
assert_eq!(allocations.load(Ordering::Relaxed), 1);
}
#[test]
fn typed_allocation_uses_type_alignment() {
let allocator = DefaultHostAllocator;
let writable = allocator.allocate_typed::<u64>(4).unwrap();
assert_eq!(writable.len(), 4 * size_of::<u64>());
assert_eq!(writable.alignment(), Alignment::of::<u64>());
}
#[test]
fn typed_mut_slice_round_trip() {
let allocator = DefaultHostAllocator;
let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
writable
.as_mut_slice_typed::<u64>()
.unwrap()
.copy_from_slice(&[10, 20, 30, 40]);
let frozen = writable.freeze();
let values = unsafe {
std::slice::from_raw_parts(
frozen.as_slice().as_ptr().cast::<u64>(),
frozen.len() / size_of::<u64>(),
)
};
assert_eq!(values, [10, 20, 30, 40]);
}
#[test]
fn typed_mut_slice_rejects_length_mismatch() {
let allocator = DefaultHostAllocator;
let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
assert!(writable.as_mut_slice_typed::<u32>().is_err());
}
#[test]
fn freeze_typed_round_trip() {
let allocator = DefaultHostAllocator;
let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
writable
.as_mut_slice_typed::<u64>()
.unwrap()
.copy_from_slice(&[1, 3, 5, 7]);
let frozen = writable.freeze_typed::<u64>().unwrap();
assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
}
#[test]
fn freeze_typed_rejects_length_mismatch() {
let allocator = DefaultHostAllocator;
let writable = allocator.allocate(7, Alignment::none()).unwrap();
let err = writable.freeze_typed::<u32>().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("not a multiple of"));
}
}