Skip to content

Commit

Permalink
Use chunk_shape_u64 internally
Browse files Browse the repository at this point in the history
  • Loading branch information
LDeakin committed Feb 10, 2024
1 parent 382c403 commit 6d9bf07
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 79 deletions.
13 changes: 5 additions & 8 deletions src/array/array_async_readable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use crate::{
};

use super::{
chunk_shape_to_array_shape,
codec::{
ArrayCodecTraits, ArrayToBytesCodecTraits, AsyncArrayPartialDecoderTraits,
AsyncStoragePartialDecoder,
Expand Down Expand Up @@ -158,12 +157,10 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
&self,
chunk_indices: &[u64],
) -> Result<ndarray::ArrayD<T>, ArrayError> {
let shape = chunk_shape_to_array_shape(
&self
.chunk_grid()
.chunk_shape(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?,
);
let shape = &self
.chunk_grid()
.chunk_shape_u64(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?;
array_async_retrieve_ndarray!(self, shape, async_retrieve_chunk_elements(chunk_indices))
}

Expand Down Expand Up @@ -520,7 +517,7 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits> Array<TStorage> {
chunk_subset: &ArraySubset,
) -> Result<Vec<u8>, ArrayError> {
let chunk_representation = self.chunk_array_representation(chunk_indices)?;
if !chunk_subset.inbounds(&chunk_shape_to_array_shape(chunk_representation.shape())) {
if !chunk_subset.inbounds(&chunk_representation.shape_u64()) {
return Err(ArrayError::InvalidArraySubset(
chunk_subset.clone(),
self.shape().to_vec(),
Expand Down
8 changes: 5 additions & 3 deletions src/array/array_async_readable_writable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
storage::{data_key, AsyncReadableWritableStorageTraits},
};

use super::{chunk_shape_to_array_shape, Array, ArrayError};
use super::{Array, ArrayError};

impl<TStorage: ?Sized + AsyncReadableWritableStorageTraits> Array<TStorage> {
/// Encode `subset_bytes` and store in `array_subset`.
Expand Down Expand Up @@ -195,8 +195,10 @@ impl<TStorage: ?Sized + AsyncReadableWritableStorageTraits> Array<TStorage> {
chunk_subset_bytes: Vec<u8>,
) -> Result<(), ArrayError> {
// Validation
if let Some(chunk_shape) = self.chunk_grid().chunk_shape(chunk_indices, self.shape())? {
let chunk_shape = chunk_shape_to_array_shape(&chunk_shape);
if let Some(chunk_shape) = self
.chunk_grid()
.chunk_shape_u64(chunk_indices, self.shape())?
{
if std::iter::zip(chunk_subset.end_exc(), &chunk_shape)
.any(|(end_exc, shape)| end_exc > *shape)
{
Expand Down
13 changes: 5 additions & 8 deletions src/array/array_sync_readable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::{
};

use super::{
chunk_shape_to_array_shape,
codec::{
ArrayCodecTraits, ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, StoragePartialDecoder,
},
Expand Down Expand Up @@ -152,12 +151,10 @@ impl<TStorage: ?Sized + ReadableStorageTraits> Array<TStorage> {
&self,
chunk_indices: &[u64],
) -> Result<ndarray::ArrayD<T>, ArrayError> {
let shape = crate::array::chunk_shape_to_array_shape(
&self
.chunk_grid()
.chunk_shape(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?,
);
let shape = self
.chunk_grid()
.chunk_shape_u64(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?;
array_retrieve_ndarray!(self, shape, retrieve_chunk_elements(chunk_indices))
}

Expand Down Expand Up @@ -600,7 +597,7 @@ impl<TStorage: ?Sized + ReadableStorageTraits> Array<TStorage> {
chunk_subset: &ArraySubset,
) -> Result<Vec<u8>, ArrayError> {
let chunk_representation = self.chunk_array_representation(chunk_indices)?;
if !chunk_subset.inbounds(&chunk_shape_to_array_shape(chunk_representation.shape())) {
if !chunk_subset.inbounds(&chunk_representation.shape_u64()) {
return Err(ArrayError::InvalidArraySubset(
chunk_subset.clone(),
self.shape().to_vec(),
Expand Down
12 changes: 5 additions & 7 deletions src/array/array_sync_readable_writable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
storage::{data_key, ReadableWritableStorageTraits},
};

use super::{chunk_shape_to_array_shape, unravel_index, Array, ArrayError};
use super::{unravel_index, Array, ArrayError};

impl<TStorage: ?Sized + ReadableWritableStorageTraits> Array<TStorage> {
/// Encode `subset_bytes` and store in `array_subset`.
Expand Down Expand Up @@ -259,12 +259,10 @@ impl<TStorage: ?Sized + ReadableWritableStorageTraits> Array<TStorage> {
chunk_subset: &ArraySubset,
chunk_subset_bytes: Vec<u8>,
) -> Result<(), ArrayError> {
let chunk_shape = chunk_shape_to_array_shape(
&self
.chunk_grid()
.chunk_shape(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?,
);
let chunk_shape = self
.chunk_grid()
.chunk_shape_u64(chunk_indices, self.shape())?
.ok_or_else(|| ArrayError::InvalidChunkGridIndicesError(chunk_indices.to_vec()))?;

// Validation
if std::iter::zip(chunk_subset.end_exc(), &chunk_shape)
Expand Down
4 changes: 1 addition & 3 deletions src/array/chunk_grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub use regular::{RegularChunkGrid, RegularChunkGridConfiguration};
use derive_more::{Deref, From};

use crate::{
array::chunk_shape_to_array_shape,
array_subset::{ArraySubset, IncompatibleDimensionalityError},
metadata::Metadata,
plugin::{Plugin, PluginCreateError},
Expand Down Expand Up @@ -403,9 +402,8 @@ pub trait ChunkGridTraits: dyn_clone::DynClone + core::fmt::Debug + Send + Sync
debug_assert_eq!(self.dimensionality(), chunk_indices.len());
if let (Some(chunk_origin), Some(chunk_shape)) = (
self.chunk_origin_unchecked(chunk_indices, array_shape),
self.chunk_shape_unchecked(chunk_indices, array_shape),
self.chunk_shape_u64_unchecked(chunk_indices, array_shape),
) {
let chunk_shape = chunk_shape_to_array_shape(&chunk_shape);
Some(ArraySubset::new_with_start_shape_unchecked(
chunk_origin,
chunk_shape,
Expand Down
10 changes: 10 additions & 0 deletions src/array/chunk_grid/regular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ impl RegularChunkGrid {
pub fn chunk_shape(&self) -> &[NonZeroU64] {
self.chunk_shape.as_slice()
}

/// Return the chunk shape as an [`ArrayShape`] ([`Vec<u64>`]).
#[must_use]
pub fn chunk_shape_u64(&self) -> Vec<u64> {
self.chunk_shape
.iter()
.copied()
.map(NonZeroU64::get)
.collect::<Vec<_>>()
}
}

impl ChunkGridTraits for RegularChunkGrid {
Expand Down
7 changes: 3 additions & 4 deletions src/array/codec/array_to_array/transpose/transpose_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use thiserror::Error;

use crate::{
array::{
chunk_shape_to_array_shape,
codec::{
ArrayCodecTraits, ArrayPartialDecoderTraits, ArrayToArrayCodecTraits, CodecError,
CodecTraits,
Expand Down Expand Up @@ -138,7 +137,7 @@ impl ArrayCodecTraits for TransposeCodec {
calculate_order_encode(&self.order, decoded_representation.shape().len());
transpose_array(
&order_encode,
&chunk_shape_to_array_shape(decoded_representation.shape()),
&decoded_representation.shape_u64(),
decoded_representation.element_size(),
&decoded_value,
)
Expand All @@ -158,10 +157,10 @@ impl ArrayCodecTraits for TransposeCodec {
) -> Result<Vec<u8>, CodecError> {
let order_decode =
calculate_order_decode(&self.order, decoded_representation.shape().len());
let transposed_shape = permute(decoded_representation.shape(), &self.order);
let transposed_shape = permute(&decoded_representation.shape_u64(), &self.order);
transpose_array(
&order_decode,
&chunk_shape_to_array_shape(&transposed_shape),
&transposed_shape,
decoded_representation.element_size(),
&encoded_value,
)
Expand Down
5 changes: 2 additions & 3 deletions src/array/codec/array_to_bytes/bytes/bytes_partial_decoder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::{
array::{
chunk_shape_to_array_shape,
codec::{ArrayPartialDecoderTraits, ArraySubset, BytesPartialDecoderTraits, CodecError},
ChunkRepresentation,
},
Expand Down Expand Up @@ -41,7 +40,7 @@ impl ArrayPartialDecoderTraits for BytesPartialDecoder<'_> {
parallel: bool,
) -> Result<Vec<Vec<u8>>, CodecError> {
let mut bytes = Vec::with_capacity(decoded_regions.len());
let chunk_shape = chunk_shape_to_array_shape(self.decoded_representation.shape());
let chunk_shape = self.decoded_representation.shape_u64();
for array_subset in decoded_regions {
// Get byte ranges
let byte_ranges = array_subset
Expand Down Expand Up @@ -118,7 +117,7 @@ impl AsyncArrayPartialDecoderTraits for AsyncBytesPartialDecoder<'_> {
parallel: bool,
) -> Result<Vec<Vec<u8>>, CodecError> {
let mut bytes = Vec::with_capacity(decoded_regions.len());
let chunk_shape = chunk_shape_to_array_shape(self.decoded_representation.shape());
let chunk_shape = self.decoded_representation.shape_u64();
for array_subset in decoded_regions {
// Get byte ranges
let byte_ranges = array_subset
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::{
array::{
chunk_shape_to_array_shape,
codec::{ArrayPartialDecoderTraits, ArraySubset, BytesPartialDecoderTraits, CodecError},
ChunkRepresentation, DataType,
},
Expand Down Expand Up @@ -35,7 +34,7 @@ fn do_partial_decode(
decoded_representation: &ChunkRepresentation,
) -> Result<Vec<Vec<u8>>, CodecError> {
let mut decoded_bytes = Vec::with_capacity(decoded_regions.len());
let chunk_shape = chunk_shape_to_array_shape(decoded_representation.shape());
let chunk_shape = decoded_representation.shape_u64();
match decoded {
None => {
for array_subset in decoded_regions {
Expand Down
22 changes: 11 additions & 11 deletions src/array/codec/array_to_bytes/sharding/sharding_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ impl ShardingCodec {
let shard_slice = unsafe {
std::slice::from_raw_parts_mut(shard.as_mut_ptr().cast::<u8>(), shard.len())
};
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_shape.clone())
.iter_chunks_unchecked(self.chunk_shape.as_slice())
Expand Down Expand Up @@ -468,7 +468,7 @@ impl ShardingCodec {
ShardingIndexLocation::Start => index_encoded_size,
ShardingIndexLocation::End => 0,
};
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_shape.clone())
.iter_chunks_unchecked(self.chunk_shape.as_slice())
Expand Down Expand Up @@ -562,7 +562,7 @@ impl ShardingCodec {
};
let shard_slice = UnsafeCellSlice::new(shard_slice);
let shard_index_slice = UnsafeCellSlice::new(&mut shard_index);
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
(0..chunks_per_shard
.as_slice()
.iter()
Expand Down Expand Up @@ -660,7 +660,7 @@ impl ShardingCodec {
let index_encoded_size = usize::try_from(index_encoded_size).unwrap();

// Find chunks that are not entirely the fill value and collect their decoded bytes
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
let encoded_chunks: Vec<(u64, Vec<u8>)> = (0..chunks_per_shard
.as_slice()
.iter()
Expand Down Expand Up @@ -780,7 +780,7 @@ impl ShardingCodec {
ShardingIndexLocation::Start => index_encoded_size,
ShardingIndexLocation::End => 0,
};
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_shape.clone())
.iter_chunks_unchecked(self.chunk_shape.as_slice())
Expand Down Expand Up @@ -884,7 +884,7 @@ impl ShardingCodec {
let shard_slice = unsafe {
std::slice::from_raw_parts_mut(shard.as_mut_ptr().cast::<u8>(), shard.len())
};
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_shape.clone())
.iter_chunks_unchecked(self.chunk_shape.as_slice())
Expand Down Expand Up @@ -998,7 +998,7 @@ impl ShardingCodec {
let shard_slice = UnsafeCellSlice::new(shard_slice);
let shard_index_slice = UnsafeCellSlice::new(&mut shard_index);
let chunks_per_shard = &chunks_per_shard;
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
let futures = (0..chunks_per_shard
.as_slice()
.iter()
Expand Down Expand Up @@ -1119,7 +1119,7 @@ impl ShardingCodec {
let index_encoded_size = usize::try_from(index_encoded_size).unwrap();

// Encode the chunks
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
let encoded_chunks = futures::future::join_all(
(0..chunks_per_shard
.as_slice()
Expand Down Expand Up @@ -1332,7 +1332,7 @@ impl ShardingCodec {
)
};

let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
if parallel {
let chunks_per_shard = calculate_chunks_per_shard(
shard_representation.shape(),
Expand Down Expand Up @@ -1387,7 +1387,7 @@ impl ShardingCodec {
})?;
} else {
let element_size = chunk_representation.element_size() as u64;
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_shape.clone())
.iter_chunks_unchecked(self.chunk_shape.as_slice())
Expand Down Expand Up @@ -1453,7 +1453,7 @@ impl ShardingCodec {
};

// Decode chunks
let shard_shape = chunk_shape_to_array_shape(shard_representation.shape());
let shard_shape = shard_representation.shape_u64();
if parallel {
let chunks_per_shard = calculate_chunks_per_shard(
shard_representation.shape(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,7 @@ impl AsyncArrayPartialDecoderTraits for AsyncShardingPartialDecoder<'_> {
})
.collect::<Vec<_>>();
if !filled_chunks.is_empty() {
let chunk_array_ss = ArraySubset::new_with_shape(chunk_shape_to_array_shape(
self.chunk_grid.chunk_shape(),
));
let chunk_array_ss = ArraySubset::new_with_shape(self.chunk_grid.chunk_shape_u64());
let filled_chunk = self
.decoded_representation
.fill_value()
Expand Down
5 changes: 2 additions & 3 deletions src/array/codec/array_to_bytes/zfp/zfp_partial_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use zfp_sys::zfp_type;

use crate::{
array::{
chunk_shape_to_array_shape,
codec::{ArrayPartialDecoderTraits, BytesPartialDecoderTraits, CodecError},
ChunkRepresentation,
},
Expand Down Expand Up @@ -56,7 +55,7 @@ impl ArrayPartialDecoderTraits for ZfpPartialDecoder<'_> {
) -> Result<Vec<Vec<u8>>, CodecError> {
let encoded_value = self.input_handle.decode_opt(parallel)?;
let mut out = Vec::with_capacity(decoded_regions.len());
let chunk_shape = chunk_shape_to_array_shape(self.decoded_representation.shape());
let chunk_shape = self.decoded_representation.shape_u64();
match encoded_value {
Some(encoded_value) => {
let decoded_value = zfp_decode(
Expand Down Expand Up @@ -136,7 +135,7 @@ impl AsyncArrayPartialDecoderTraits for AsyncZfpPartialDecoder<'_> {
parallel: bool,
) -> Result<Vec<Vec<u8>>, CodecError> {
let encoded_value = self.input_handle.decode_opt(parallel).await?;
let chunk_shape = chunk_shape_to_array_shape(self.decoded_representation.shape());
let chunk_shape = self.decoded_representation.shape_u64();
let mut out = Vec::with_capacity(decoded_regions.len());
match encoded_value {
Some(encoded_value) => {
Expand Down
Loading

0 comments on commit 6d9bf07

Please sign in to comment.