Skip to content

Commit

Permalink
add StoreV3 support to most convenience routines
Browse files Browse the repository at this point in the history
consolidated metadata functions haven't been updated yet
  • Loading branch information
grlee77 committed Dec 15, 2021
1 parent 8e3c443 commit b13a6b3
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 91 deletions.
191 changes: 125 additions & 66 deletions zarr/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,28 @@

from zarr.core import Array
from zarr.creation import array as _create_array
from zarr.creation import normalize_store_arg, open_array
from zarr.creation import open_array
from zarr.errors import CopyError, PathNotFoundError
from zarr.hierarchy import Group
from zarr.hierarchy import group as _create_group
from zarr.hierarchy import open_group
from zarr.meta import json_dumps, json_loads
from zarr.storage import contains_array, contains_group, BaseStore
from zarr.storage import contains_array, contains_group, normalize_store_arg, BaseStore
from zarr.util import TreeViewer, buffer_size, normalize_storage_path

from typing import Union

StoreLike = Union[BaseStore, MutableMapping, str, None]


def _check_and_update_path(store: BaseStore, path):
if getattr(store, '_store_version', 2) > 2 and not path:
raise ValueError("path must be provided for v3 stores")
return normalize_storage_path(path)


# noinspection PyShadowingBuiltins
def open(store: StoreLike = None, mode: str = "a", **kwargs):
def open(store: StoreLike = None, mode: str = "a", *, zarr_version=2, path=None, **kwargs):
"""Convenience function to open a group or array using file-mode-like semantics.
Parameters
Expand All @@ -34,6 +40,10 @@ def open(store: StoreLike = None, mode: str = "a", **kwargs):
read/write (must exist); 'a' means read/write (create if doesn't
exist); 'w' means create (overwrite if exists); 'w-' means create
(fail if exists).
zarr_version : {2, 3}
The zarr protocol version to use.
path : str
The path within the store to open.
**kwargs
Additional parameters are passed through to :func:`zarr.creation.open_array` or
:func:`zarr.hierarchy.open_group`.
Expand Down Expand Up @@ -75,15 +85,16 @@ def open(store: StoreLike = None, mode: str = "a", **kwargs):
"""

path = kwargs.get('path')
# handle polymorphic store arg
clobber = mode == 'w'
# we pass storage options explicitly, since normalize_store_arg might construct
# a store if the input is a fsspec-compatible URL
_store: BaseStore = normalize_store_arg(
store, clobber=clobber, storage_options=kwargs.pop("storage_options", {})
store, clobber=clobber, storage_options=kwargs.pop("storage_options", {}),
zarr_version=zarr_version,
)
path = normalize_storage_path(path)
path = _check_and_update_path(_store, path)
kwargs['path'] = path

if mode in {'w', 'w-', 'x'}:
if 'shape' in kwargs:
Expand All @@ -110,7 +121,7 @@ def _might_close(path):
return isinstance(path, (str, os.PathLike))


def save_array(store: StoreLike, arr, **kwargs):
def save_array(store: StoreLike, arr, *, zarr_version=2, path=None, **kwargs):
"""Convenience function to save a NumPy array to the local file system, following a
similar API to the NumPy save() function.
Expand All @@ -120,6 +131,10 @@ def save_array(store: StoreLike, arr, **kwargs):
Store or path to directory in file system or name of zip file.
arr : ndarray
NumPy array with data to save.
zarr_version : {2, 3}
The zarr protocol version to use when saving.
path : str
The path within the store where the array will be saved.
kwargs
Passed through to :func:`create`, e.g., compressor.
Expand All @@ -142,16 +157,18 @@ def save_array(store: StoreLike, arr, **kwargs):
"""
may_need_closing = _might_close(store)
_store: BaseStore = normalize_store_arg(store, clobber=True)
_store: BaseStore = normalize_store_arg(store, clobber=True, zarr_version=zarr_version)
path = _check_and_update_path(_store, path)
try:
_create_array(arr, store=_store, overwrite=True, **kwargs)
_create_array(arr, store=_store, overwrite=True, zarr_version=zarr_version, path=path,
**kwargs)
finally:
if may_need_closing:
# needed to ensure zip file records are written
_store.close()


