Skip to content

Commit

Permalink
Added Array::retrieve_chunk_if_exists and variants
Browse files Browse the repository at this point in the history
  • Loading branch information
LDeakin committed Feb 10, 2024
1 parent 6d9bf07 commit c72757e
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 139 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `{Chunk,Array}Representation::shape_u64`
- Added `ChunkGridTraits::chunk_shape_u64` to `ChunkGridTraits`
- **Breaking** added `chunk_shape_u64_unchecked` to `ChunkGridTraits` which must be implemented by chunk grids
- Added `Array::retrieve_chunk_if_exists` and variants (`async_`, `_elements`, `_ndarray`)

### Changed
- Dependency bumps
Expand Down
38 changes: 36 additions & 2 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,31 @@ fn iter_u64_to_usize<'a, I: Iterator<Item = &'a u64>>(iter: I) -> Vec<usize> {
.collect::<Vec<_>>()
}

fn validate_element_size<T>(data_type: &DataType) -> Result<(), ArrayError> {
if data_type.size() == std::mem::size_of::<T>() {
Ok(())
} else {
Err(ArrayError::IncompatibleElementSize(
data_type.size(),
std::mem::size_of::<T>(),
))
}
}

#[cfg(feature = "ndarray")]
fn elements_to_ndarray<T>(
shape: &[u64],
elements: Vec<T>,
) -> Result<ndarray::ArrayD<T>, ArrayError> {
let length = elements.len();
ndarray::ArrayD::<T>::from_shape_vec(iter_u64_to_usize(shape.iter()), elements).map_err(|_| {
ArrayError::CodecError(codec::CodecError::UnexpectedChunkDecodedSize(
length * std::mem::size_of::<T>(),
shape.iter().product::<u64>() * std::mem::size_of::<T>() as u64,
))
})
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
Expand Down Expand Up @@ -765,7 +790,7 @@ mod tests {
array
.store_array_subset_elements::<f32>(
&ArraySubset::new_with_ranges(&[3..6, 3..6]),
vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
vec![1.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
)
.unwrap();

Expand All @@ -781,13 +806,22 @@ mod tests {
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // 0
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // 1
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // 2
1.0, 1.0, 1.0, 0.1, 0.2, 0.3, 1.0, 1.0, //_3____________
1.0, 1.0, 1.0, 1.0, 0.2, 0.3, 1.0, 1.0, //_3____________
1.0, 1.0, 1.0, 0.4, 0.5, 0.6, 1.0, 1.0, // 4
1.0, 1.0, 1.0, 0.7, 0.8, 0.9, 1.0, 1.0, // 5 (1, 1)
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // 6
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // 7
]
);
assert!(array
.retrieve_chunk_elements_if_exists::<f32>(&[0; 2])
.unwrap()
.is_none());
#[cfg(feature = "ndarray")]
assert!(array
.retrieve_chunk_ndarray_if_exists::<f32>(&[0; 2])
.unwrap()
.is_none());
}

fn array_subset_locking(locks: StoreLocks, expect_equal: bool) {
Expand Down
184 changes: 113 additions & 71 deletions src/array/array_async_readable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,13 @@ use super::{
ArrayCodecTraits, ArrayToBytesCodecTraits, AsyncArrayPartialDecoderTraits,
AsyncStoragePartialDecoder,
},
transmute_from_bytes_vec,
unsafe_cell_slice::UnsafeCellSlice,
Array, ArrayCreateError, ArrayError, ArrayMetadata,
validate_element_size, Array, ArrayCreateError, ArrayError, ArrayMetadata,
};

// FIXME: Matches array_retrieve_elements with await
macro_rules! array_async_retrieve_elements {
( $self:expr, $func:ident($($arg:tt)*) ) => {
if $self.data_type.size() != std::mem::size_of::<T>() {
Err(ArrayError::IncompatibleElementSize(
$self.data_type.size(),
std::mem::size_of::<T>(),
))
} else {
let bytes = $self.$func($($arg)*).await?;
let elements = crate::array::transmute_from_bytes_vec::<T>(bytes);
Ok(elements)
}
};
}

// FIXME: Matches array_retrieve_ndarray with await
#[cfg(feature = "ndarray")]
macro_rules! array_async_retrieve_ndarray {
( $self:expr, $shape:expr, $func:ident($($arg:tt)*) ) => {
if $self.data_type.size() != std::mem::size_of::<T>() {
Err(ArrayError::IncompatibleElementSize(
$self.data_type.size(),
std::mem::size_of::<T>(),
))
} else {
let elements = $self.$func($($arg)*).await?;
let length = elements.len();
ndarray::ArrayD::<T>::from_shape_vec(
super::iter_u64_to_usize($shape.iter()),
elements,
)
.map_err(|_| {
ArrayError::CodecError(crate::array::codec::CodecError::UnexpectedChunkDecodedSize(
length * std::mem::size_of::<T>(),
$shape.iter().product::<u64>() * std::mem::size_of::<T>() as u64,
))
})
}
};
}
use super::elements_to_ndarray;

impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
/// Create an array in `storage` at `path`. The metadata is read from the store.
Expand All @@ -78,7 +40,7 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
Self::new_with_metadata(storage, path, metadata)
}

