Skip to content

Commit

Permalink
Add in-memory icechunk tests to existing roundtrip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jsignell committed Jan 30, 2025
1 parent 61e3cff commit bd95d60
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 65 deletions.
9 changes: 9 additions & 0 deletions virtualizarr/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _importorskip(


has_astropy, requires_astropy = _importorskip("astropy")
has_icechunk, requires_icechunk = _importorskip("icechunk")
has_kerchunk, requires_kerchunk = _importorskip("kerchunk")
has_s3fs, requires_s3fs = _importorskip("s3fs")
has_scipy, requires_scipy = _importorskip("scipy")
Expand Down Expand Up @@ -119,3 +120,11 @@ def open_dataset_kerchunk(
"reference", fo=filename_or_obj, **(storage_options or {})
).get_mapper()
return xr.open_dataset(m, engine="zarr", consolidated=False, **kwargs)


def in_memory_icechunk_session():
from icechunk import Repository, Storage

repo = Repository.create(storage=Storage.new_in_memory())
session = repo.writable_session("main")
return session
124 changes: 59 additions & 65 deletions virtualizarr/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from virtualizarr import open_virtual_dataset
from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.tests import (
has_icechunk,
has_kerchunk,
in_memory_icechunk_session,
open_dataset_kerchunk,
parametrize_over_hdf_backends,
requires_kerchunk,
Expand Down Expand Up @@ -40,16 +42,16 @@ def test_kerchunk_roundtrip_in_memory_no_concat():
),
chunkmanifest=manifest,
)
ds = xr.Dataset({"a": (["x", "y"], marr)})
vds = xr.Dataset({"a": (["x", "y"], marr)})

# Use accessor to write it out to kerchunk reference dict
ds_refs = ds.virtualize.to_kerchunk(format="dict")
ds_refs = vds.virtualize.to_kerchunk(format="dict")

# Use dataset_from_kerchunk_refs to reconstruct the dataset
roundtrip = dataset_from_kerchunk_refs(ds_refs)

# Assert equal to original dataset
xrt.assert_equal(roundtrip, ds)
xrt.assert_equal(roundtrip, vds)


@requires_kerchunk
Expand Down Expand Up @@ -90,13 +92,55 @@ def test_numpy_arrays_to_inlined_kerchunk_refs(
assert refs["refs"]["time/0"] == expected["refs"]["time/0"]


def roundtrip_as_kerchunk_dict(vds: xr.Dataset, tmpdir, **kwargs):
# write those references to an in-memory kerchunk-formatted references dictionary
ds_refs = vds.virtualize.to_kerchunk(format="dict")

# use fsspec to read the dataset from the kerchunk references dict
return open_dataset_kerchunk(ds_refs, **kwargs)


def roundtrip_as_kerchunk_json(vds: xr.Dataset, tmpdir, **kwargs):
# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.json", format="json")

# use fsspec to read the dataset from disk via the kerchunk references
return open_dataset_kerchunk(f"{tmpdir}/refs.json", **kwargs)


def roundtrip_as_kerchunk_parquet(vds: xr.Dataset, tmpdir, **kwargs):
# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.parquet", format="parquet")

# use fsspec to read the dataset from disk via the kerchunk references
return open_dataset_kerchunk(f"{tmpdir}/refs.parquet", **kwargs)


def roundtrip_as_in_memory_icechunk(vds: xr.Dataset, tmpdir, **kwargs):
# write those references to an in-memory icechunk store
icechunk_session = in_memory_icechunk_session()
vds.virtualize.to_icechunk(icechunk_session.store)
icechunk_session.commit("add data")

# read the dataset from icechunk
return xr.open_zarr(
icechunk_session.store, zarr_format=3, consolidated=False, **kwargs
)


@requires_zarr_python
@pytest.mark.parametrize(
"format", ["dict", "json", "parquet"] if has_kerchunk else ["dict", "json"]
"roundtrip_func",
[
roundtrip_as_kerchunk_dict,
roundtrip_as_kerchunk_json,
*([roundtrip_as_kerchunk_parquet] if has_kerchunk else []),
*([roundtrip_as_in_memory_icechunk] if has_icechunk else []),
],
)
class TestKerchunkRoundtrip:
class TestRoundtrip:
@parametrize_over_hdf_backends
def test_kerchunk_roundtrip_no_concat(self, tmpdir, format, hdf_backend):
def test_roundtrip_no_concat(self, tmpdir, roundtrip_func, hdf_backend):
# set up example xarray dataset
ds = xr.tutorial.open_dataset("air_temperature", decode_times=False)

Expand All @@ -106,20 +150,7 @@ def test_kerchunk_roundtrip_no_concat(self, tmpdir, format, hdf_backend):
# use open_dataset_via_kerchunk to read it as references
vds = open_virtual_dataset(f"{tmpdir}/air.nc", indexes={}, backend=hdf_backend)

if format == "dict":
# write those references to an in-memory kerchunk-formatted references dictionary
ds_refs = vds.virtualize.to_kerchunk(format=format)

# use fsspec to read the dataset from the kerchunk references dict
roundtrip = open_dataset_kerchunk(ds_refs, decode_times=False)
else:
# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = open_dataset_kerchunk(
f"{tmpdir}/refs.{format}", decode_times=False
)
roundtrip = roundtrip_func(vds, tmpdir, decode_times=False)