def save_group(store: StoreLike, *args, **kwargs):
def save_group(store: StoreLike, *args, zarr_version=2, path=None, **kwargs):
"""Convenience function to save several NumPy arrays to the local file system, following a
similar API to the NumPy savez()/savez_compressed() functions.
Expand All @@ -161,6 +178,10 @@ def save_group(store: StoreLike, *args, **kwargs):
Store or path to directory in file system or name of zip file.
args : ndarray
NumPy arrays with data to save.
zarr_version : {2, 3}
The zarr protocol version to use when saving.
path : str
Path within the store where the group will be saved.
kwargs
NumPy arrays with data to save.
Expand Down Expand Up @@ -213,21 +234,22 @@ def save_group(store: StoreLike, *args, **kwargs):
raise ValueError('at least one array must be provided')
# handle polymorphic store arg
may_need_closing = _might_close(store)
_store: BaseStore = normalize_store_arg(store, clobber=True)
_store: BaseStore = normalize_store_arg(store, clobber=True, zarr_version=zarr_version)
path = _check_and_update_path(_store, path)
try:
grp = _create_group(_store, overwrite=True)
grp = _create_group(_store, path=path, overwrite=True, zarr_version=zarr_version)
for i, arr in enumerate(args):
k = 'arr_{}'.format(i)
grp.create_dataset(k, data=arr, overwrite=True)
grp.create_dataset(k, data=arr, overwrite=True, zarr_version=zarr_version)
for k, arr in kwargs.items():
grp.create_dataset(k, data=arr, overwrite=True)
grp.create_dataset(k, data=arr, overwrite=True, zarr_version=zarr_version)
finally:
if may_need_closing:
# needed to ensure zip file records are written
_store.close()


def save(store: StoreLike, *args, **kwargs):
def save(store: StoreLike, *args, zarr_version=2, path=None, **kwargs):
"""Convenience function to save an array or group of arrays to the local file system.
Parameters
Expand All @@ -236,6 +258,10 @@ def save(store: StoreLike, *args, **kwargs):
Store or path to directory in file system or name of zip file.
args : ndarray
NumPy arrays with data to save.
zarr_version : {2, 3}
The zarr protocol version to use when saving.
path : str
The path within the group where the arrays will be saved.
kwargs
NumPy arrays with data to save.
Expand Down Expand Up @@ -302,9 +328,10 @@ def save(store: StoreLike, *args, **kwargs):
if len(args) == 0 and len(kwargs) == 0:
raise ValueError('at least one array must be provided')
if len(args) == 1 and len(kwargs) == 0:
save_array(store, args[0])
save_array(store, args[0], zarr_version=zarr_version, path=path)
else:
save_group(store, *args, **kwargs)
save_group(store, *args, zarr_version=zarr_version, path=path,
**kwargs)


class LazyLoader(Mapping):
Expand Down Expand Up @@ -337,7 +364,7 @@ def __repr__(self):
return r


def load(store: StoreLike):
def load(store: StoreLike, zarr_version=2, path=None):
"""Load data from an array or group into memory.
Parameters
Expand All @@ -363,11 +390,12 @@ def load(store: StoreLike):
"""
# handle polymorphic store arg
_store = normalize_store_arg(store)
if contains_array(_store, path=None):
return Array(store=_store, path=None)[...]
elif contains_group(_store, path=None):
grp = Group(store=_store, path=None)
_store = normalize_store_arg(store, zarr_version=zarr_version)
path = _check_and_update_path(_store, path)
if contains_array(_store, path=path):
return Array(store=_store, path=path)[...]
elif contains_group(_store, path=path):
grp = Group(store=_store, path=path)
return LazyLoader(grp)


