Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image-label convenience functions #178

Merged
merged 13 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
- name: Test
run: tox -e 'py37-coverage'

- uses: codecov/codecov-action@v1
- uses: codecov/codecov-action@v2
with:
file: ./coverage.xml
fail_ci_if_error: true
128 changes: 118 additions & 10 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def write_multiscale(
axes: Union[str, List[str], List[Dict[str, str]]] = None,
coordinate_transformations: List[List[Dict[str, Any]]] = None,
storage_options: Union[JSONDict, List[JSONDict]] = None,
**metadata: Union[str, JSONDict, List[JSONDict]],
) -> None:
"""
Write a pyramid with multiscale metadata to disk.
Expand Down Expand Up @@ -232,14 +233,15 @@ def write_multiscale(
for dataset, transform in zip(datasets, coordinate_transformations):
dataset["coordinateTransformations"] = transform

write_multiscales_metadata(group, datasets, fmt, axes)
write_multiscales_metadata(group, datasets, fmt, axes, **metadata)


def write_multiscales_metadata(
group: zarr.Group,
datasets: List[dict],
fmt: Format = CurrentFormat(),
axes: Union[str, List[str], List[Dict[str, str]]] = None,
**metadata: Union[str, JSONDict, List[JSONDict]],
) -> None:
"""
Write the multiscales metadata in the group.
Expand Down Expand Up @@ -268,11 +270,15 @@ def write_multiscales_metadata(
if axes is not None:
ndim = len(axes)

# note: we construct the multiscale metadata via dict(), rather than {}
# to avoid duplication of protected keys like 'version' in **metadata
# (for {} this would silently over-write it, with dict() it explicitly fails)
multiscales = [
{
"version": fmt.version,
"datasets": _validate_datasets(datasets, ndim, fmt),
}
dict(
version=fmt.version,
datasets=_validate_datasets(datasets, ndim, fmt),
**metadata,
)
]
if axes is not None:
multiscales[0]["axes"] = axes
Expand Down Expand Up @@ -363,7 +369,7 @@ def write_image(
axes: Union[str, List[str], List[Dict[str, str]]] = None,
coordinate_transformations: List[List[Dict[str, Any]]] = None,
storage_options: Union[JSONDict, List[JSONDict]] = None,
**metadata: JSONDict,
**metadata: Union[str, JSONDict, List[JSONDict]],
) -> None:
"""Writes an image to the zarr store according to ome-zarr specification

Expand Down Expand Up @@ -419,21 +425,123 @@ def write_image(
"Can't downsample if size of x or y dimension is 1. "
"Shape: %s" % (image.shape,)
)
image = scaler.nearest(image)
mip = scaler.nearest(image)
else:
LOGGER.debug("disabling pyramid")
image = [image]
mip = [image]

write_multiscale(
image,
mip,
group,
chunks=chunks,
fmt=fmt,
axes=axes,
coordinate_transformations=coordinate_transformations,
storage_options=storage_options,
**metadata,
)


def write_label_metadata(
group: zarr.Group,
name: str,
colors: List[JSONDict] = None,
properties: List[JSONDict] = None,
**metadata: Union[List[JSONDict], JSONDict, str],
) -> None:
"""
Write image-label metadata to the group.

The label data must have been written to a sub-group,
with the same name as the second argument.

group: zarr.Group
the top-level label group within the zarr store
name: str
the name of the label sub-group
colors: list of JSONDict
Fixed colors for (a subset of) the label values.
Each dict specifies the color for one label and must contain the fields
"label-value" and "rgba".
properties: list of JSONDict
Additional properties for (a subset of) the label values.
Each dict specifies additional properties for one label.
It must contain the field "label-value"
and may contain arbitrary additional properties.
"""
label_group = group[name]
image_label_metadata = {**metadata}
if colors is not None:
image_label_metadata["colors"] = colors
if properties is not None:
image_label_metadata["properties"] = properties
label_group.attrs["image-label"] = image_label_metadata

label_list = group.attrs.get("labels", [])
label_list.append(name)
group.attrs["labels"] = label_list


def write_multiscale_labels(
pyramid: List,
group: zarr.Group,
name: str,
chunks: Union[Tuple[Any, ...], int] = None,
fmt: Format = CurrentFormat(),
axes: Union[str, List[str], List[Dict[str, str]]] = None,
coordinate_transformations: List[List[Dict[str, Any]]] = None,
storage_options: Union[JSONDict, List[JSONDict]] = None,
label_metadata: JSONDict = None,
**metadata: JSONDict,
) -> None:
"""
Write pyramidal image labels to disk.

Including the multiscales and image-label metadata.
Creates the label data in the sub-group "labels/{name}"

pyramid: List of np.ndarray
the image label data to save. Largest level first
All image arrays MUST be up to 5-dimensional with dimensions
ordered (t, c, z, y, x)
group: zarr.Group
the group within the zarr store to store the data in
name: str
the name of this labale data
chunks: int or tuple of ints,
size of the saved chunks to store the image
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str or list of dict
List of axes dicts, or names. Not needed for v0.1 or v0.2
or if 2D. Otherwise this must be provided
coordinate_transformations: 2Dlist of dict
For each path, we have a List of transformation Dicts.
Each list of dicts are added to each datasets in order
and must include a 'scale' transform.
storage_options: dict or list of dict
Options to be passed on to the storage backend. A list would need to match
the number of datasets in a multiresolution pyramid. One can provide
different chunk size for each level of a pyramind using this option.
label_metadata: JSONDict
image label metadata. See 'write_label_metadata' for details
"""
sub_group = group.require_group(f"labels/{name}")
write_multiscale(
pyramid,
sub_group,
chunks,
fmt,
axes,
coordinate_transformations,
storage_options,
name=name,
**metadata,
)
write_label_metadata(
group["labels"], name, **({} if label_metadata is None else label_metadata)
)
group.attrs.update(metadata)


def _retuple(
Expand Down
161 changes: 161 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ome_zarr.writer import (
_get_valid_axes,
write_image,
write_multiscale_labels,
write_multiscales_metadata,
write_plate_metadata,
write_well_metadata,
Expand Down Expand Up @@ -823,3 +824,163 @@ def test_unspecified_images_keys(self):
assert "well" in self.root.attrs
assert self.root.attrs["well"]["images"] == images
assert self.root.attrs["well"]["version"] == CurrentFormat().version


class TestLabelWriter:
@pytest.fixture(autouse=True)
def initdir(self, tmpdir):
self.path = pathlib.Path(tmpdir.mkdir("data.ome.zarr"))
self.store = parse_url(self.path, mode="w").store
self.root = zarr.group(store=self.store)

def create_image_data(self, shape, scaler, version, axes, transformations):
rng = np.random.default_rng(0)
data = rng.poisson(10, size=shape).astype(np.uint8)
write_image(
image=data,
group=self.root,
chunks=(128, 128),
scaler=scaler,
fmt=version,
axes=axes,
coordinate_transformations=transformations,
)

@pytest.fixture(
params=(
(1, 2, 1, 256, 256),
(3, 512, 512),
(256, 256),
),
ids=["5D", "3D", "2D"],
)
def shape(self, request):
return request.param

@pytest.fixture(params=[True, False], ids=["scale", "noop"])
def scaler(self, request):
if request.param:
return Scaler()
else:
return None

@pytest.mark.parametrize(
"format_version",
(
pytest.param(FormatV01, id="V01"),
pytest.param(FormatV02, id="V02"),
pytest.param(FormatV03, id="V03"),
pytest.param(FormatV04, id="V04"),
),
)
def test_multiscale_label_writer(self, shape, scaler, format_version):

version = format_version()
axes = "tczyx"[-len(shape) :]
transformations = []
for dataset_transfs in TRANSFORMATIONS:
transf = dataset_transfs[0]
# e.g. slice [1, 1, z, x, y] -> [z, x, y] for 3D
transformations.append(
[{"type": "scale", "scale": transf["scale"][-len(shape) :]}]
)

# create the actual label data
label_data = np.random.randint(0, 1000, size=shape)
if version.version in ("0.1", "0.2"):
# v0.1 and v0.2 require 5d
expand_dims = (np.s_[None],) * (5 - len(shape))
label_data = label_data[expand_dims]
assert label_data.ndim == 5
label_name = "my-labels"
if scaler is None:
transformations = [transformations[0]]
labels_mip = [label_data]
else:
labels_mip = scaler.nearest(label_data)

# create the root level image data
self.create_image_data(shape, scaler, version, axes, transformations)

write_multiscale_labels(
labels_mip,
self.root,
name=label_name,
fmt=version,
axes=axes,
coordinate_transformations=transformations,
)

# Verify image data
reader = Reader(parse_url(f"{self.path}/labels/{label_name}"))
node = list(reader())[0]
assert Multiscales.matches(node.zarr)
if version.version in ("0.1", "0.2"):
# v0.1 and v0.2 MUST be 5D
assert node.data[0].ndim == 5
else:
assert node.data[0].shape == shape

if version.version not in ("0.1", "0.2", "0.3"):
for transf, expected in zip(
node.metadata["coordinateTransformations"], transformations
):
assert transf == expected
assert len(node.metadata["coordinateTransformations"]) == len(node.data)
assert np.allclose(label_data, node.data[0][...].compute())

# Verify label metadata
label_root = zarr.open(f"{self.path}/labels", "r")
assert "labels" in label_root.attrs
assert label_name in label_root.attrs["labels"]

label_group = zarr.open(f"{self.path}/labels/{label_name}", "r")
assert "image-label" in label_group.attrs

# Verify multiscale metadata
name = label_group.attrs["multiscales"][0].get("name", "")
assert label_name == name

def test_two_label_images(self):
axes = "tczyx"
transformations = []
for dataset_transfs in TRANSFORMATIONS:
transf = dataset_transfs[0]
transformations.append([{"type": "scale", "scale": transf["scale"]}])

# create the root level image data
shape = (1, 2, 1, 256, 256)
scaler = Scaler()
self.create_image_data(
shape,
scaler,
axes=axes,
version=FormatV04(),
transformations=transformations,
)

label_names = ("first_labels", "second_labels")
for label_name in label_names:
label_data = np.random.randint(0, 1000, size=shape)
labels_mip = scaler.nearest(label_data)

write_multiscale_labels(
labels_mip,
self.root,
name=label_name,
axes=axes,
coordinate_transformations=transformations,
)

label_group = zarr.open(f"{self.path}/labels/{label_name}", "r")
assert "image-label" in label_group.attrs
name = label_group.attrs["multiscales"][0].get("name", "")
assert label_name == name

# Verify label metadata
label_root = zarr.open(f"{self.path}/labels", "r")
assert "labels" in label_root.attrs
assert len(label_root.attrs["labels"]) == len(label_names)
assert all(
label_name in label_root.attrs["labels"] for label_name in label_names
)