@@ -191,12 +191,12 @@ class ArrayPartition(SuperLazyArrayLike):
191191 description = "Complete Array-like object with all proper methods for data retrieval."
192192
193193 def __init__ (self ,
194- filename ,
195- address ,
196- shape = None ,
197- position = None ,
198- extent = None ,
199- format = None ,
194+ filename : str ,
195+ address : str ,
196+ shape : Union [ tuple , None ] = None ,
197+ position : Union [ tuple , None ] = None ,
198+ extent : Union [ tuple , None ] = None ,
199+ format : Union [ str , None ] = None ,
200200 ** kwargs
201201 ):
202202
@@ -241,6 +241,10 @@ def __init__(self,
241241 self .format = format
242242 self .position = position
243243
244+ if shape is None :
245+ # Identify shape
246+ shape = tuple (self ._get_array ().shape )
247+
244248 self ._lock = SerializableLock ()
245249
246250 super ().__init__ (shape , ** kwargs )
@@ -258,9 +262,6 @@ def __array__(self, *args, **kwargs):
258262 defined by the ``extent`` parameter.
259263 """
260264
261- # Unexplained xarray behaviour:
262- # If using xarray indexing, __array__ should not have a positional 'dtype' option.
263- # If casting DataArray to numpy, __array__ requires a positional 'dtype' option.
264265 dtype = None
265266 if args :
266267 dtype = args [0 ]
@@ -270,6 +271,34 @@ def __array__(self, *args, **kwargs):
270271 'Requested datatype does not match this chunk'
271272 )
272273
274+ array = self ._get_array (* args )
275+
276+ if hasattr (array , 'units' ):
277+ self .units = array .units
278+
279+ if len (array .shape ) != len (self ._extent ):
280+ self ._correct_slice (array .dimensions )
281+
282+ try :
283+ var = np .array (array [tuple (self ._extent )], dtype = self .dtype )
284+ except IndexError :
285+ raise ValueError (
286+ f"Unable to select required 'extent' of { self .extent } "
287+ f"from fragment { self .position } with shape { array .shape } "
288+ )
289+
290+ return self ._post_process_data (var )
291+
292+ def _get_array (self , * args ):
293+ """
294+ Base private function to get the data array object.
295+
296+ Can be used to extract the shape and dtype if not known.
297+ """
298+ # Unexplained xarray behaviour:
299+ # If using xarray indexing, __array__ should not have a positional 'dtype' option.
300+ # If casting DataArray to numpy, __array__ requires a positional 'dtype' option.
301+
273302 ds = self .open ()
274303
275304 if '/' in self .address :
@@ -291,24 +320,9 @@ def __array__(self, *args, **kwargs):
291320 f"Dask Chunk at '{ self .position } ' does not contain "
292321 f"the variable '{ varname } '."
293322 )
294-
295- if hasattr (array , 'units' ):
296- self .units = array .units
297-
298- if len (array .shape ) != len (self ._extent ):
299- self ._correct_slice (array .dimensions )
323+ return array
300324
301- try :
302- var = np .array (array [tuple (self ._extent )], dtype = self .dtype )
303- except IndexError :
304- raise ValueError (
305- f"Unable to select required 'extent' of { self .extent } "
306- f"from fragment { self .position } with shape { array .shape } "
307- )
308-
309- return self ._post_process_data (var )
310-
311- def _correct_slice (self , array_dims ):
325+ def _correct_slice (self , array_dims : tuple ):
312326 """
313327 Drop size-1 dimensions from the set of slices if there is an issue.
314328
@@ -337,17 +351,20 @@ def _correct_slice(self, array_dims):
337351 )
338352 self ._extent = extent
339353
340- def _post_process_data (self , data ):
354+ def _post_process_data (self , data : np . array ):
341355 """
342- Perform any post-processing steps on the data here. Method to be
343- overriden by inherrited classes (CFAPyX.CFAPartition and
344- XarrayActive.ActivePartition)
356+ Perform any post-processing steps on the data here.
357+
358+ Method to be overriden by inherrited classes (CFAPyX.CFAPartition
359+ and XarrayActive.ActivePartition)
345360 """
346361 return data
347362
348- def _try_openers (self , filename ):
363+ def _try_openers (self , filename : str ):
349364 """
350- Attempt to open the dataset using all possible methods. Currently only NetCDF is supported.
365+ Attempt to open the dataset using all possible methods.
366+
367+ Currently only NetCDF is supported.
351368 """
352369 for open in [
353370 self ._open_netcdf ,
@@ -364,13 +381,13 @@ def _try_openers(self, filename):
364381 )
365382 return ds
366383
367- def _open_pp (self , filename ):
384+ def _open_pp (self , filename : str ):
368385 raise NotImplementedError
369386
370- def _open_um (self , filename ):
387+ def _open_um (self , filename : str ):
371388 raise NotImplementedError
372389
373- def _open_netcdf (self , filename ):
390+ def _open_netcdf (self , filename : str ):
374391 """
375392 Open a NetCDF file using the netCDF4 python package."""
376393 return netCDF4 .Dataset (filename , mode = 'r' )
@@ -386,12 +403,14 @@ def get_kwargs(self):
386403 'format' : self .format
387404 } | super ().get_kwargs ()
388405
389- def copy (self , extent = None ):
406+ def copy (self , extent : Union [ tuple , None ] = None ):
390407 """
391- Create a new instance of this class with all attributes of the current instance, but
392- with a new initial extent made by combining the current instance extent with the ``newextent``.
393- Each ArrayLike class must overwrite this class to get the best performance with multiple
394- slicing operations.
408+ Create a new instance of this class with all attributes of the current instance.
409+
410+ The copy has annew initial extent made by combining the current instance
411+ extent with the ``newextent``.
412+ Each ArrayLike class must overwrite this class to get the best performance
413+ with multiple slicing operations.
395414 """
396415 kwargs = self .get_kwargs ()
397416 if extent :
@@ -406,7 +425,9 @@ def copy(self, extent=None):
406425
407426 def open (self ):
408427 """
409- Open the source file for this chunk to extract data. Multiple file locations may be provided
428+ Open the source file for this chunk to extract data.
429+
430+ Multiple file locations may be provided
410431 for this object, in which case there is a priority for 'remote' sources first, followed by
411432 'local' sources - otherwise the order is as given in the fragment array variable ``location``.
412433 """
@@ -446,7 +467,23 @@ def open(self):
446467 f'Locations tried: { filenames } .'
447468 )
448469
449- def _identical_extents (old , new , dshape ):
470+ def _identical_extents (
471+ old : slice ,
472+ new : slice ,
473+ dshape : int ):
474+ """
475+ Determine if two slices match precisely.
476+
477+ :param old: (slice) Current slice applied to the dimension.
478+
479+ :param new: (slice) New slice to be combined with the old slice.
480+
481+ :param dshape: (int) Total size of the given dimension.
482+ """
483+
484+ if isinstance (new , int ):
485+ new = slice (new , new + 1 )
486+
450487 ostart = old .start or 0
451488 ostop = old .stop or dshape
452489 ostep = old .step or 1
@@ -459,9 +496,14 @@ def _identical_extents(old, new, dshape):
459496 (ostop == nstop ) and \
460497 (ostep == nstep )
461498
462- def get_chunk_space (chunk_shape , shape ):
499+ def get_chunk_space (
500+ chunk_shape : tuple ,
501+ shape : tuple
502+ ) -> tuple :
463503 """
464- Derive the chunk space from the ratio between the chunk shape and array shape in
504+ Derive the chunk space in each dimension.
505+
506+ Calculated from the ratio between the chunk shape and array shape in
465507 each dimension. Chunk space is the number of chunks in each dimension which is
466508 referred to as a ``space`` because it effectively represents the lengths of the each
467509 dimension in 'chunk space' rather than any particular chunk coordinate.
@@ -481,12 +523,22 @@ def get_chunk_space(chunk_shape, shape):
481523 ...
482524 and so on.
483525
526+ :param chunk_shape: (tuple) The shape of each chunk in array space.
527+
528+ :param shape: (tuple) The total array shape in array space -
529+ alternatively the total array space size.
530+
484531 """
485532
486533 space = tuple ([math .ceil (i / j ) for i , j in zip (shape , chunk_shape )])
487534 return space
488535
489- def get_chunk_shape (chunks , shape , dims , chunk_limits = True ):
536+ def get_chunk_shape (
537+ chunks : dict ,
538+ shape : tuple ,
539+ dims : tuple ,
540+ chunk_limits : bool = True
541+ ) -> tuple :
490542 """
491543 Calculate the chunk shape from the user-provided ``chunks`` parameter,
492544 the array shape and named dimensions, and apply chunk limits if enabled.
@@ -532,10 +584,19 @@ def get_chunk_shape(chunks, shape, dims, chunk_limits=True):
532584
533585 return tuple (chunk_shape )
534586
535- def get_chunk_positions (chunk_space ):
587+ def get_chunk_positions (
588+ chunk_space : tuple
589+ ) -> list [tuple ]:
536590 """
537- Get the list of chunk positions in ``chunk space`` given the size
538- of the space.
591+ Get the list of chunk positions in ``chunk space``.
592+
593+ Given the size of the space, list all possible positions
594+ within the space. A space of ``(1,1)`` has a single possible
595+ position; ``(0,0)``, whereas a space of ``(2,2)`` has four
596+ positions: ``(0,0)``,``(0,1)``,``(1,0)`` and ``(1,1)``.
597+
598+ :param chunk_space: (tuple) The total size of the space in
599+ all dimensions
539600 """
540601 origin = [0 for i in chunk_space ]
541602
@@ -547,10 +608,23 @@ def get_chunk_positions(chunk_space):
547608
548609 return positions
549610
550- def get_chunk_extent (position , shape , chunk_space ):
611+ def get_chunk_extent (
612+ position : tuple ,
613+ shape : tuple ,
614+ chunk_space : tuple
615+ ) -> tuple :
551616 """
552- Get the extent of a particular chunk within the space given its position,
553- the array shape and the extent of the chunk space.
617+ Get the extent of a particular chunk within the space.
618+
619+ Given its position, the array shape and the extent of the
620+ chunk space, find the extent of a particular chunk.
621+
622+ :param position: (tuple) The position of the chunk in chunk space.
623+
624+ :param shape: (tuple) The total array shape for the whole chunk space.
625+
626+ :param chunk_space: (tuple) The size of the chunk space (number of chunks
627+ in each dimension).
554628 """
555629 extent = []
556630 for dim in range (len (position )):
@@ -568,11 +642,12 @@ def get_chunk_extent(position, shape, chunk_space):
568642 return extent
569643
570644def get_dask_chunks (
571- array_space ,
572- fragment_space ,
573- extent ,
574- dtype ,
575- explicit_shapes = None ):
645+ array_space : tuple ,
646+ fragment_space : tuple ,
647+ extent : tuple ,
648+ dtype : np .dtype ,
649+ explicit_shapes : Union [tuple ,None ] = None
650+ ) -> tuple :
576651 """
577652 Define the `chunks` array passed to Dask when creating a Dask Array. This is an array of fragment sizes
578653 per dimension for each of the relevant dimensions. Copied from cf-python version 3.14.0 onwards.
@@ -719,7 +794,7 @@ def combine_slices(
719794 else :
720795 for dim in range (len (newslice )):
721796 if not _identical_extents (extent [dim ], newslice [dim ], shape [dim ]):
722- extent [dim ] = combine_sliced_dim (extent [dim ], newslice [dim ], dim )
797+ extent [dim ] = combine_sliced_dim (extent [dim ], newslice [dim ], shape , dim )
723798 return extent
724799
725800def normalize_partition_chunks (
0 commit comments