Expand Down Expand Up @@ -601,59 +629,79 @@ def copy_store(source, dest, source_path='', dest_path='', excludes=None,
# setup counting variables
n_copied = n_skipped = n_bytes_copied = 0

source_store_version = getattr(source, '_store_version', 2)
dest_store_version = getattr(dest, '_store_version', 2)
if source_store_version != dest_store_version:
raise ValueError("zarr stores must share the same protocol version")
if source_store_version > 2:
if not source_path or not dest_path:
raise ValueError("v3 stores require specifying a non-empty "
"source_path and dest_path")

# setup logging
with _LogWriter(log) as log:

# iterate over source keys
for source_key in sorted(source.keys()):

# filter to keys under source path
if source_key.startswith(source_path):
if source_store_version == 2:
if not source_key.startswith(source_path):
continue
elif source_store_version == 3:
# 'meta/root/' or 'data/root/' have length 10
if not source_key[10:].startswith(source_path):
continue

# process excludes and includes
exclude = False
for prog in excludes:
# process excludes and includes
exclude = False
for prog in excludes:
if prog.search(source_key):
exclude = True
break
if exclude:
for prog in includes:
if prog.search(source_key):
exclude = True
exclude = False
break
if exclude:
for prog in includes:
if prog.search(source_key):
exclude = False
break
if exclude:
continue
if exclude:
continue

# map key to destination path
# map key to destination path
if source_store_version == 2:
key_suffix = source_key[len(source_path):]
dest_key = dest_path + key_suffix

# create a descriptive label for this operation
descr = source_key
if dest_key != source_key:
descr = descr + ' -> ' + dest_key

# decide what to do
do_copy = True
if if_exists != 'replace':
if dest_key in dest:
if if_exists == 'raise':
raise CopyError('key {!r} exists in destination'
.format(dest_key))
elif if_exists == 'skip':
do_copy = False

# take action
if do_copy:
log('copy {}'.format(descr))
if not dry_run:
data = source[source_key]
n_bytes_copied += buffer_size(data)
dest[dest_key] = data
n_copied += 1
else:
log('skip {}'.format(descr))
n_skipped += 1
elif source_store_version == 3:
# 10 is length of 'meta/root/' or 'data/root/'
key_suffix = source_key[10 + len(source_path):]
dest_key = source_key[:10] + dest_path + key_suffix

# create a descriptive label for this operation
descr = source_key
if dest_key != source_key:
descr = descr + ' -> ' + dest_key

# decide what to do
do_copy = True
if if_exists != 'replace':
if dest_key in dest:
if if_exists == 'raise':
raise CopyError('key {!r} exists in destination'
.format(dest_key))
elif if_exists == 'skip':
do_copy = False

# take action
if do_copy:
log('copy {}'.format(descr))
if not dry_run:
data = source[source_key]
n_bytes_copied += buffer_size(data)
dest[dest_key] = data
n_copied += 1
else:
log('skip {}'.format(descr))
n_skipped += 1

# log a final message with a summary of what happened
_log_copy_summary(log, dry_run, n_copied, n_skipped, n_bytes_copied)
Expand Down Expand Up @@ -908,7 +956,15 @@ def _copy(log, source, dest, name, root, shallow, without_attrs, if_exists,

# copy attributes
if not without_attrs:
ds.attrs.update(source.attrs)
if dest_h5py and 'filters' in source.attrs:
# No filters key in v3 metadata so it was stored in the
# attributes instead. We cannot copy this key to
# HDF5 attrs, though!
source_attrs = source.attrs.asdict().copy()
source_attrs.pop('filters', None)
else:
source_attrs = source.attrs
ds.attrs.update(source_attrs)

n_copied += 1

Expand Down Expand Up @@ -1064,6 +1120,8 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
# setup counting variables
n_copied = n_skipped = n_bytes_copied = 0

zarr_version = getattr(source, '_version', 2)

# setup logging
with _LogWriter(log) as log:

Expand All @@ -1075,15 +1133,16 @@ def copy_all(source, dest, shallow=False, without_attrs=False, log=None,
n_copied += c
n_skipped += s
n_bytes_copied += b
dest.attrs.update(**source.attrs)
if zarr_version == 2:
dest.attrs.update(**source.attrs)

# log a final message with a summary of what happened
_log_copy_summary(log, dry_run, n_copied, n_skipped, n_bytes_copied)

return n_copied, n_skipped, n_bytes_copied


def consolidate_metadata(store: StoreLike, metadata_key=".zmetadata"):
def consolidate_metadata(store: BaseStore, metadata_key=".zmetadata"):
"""
Consolidate all metadata for groups and arrays within the given store
into a single resource and put it under the given key.
Expand Down
Loading

0 comments on commit b13a6b3

Please sign in to comment.