# assert all_close to original dataset
xrt.assert_allclose(roundtrip, ds)
Expand All @@ -131,7 +162,7 @@ def test_kerchunk_roundtrip_no_concat(self, tmpdir, format, hdf_backend):
@parametrize_over_hdf_backends
@pytest.mark.parametrize("decode_times,time_vars", [(False, []), (True, ["time"])])
def test_kerchunk_roundtrip_concat(
self, tmpdir, format, hdf_backend, decode_times, time_vars
self, tmpdir, roundtrip_func, hdf_backend, decode_times, time_vars
):
# set up example xarray dataset
ds = xr.tutorial.open_dataset("air_temperature", decode_times=decode_times)
Expand Down Expand Up @@ -167,20 +198,7 @@ def test_kerchunk_roundtrip_concat(
# concatenate virtually along time
vds = xr.concat([vds1, vds2], dim="time", coords="minimal", compat="override")

if format == "dict":
# write those references to an in-memory kerchunk-formatted references dictionary
ds_refs = vds.virtualize.to_kerchunk(format=format)

# use fsspec to read the dataset from the kerchunk references dict
roundtrip = open_dataset_kerchunk(ds_refs, decode_times=decode_times)
else:
# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = open_dataset_kerchunk(
f"{tmpdir}/refs.{format}", decode_times=decode_times
)
roundtrip = roundtrip_func(vds, tmpdir, decode_times=decode_times)

if decode_times is False:
# assert all_close to original dataset
Expand All @@ -197,7 +215,7 @@ def test_kerchunk_roundtrip_concat(
assert roundtrip.time.encoding["calendar"] == ds.time.encoding["calendar"]

@parametrize_over_hdf_backends
def test_non_dimension_coordinates(self, tmpdir, format, hdf_backend):
def test_non_dimension_coordinates(self, tmpdir, roundtrip_func, hdf_backend):
# regression test for GH issue #105

if hdf_backend:
Expand All @@ -215,20 +233,7 @@ def test_non_dimension_coordinates(self, tmpdir, format, hdf_backend):
assert "lat" in vds.coords
assert "coordinates" not in vds.attrs

if format == "dict":
# write those references to an in-memory kerchunk-formatted references dictionary
ds_refs = vds.virtualize.to_kerchunk(format=format)

# use fsspec to read the dataset from the kerchunk references dict
roundtrip = open_dataset_kerchunk(ds_refs, decode_times=False)
else:
# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = open_dataset_kerchunk(
f"{tmpdir}/refs.{format}", decode_times=False
)
roundtrip = roundtrip_func(vds, tmpdir)

# assert equal to original dataset
xrt.assert_allclose(roundtrip, ds)
Expand All @@ -237,7 +242,7 @@ def test_non_dimension_coordinates(self, tmpdir, format, hdf_backend):
for coord in ds.coords:
assert ds.coords[coord].attrs == roundtrip.coords[coord].attrs

def test_datetime64_dtype_fill_value(self, tmpdir, format):
def test_datetime64_dtype_fill_value(self, tmpdir, roundtrip_func):
chunks_dict = {
"0.0.0": {"path": "/foo.nc", "offset": 100, "length": 100},
}
Expand All @@ -255,7 +260,7 @@ def test_datetime64_dtype_fill_value(self, tmpdir, format):
zarr_format=2,
)
marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest)
ds = xr.Dataset(
vds = xr.Dataset(
{
"a": xr.DataArray(
marr1,
Expand All @@ -266,20 +271,9 @@ def test_datetime64_dtype_fill_value(self, tmpdir, format):
}
)

if format == "dict":
# write those references to an in-memory kerchunk-formatted references dictionary
ds_refs = ds.virtualize.to_kerchunk(format=format)

# use fsspec to read the dataset from the kerchunk references dict
roundtrip = open_dataset_kerchunk(ds_refs)
else:
# write those references to disk as kerchunk references format
ds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = open_dataset_kerchunk(f"{tmpdir}/refs.{format}")
roundtrip = roundtrip_func(vds, tmpdir)

assert roundtrip.a.attrs == ds.a.attrs
assert roundtrip.a.attrs == vds.a.attrs


@parametrize_over_hdf_backends
Expand Down

0 comments on commit bd95d60

Please sign in to comment.