diff --git a/zarr/convenience.py b/zarr/convenience.py index 1ed0f92ff3..80cf7fffd4 100644 --- a/zarr/convenience.py +++ b/zarr/convenience.py @@ -1,6 +1,7 @@ """Convenience functions for storing and loading data.""" import io import itertools +import os import re from collections.abc import Mapping @@ -100,6 +101,10 @@ def open(store=None, mode='a', **kwargs): raise PathNotFoundError(path) +def _might_close(path): + return isinstance(path, (str, os.PathLike)) + + def save_array(store, arr, **kwargs): """Convenience function to save a NumPy array to the local file system, following a similar API to the NumPy save() function. @@ -131,7 +136,7 @@ def save_array(store, arr, **kwargs): array([ 0, 1, 2, ..., 9997, 9998, 9999]) """ - may_need_closing = isinstance(store, str) + may_need_closing = _might_close(store) store = normalize_store_arg(store, clobber=True) try: _create_array(arr, store=store, overwrite=True, **kwargs) @@ -202,7 +207,7 @@ def save_group(store, *args, **kwargs): if len(args) == 0 and len(kwargs) == 0: raise ValueError('at least one array must be provided') # handle polymorphic store arg - may_need_closing = isinstance(store, str) + may_need_closing = _might_close(store) store = normalize_store_arg(store, clobber=True) try: grp = _create_group(store, overwrite=True) diff --git a/zarr/creation.py b/zarr/creation.py index 73e10adff1..213fa248ac 100644 --- a/zarr/creation.py +++ b/zarr/creation.py @@ -1,3 +1,4 @@ +import os from warnings import warn import numpy as np @@ -148,7 +149,9 @@ def create(shape, chunks=True, dtype=None, compressor='default', def normalize_store_arg(store, clobber=False, storage_options=None, mode='w'): if store is None: return dict() - elif isinstance(store, str): + if isinstance(store, os.PathLike): + store = os.fspath(store) + if isinstance(store, str): mode = mode if clobber else "r" if "://" in store or "::" in store: return FSStore(store, mode=mode, **(storage_options or {})) diff --git a/zarr/tests/conftest.py b/zarr/tests/conftest.py new file mode 100644 index 0000000000..aa73b8691e --- /dev/null +++ b/zarr/tests/conftest.py @@ -0,0 +1,8 @@ +import pathlib + +import pytest + + +@pytest.fixture(params=[str, pathlib.Path]) +def path_type(request): + return request.param diff --git a/zarr/tests/test_convenience.py b/zarr/tests/test_convenience.py index 20cd25027c..a5ac40e371 100644 --- a/zarr/tests/test_convenience.py +++ b/zarr/tests/test_convenience.py @@ -27,10 +27,11 @@ atexit_rmtree, getsize) -def test_open_array(): +def test_open_array(path_type): store = tempfile.mkdtemp() atexit.register(atexit_rmtree, store) + store = path_type(store) # open array, create if doesn't exist z = open(store, mode='a', shape=100) @@ -53,10 +54,11 @@ def test_open_array(): open('doesnotexist', mode='r') -def test_open_group(): +def test_open_group(path_type): store = tempfile.mkdtemp() atexit.register(atexit_rmtree, store) + store = path_type(store) # open group, create if doesn't exist g = open(store, mode='a') diff --git a/zarr/tests/test_storage.py b/zarr/tests/test_storage.py index e9b997b335..d3f3b0e770 100644 --- a/zarr/tests/test_storage.py +++ b/zarr/tests/test_storage.py @@ -2,6 +2,7 @@ import atexit import json import os +import pathlib import sys import pickle import shutil @@ -836,6 +837,11 @@ def test_filesystem_path(self): with pytest.raises(ValueError): DirectoryStore(f.name) + def test_init_pathlib(self): + path = tempfile.mkdtemp() + atexit.register(atexit_rmtree, path) + DirectoryStore(pathlib.Path(path)) + def test_pickle_ext(self): store = self.create_store() store2 = pickle.loads(pickle.dumps(store))