Skip to content

Commit 0766aaf

Browse files
committed
Added test data, added type hints
1 parent ef76bd4 commit 0766aaf

File tree

4 files changed

+165
-57
lines changed

4 files changed

+165
-57
lines changed

arraypartition/partition.py

Lines changed: 132 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,12 @@ class ArrayPartition(SuperLazyArrayLike):
191191
description = "Complete Array-like object with all proper methods for data retrieval."
192192

193193
def __init__(self,
194-
filename,
195-
address,
196-
shape=None,
197-
position=None,
198-
extent=None,
199-
format=None,
194+
filename: str,
195+
address: str,
196+
shape: Union[tuple,None] = None,
197+
position: Union[tuple,None] = None,
198+
extent: Union[tuple,None] = None,
199+
format: Union[str,None] = None,
200200
**kwargs
201201
):
202202

@@ -241,6 +241,10 @@ def __init__(self,
241241
self.format = format
242242
self.position = position
243243

244+
if shape is None:
245+
# Identify shape
246+
shape = tuple(self._get_array().shape)
247+
244248
self._lock = SerializableLock()
245249

246250
super().__init__(shape, **kwargs)
@@ -258,9 +262,6 @@ def __array__(self, *args, **kwargs):
258262
defined by the ``extent`` parameter.
259263
"""
260264

261-
# Unexplained xarray behaviour:
262-
# If using xarray indexing, __array__ should not have a positional 'dtype' option.
263-
# If casting DataArray to numpy, __array__ requires a positional 'dtype' option.
264265
dtype = None
265266
if args:
266267
dtype = args[0]
@@ -270,6 +271,34 @@ def __array__(self, *args, **kwargs):
270271
'Requested datatype does not match this chunk'
271272
)
272273

274+
array = self._get_array(*args)
275+
276+
if hasattr(array, 'units'):
277+
self.units = array.units
278+
279+
if len(array.shape) != len(self._extent):
280+
self._correct_slice(array.dimensions)
281+
282+
try:
283+
var = np.array(array[tuple(self._extent)], dtype=self.dtype)
284+
except IndexError:
285+
raise ValueError(
286+
f"Unable to select required 'extent' of {self.extent} "
287+
f"from fragment {self.position} with shape {array.shape}"
288+
)
289+
290+
return self._post_process_data(var)
291+
292+
def _get_array(self, *args):
293+
"""
294+
Base private function to get the data array object.
295+
296+
Can be used to extract the shape and dtype if not known.
297+
"""
298+
# Unexplained xarray behaviour:
299+
# If using xarray indexing, __array__ should not have a positional 'dtype' option.
300+
# If casting DataArray to numpy, __array__ requires a positional 'dtype' option.
301+
273302
ds = self.open()
274303

275304
if '/' in self.address:
@@ -291,24 +320,9 @@ def __array__(self, *args, **kwargs):
291320
f"Dask Chunk at '{self.position}' does not contain "
292321
f"the variable '{varname}'."
293322
)
294-
295-
if hasattr(array, 'units'):
296-
self.units = array.units
297-
298-
if len(array.shape) != len(self._extent):
299-
self._correct_slice(array.dimensions)
323+
return array
300324

301-
try:
302-
var = np.array(array[tuple(self._extent)], dtype=self.dtype)
303-
except IndexError:
304-
raise ValueError(
305-
f"Unable to select required 'extent' of {self.extent} "
306-
f"from fragment {self.position} with shape {array.shape}"
307-
)
308-
309-
return self._post_process_data(var)
310-
311-
def _correct_slice(self, array_dims):
325+
def _correct_slice(self, array_dims: tuple):
312326
"""
313327
Drop size-1 dimensions from the set of slices if there is an issue.
314328
@@ -337,17 +351,20 @@ def _correct_slice(self, array_dims):
337351
)
338352
self._extent = extent
339353

340-
def _post_process_data(self, data):
354+
def _post_process_data(self, data: np.array):
341355
"""
342-
Perform any post-processing steps on the data here. Method to be
343-
overriden by inherrited classes (CFAPyX.CFAPartition and
344-
XarrayActive.ActivePartition)
356+
Perform any post-processing steps on the data here.
357+
358+
Method to be overriden by inherrited classes (CFAPyX.CFAPartition
359+
and XarrayActive.ActivePartition)
345360
"""
346361
return data
347362