/// Read and decode the chunk at `chunk_indices` into its bytes.
/// Read and decode the chunk at `chunk_indices` into its bytes if it exists.
///
/// # Errors
/// Returns an [`ArrayError`] if
Expand All @@ -88,7 +50,10 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
///
/// # Panics
/// Panics if the number of elements in the chunk exceeds `usize::MAX`.
pub async fn async_retrieve_chunk(&self, chunk_indices: &[u64]) -> Result<Vec<u8>, ArrayError> {
pub async fn async_retrieve_chunk_if_exists(
&self,
chunk_indices: &[u64],
) -> Result<Option<Vec<u8>>, ArrayError> {
let storage_handle = Arc::new(StorageHandle::new(&*self.storage));
let storage_transformer = self
.storage_transformers()
Expand All @@ -101,8 +66,8 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
)
.await
.map_err(ArrayError::StorageError)?;
let chunk_representation = self.chunk_array_representation(chunk_indices)?;
if let Some(chunk_encoded) = chunk_encoded {
let chunk_representation = self.chunk_array_representation(chunk_indices)?;
let chunk_decoded = self
.codecs()
.async_decode_opt(chunk_encoded, &chunk_representation, self.parallel_codecs())
Expand All @@ -111,20 +76,58 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
let chunk_decoded_size =
chunk_representation.num_elements_usize() * chunk_representation.data_type().size();
if chunk_decoded.len() == chunk_decoded_size {
Ok(chunk_decoded)
Ok(Some(chunk_decoded))
} else {
Err(ArrayError::UnexpectedChunkDecodedSize(
chunk_decoded.len(),
chunk_decoded_size,
))
}
} else {
Ok(None)
}
}

/// Read and decode the chunk at `chunk_indices` into its bytes or the fill value if it does not exist.
///
/// # Errors
/// Returns an [`ArrayError`] if
/// - `chunk_indices` are invalid,
/// - there is a codec decoding error, or
/// - an underlying store error.
///
/// # Panics
/// Panics if the number of elements in the chunk exceeds `usize::MAX`.
pub async fn async_retrieve_chunk(&self, chunk_indices: &[u64]) -> Result<Vec<u8>, ArrayError> {
let chunk = self.async_retrieve_chunk_if_exists(chunk_indices).await?;
if let Some(chunk) = chunk {
Ok(chunk)
} else {
let chunk_representation = self.chunk_array_representation(chunk_indices)?;
let fill_value = chunk_representation.fill_value().as_ne_bytes();
Ok(fill_value.repeat(chunk_representation.num_elements_usize()))
}
}

/// Read and decode the chunk at `chunk_indices` into a vector of its elements.
/// Read and decode the chunk at `chunk_indices` into a vector of its elements if it exists.
///
/// # Errors
/// Returns an [`ArrayError`] if
/// - the size of `T` does not match the data type size,
/// - the decoded bytes cannot be transmuted,
/// - `chunk_indices` are invalid,
/// - there is a codec decoding error, or
/// - an underlying store error.
pub async fn async_retrieve_chunk_elements_if_exists<T: bytemuck::Pod + Send + Sync>(
&self,
chunk_indices: &[u64],
) -> Result<Option<Vec<T>>, ArrayError> {
validate_element_size::<T>(self.data_type())?;
let bytes = self.async_retrieve_chunk_if_exists(chunk_indices).await?;
Ok(bytes.map(|bytes| transmute_from_bytes_vec::<T>(bytes)))
}

/// Read and decode the chunk at `chunk_indices` into a vector of its elements or the fill value if it does not exist.
///
/// # Errors
/// Returns an [`ArrayError`] if
Expand All @@ -137,11 +140,45 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
chunk_indices: &[u64],
) -> Result<Vec<T>, ArrayError> {
array_async_retrieve_elements!(self, async_retrieve_chunk(chunk_indices))
validate_element_size::<T>(self.data_type())?;
let bytes = self.async_retrieve_chunk(chunk_indices).await?;
Ok(transmute_from_bytes_vec::<T>(bytes))
}

#[cfg(feature = "ndarray")]
/// Read and decode the chunk at `chunk_indices` into an [`ndarray::ArrayD`].
/// Read and decode the chunk at `chunk_indices` into an [`ndarray::ArrayD`] if it exists.
///
/// # Errors
/// Returns an [`ArrayError`] if:
/// - the size of `T` does not match the data type size,
/// - the decoded bytes cannot be transmuted,
/// - the chunk indices are invalid,
/// - there is a codec decoding error, or
/// - an underlying store error.
///
/// # Panics
/// Will panic if a chunk dimension is larger than `usize::MAX`.
pub async fn async_retrieve_chunk_ndarray_if_exists<T: bytemuck::Pod + Send + Sync>(
&self,
chunk_indices: &[u64],
) -> Result<Option<ndarray::ArrayD<T>>, ArrayError> {
// validate_element_size::<T>(self.data_type())?; in // async_retrieve_chunk_elements_if_exists
let shape = self
.chunk_grid()
.chunk_shape_u64(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?;
let elements = self
.async_retrieve_chunk_elements_if_exists(chunk_indices)
.await?;
if let Some(elements) = elements {
Ok(Some(elements_to_ndarray(&shape, elements)?))
} else {
Ok(None)
}
}

#[cfg(feature = "ndarray")]
/// Read and decode the chunk at `chunk_indices` into an [`ndarray::ArrayD`]. It is filled with the fill value if it does not exist.
///
/// # Errors
/// Returns an [`ArrayError`] if:
Expand All @@ -157,11 +194,13 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
chunk_indices: &[u64],
) -> Result<ndarray::ArrayD<T>, ArrayError> {
let shape = &self
// validate_element_size::<T>(self.data_type())?; // in async_retrieve_chunk_elements
let shape = self
.chunk_grid()
.chunk_shape_u64(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?;
array_async_retrieve_ndarray!(self, shape, async_retrieve_chunk_elements(chunk_indices))
let elements = self.async_retrieve_chunk_elements(chunk_indices).await?;
elements_to_ndarray(&shape, elements)
}

/// Read and decode the chunk at `chunk_indices` into its bytes.
Expand Down Expand Up @@ -231,7 +270,9 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
chunks: &ArraySubset,
) -> Result<Vec<T>, ArrayError> {
array_async_retrieve_elements!(self, async_retrieve_chunks(chunks))
validate_element_size::<T>(self.data_type())?;
let bytes = self.async_retrieve_chunks(chunks).await?;
Ok(transmute_from_bytes_vec::<T>(bytes))
}

/// Read and decode the chunk at `chunk_indices` into an [`ndarray::ArrayD`].
Expand All @@ -242,12 +283,10 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
chunks: &ArraySubset,
) -> Result<ndarray::ArrayD<T>, ArrayError> {
validate_element_size::<T>(self.data_type())?;
let array_subset = self.chunks_subset(chunks)?;
array_async_retrieve_ndarray!(
self,
array_subset.shape(),
async_retrieve_chunks_elements(chunks)
)
let elements = self.async_retrieve_chunks_elements(chunks).await?;
elements_to_ndarray(array_subset.shape(), elements)
}

async fn _async_decode_chunk_into_array_subset(
Expand Down Expand Up @@ -475,7 +514,9 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
array_subset: &ArraySubset,
) -> Result<Vec<T>, ArrayError> {
array_async_retrieve_elements!(self, async_retrieve_array_subset(array_subset))
validate_element_size::<T>(self.data_type())?;
let bytes = self.async_retrieve_array_subset(array_subset).await?;
Ok(transmute_from_bytes_vec::<T>(bytes))
}

#[cfg(feature = "ndarray")]
Expand All @@ -493,11 +534,11 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
array_subset: &ArraySubset,
) -> Result<ndarray::ArrayD<T>, ArrayError> {
array_async_retrieve_ndarray!(
self,
array_subset.shape(),
async_retrieve_array_subset_elements(array_subset)
)
// validate_element_size::<T>(self.data_type())?; // in async_retrieve_array_subset_elements
let elements = self
.async_retrieve_array_subset_elements(array_subset)
.await?;
elements_to_ndarray(array_subset.shape(), elements)
}

/// Read and decode the `chunk_subset` of the chunk at `chunk_indices` into its bytes.
Expand Down Expand Up @@ -565,10 +606,11 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
chunk_indices: &[u64],
chunk_subset: &ArraySubset,
) -> Result<Vec<T>, ArrayError> {
array_async_retrieve_elements!(
self,
async_retrieve_chunk_subset(chunk_indices, chunk_subset)
)
validate_element_size::<T>(self.data_type())?;
let bytes = self
.async_retrieve_chunk_subset(chunk_indices, chunk_subset)
.await?;
Ok(transmute_from_bytes_vec::<T>(bytes))
}

#[cfg(feature = "ndarray")]
Expand All @@ -588,11 +630,11 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
chunk_indices: &[u64],
chunk_subset: &ArraySubset,
) -> Result<ndarray::ArrayD<T>, ArrayError> {
array_async_retrieve_ndarray!(
self,
chunk_subset.shape(),
async_retrieve_chunk_subset_elements(chunk_indices, chunk_subset)
)
// validate_element_size::<T>(self.data_type())?; // in async_retrieve_chunk_subset_elements
let elements = self
.async_retrieve_chunk_subset_elements(chunk_indices, chunk_subset)
.await?;
elements_to_ndarray(chunk_subset.shape(), elements)
}

/// Initialises a partial decoder for the chunk at `chunk_indices` with optional parallelism.
Expand Down
Loading

0 comments on commit c72757e

Please sign in to comment.