348-
def _try_openers(self, filename):
363+
def _try_openers(self, filename: str):
349364
"""
350-
Attempt to open the dataset using all possible methods. Currently only NetCDF is supported.
365+
Attempt to open the dataset using all possible methods.
366+
367+
Currently only NetCDF is supported.
351368
"""
352369
for open in [
353370
self._open_netcdf,
@@ -364,13 +381,13 @@ def _try_openers(self, filename):
364381
)
365382
return ds
366383

367-
def _open_pp(self, filename):
384+
def _open_pp(self, filename: str):
368385
raise NotImplementedError
369386

370-
def _open_um(self, filename):
387+
def _open_um(self, filename: str):
371388
raise NotImplementedError
372389

373-
def _open_netcdf(self, filename):
390+
def _open_netcdf(self, filename: str):
374391
"""
375392
Open a NetCDF file using the netCDF4 python package."""
376393
return netCDF4.Dataset(filename, mode='r')
@@ -386,12 +403,14 @@ def get_kwargs(self):
386403
'format': self.format
387404
} | super().get_kwargs()
388405

389-
def copy(self, extent=None):
406+
def copy(self, extent: Union[tuple,None] = None):
390407
"""
391-
Create a new instance of this class with all attributes of the current instance, but
392-
with a new initial extent made by combining the current instance extent with the ``newextent``.
393-
Each ArrayLike class must overwrite this class to get the best performance with multiple
394-
slicing operations.
408+
Create a new instance of this class with all attributes of the current instance.
409+
410+
The copy has annew initial extent made by combining the current instance
411+
extent with the ``newextent``.
412+
Each ArrayLike class must overwrite this class to get the best performance
413+
with multiple slicing operations.
395414
"""
396415
kwargs = self.get_kwargs()
397416
if extent:
@@ -406,7 +425,9 @@ def copy(self, extent=None):
406425

407426
def open(self):
408427
"""
409-
Open the source file for this chunk to extract data. Multiple file locations may be provided
428+
Open the source file for this chunk to extract data.
429+
430+
Multiple file locations may be provided
410431
for this object, in which case there is a priority for 'remote' sources first, followed by
411432
'local' sources - otherwise the order is as given in the fragment array variable ``location``.
412433
"""
@@ -446,7 +467,23 @@ def open(self):
446467
f'Locations tried: {filenames}.'
447468
)
448469

449-
def _identical_extents(old, new, dshape):
470+
def _identical_extents(
471+
old: slice,
472+
new: slice,
473+
dshape: int):
474+
"""
475+
Determine if two slices match precisely.
476+
477+
:param old: (slice) Current slice applied to the dimension.
478+
479+
:param new: (slice) New slice to be combined with the old slice.
480+
481+
:param dshape: (int) Total size of the given dimension.
482+
"""
483+
484+
if isinstance(new, int):
485+
new = slice(new, new+1)
486+
450487
ostart = old.start or 0
451488
ostop = old.stop or dshape
452489
ostep = old.step or 1
@@ -459,9 +496,14 @@ def _identical_extents(old, new, dshape):
459496
(ostop == nstop) and \
460497
(ostep == nstep)
461498

462-
def get_chunk_space(chunk_shape, shape):
499+
def get_chunk_space(
500+
chunk_shape: tuple,
501+
shape: tuple
502+
) -> tuple:
463503
"""
464-
Derive the chunk space from the ratio between the chunk shape and array shape in
504+
Derive the chunk space in each dimension.
505+
506+
Calculated from the ratio between the chunk shape and array shape in
465507
each dimension. Chunk space is the number of chunks in each dimension which is
466508
referred to as a ``space`` because it effectively represents the lengths of the each
467509
dimension in 'chunk space' rather than any particular chunk coordinate.
@@ -481,12 +523,22 @@ def get_chunk_space(chunk_shape, shape):
481523
...
482524
and so on.
483525
526+
:param chunk_shape: (tuple) The shape of each chunk in array space.
527+
528+
:param shape: (tuple) The total array shape in array space -
529+
alternatively the total array space size.
530+
484531
"""
485532

486533
space = tuple([math.ceil(i/j) for i, j in zip(shape, chunk_shape)])
487534
return space
488535

489-
def get_chunk_shape(chunks, shape, dims, chunk_limits=True):
536+
def get_chunk_shape(
537+
chunks: dict,
538+
shape: tuple,
539+
dims: tuple,
540+
chunk_limits: bool = True
541+
) -> tuple:
490542
"""
491543
Calculate the chunk shape from the user-provided ``chunks`` parameter,
492544
the array shape and named dimensions, and apply chunk limits if enabled.
@@ -532,10 +584,19 @@ def get_chunk_shape(chunks, shape, dims, chunk_limits=True):
532584

533585
return tuple(chunk_shape)
534586

535-
def get_chunk_positions(chunk_space):
587+
def get_chunk_positions(
588+
chunk_space: tuple
589+
) -> list[tuple]:
536590
"""
537-
Get the list of chunk positions in ``chunk space`` given the size
538-
of the space.
591+
Get the list of chunk positions in ``chunk space``.
592+
593+
Given the size of the space, list all possible positions
594+
within the space. A space of ``(1,1)`` has a single possible
595+
position; ``(0,0)``, whereas a space of ``(2,2)`` has four
596+
positions: ``(0,0)``,``(0,1)``,``(1,0)`` and ``(1,1)``.
597+
598+
:param chunk_space: (tuple) The total size of the space in
599+
all dimensions
539600
"""
540601
origin = [0 for i in chunk_space]
541602

@@ -547,10 +608,23 @@ def get_chunk_positions(chunk_space):
547608

548609
return positions
549610

550-
def get_chunk_extent(position, shape, chunk_space):
611+
def get_chunk_extent(
612+
position: tuple,
613+
shape: tuple,
614+
chunk_space: tuple
615+
) -> tuple:
551616
"""
552-
Get the extent of a particular chunk within the space given its position,
553-
the array shape and the extent of the chunk space.
617+
Get the extent of a particular chunk within the space.
618+
619+
Given its position, the array shape and the extent of the
620+
chunk space, find the extent of a particular chunk.
621+
622+
:param position: (tuple) The position of the chunk in chunk space.
623+
624+
:param shape: (tuple) The total array shape for the whole chunk space.
625+
626+
:param chunk_space: (tuple) The size of the chunk space (number of chunks
627+
in each dimension).
554628
"""
555629
extent = []
556630
for dim in range(len(position)):
@@ -568,11 +642,12 @@ def get_chunk_extent(position, shape, chunk_space):
568642
return extent
569643

570644
def get_dask_chunks(
571-
array_space,
572-
fragment_space,
573-
extent,
574-
dtype,
575-
explicit_shapes=None):
645+
array_space: tuple,
646+
fragment_space: tuple,
647+
extent: tuple,
648+
dtype: np.dtype,
649+
explicit_shapes: Union[tuple,None] = None
650+
) -> tuple:
576651
"""
577652
Define the `chunks` array passed to Dask when creating a Dask Array. This is an array of fragment sizes
578653
per dimension for each of the relevant dimensions. Copied from cf-python version 3.14.0 onwards.
@@ -719,7 +794,7 @@ def combine_slices(
719794
else:
720795
for dim in range(len(newslice)):
721796
if not _identical_extents(extent[dim], newslice[dim], shape[dim]):
722-
extent[dim] = combine_sliced_dim(extent[dim], newslice[dim], dim)
797+
extent[dim] = combine_sliced_dim(extent[dim], newslice[dim], shape, dim)
723798
return extent
724799

725800
def normalize_partition_chunks(
1 MB
Binary file not shown.

arraypartition/tests/test_array.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
__author__ = "Daniel Westwood"
2+
__contact__ = "[email protected]"
3+
__copyright__ = "Copyright 2024 United Kingdom Research and Innovation"
4+
5+
import numpy as np
6+
7+
from arraypartition import ArrayPartition
8+
9+
class TestArray:
10+
def test_array(self):
11+
12+
# Example Array Partition
13+
14+
array_part = ArrayPartition(
15+
'arraypartition/tests/data/example1.0.nc',
16+
'p',
17+
)
18+
19+
assert array_part.shape == (2,180,360), "Shape Source Error"
20+
21+
ap = array_part[0, slice(100,120), slice(100,120)]
22+
23+
assert ap.shape == (1,20,20), "Shape Error"
24+
25+
assert (np.array(array_part)[0][5][0] - 0.463) < 0.001, "Data Error"
26+
assert (np.array(ap)[0][5][0] - 0.803) < 0.001, "Data Error"
27+
28+
if __name__ == '__main__':
29+
TestArray().test_array()

arraypartition/tests/test_consistency.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
__author__ = "Daniel Westwood"
2+
__contact__ = "[email protected]"
3+
__copyright__ = "Copyright 2024 United Kingdom Research and Innovation"
4+
15
class TestConsistency:
26

37
def test_core(self):

0 commit comments

Comments
 (0)