diff --git a/python/lsst/scarlet/lite/io/blend.py b/python/lsst/scarlet/lite/io/blend.py index b817848..9a2db02 100644 --- a/python/lsst/scarlet/lite/io/blend.py +++ b/python/lsst/scarlet/lite/io/blend.py @@ -2,6 +2,7 @@ import logging from dataclasses import dataclass +from functools import cached_property from typing import Any import numpy as np @@ -51,6 +52,11 @@ class ScarletBlendData(ScarletBlendBaseData): sources: dict[Any, ScarletSourceBaseData] version: str = CURRENT_SCHEMA + @cached_property + def bbox(self) -> Box: + """The bounding box of the blend""" + return Box(self.shape, origin=self.origin) + def as_dict(self) -> dict: """Return the object encoded into a dict for JSON serialization @@ -131,7 +137,7 @@ def minimal_data_to_blend( _model_psf: np.ndarray = extract_from_metadata(model_psf, self.metadata, "model_psf") _psf: np.ndarray = extract_from_metadata(psf, self.metadata, "psf") _bands: tuple[str] = extract_from_metadata(bands, self.metadata, "bands") - model_box = Box(self.shape, origin=self.origin) + model_box = self.bbox observation = Observation.empty( bands=_bands, psfs=_psf, diff --git a/python/lsst/scarlet/lite/io/blend_base.py b/python/lsst/scarlet/lite/io/blend_base.py index 551e39c..82b4dc5 100644 --- a/python/lsst/scarlet/lite/io/blend_base.py +++ b/python/lsst/scarlet/lite/io/blend_base.py @@ -6,6 +6,7 @@ from numpy.typing import DTypeLike +from ..bbox import Box from .utils import PersistenceError __all__ = ["ScarletBlendBaseData"] @@ -32,6 +33,11 @@ class ScarletBlendBaseData(ABC): metadata: dict[str, Any] | None = None version: str + @property + @abstractmethod + def bbox(self) -> Box: + """The bounding box of the blend""" + @classmethod def register(cls) -> None: """Register a new Blend type""" diff --git a/python/lsst/scarlet/lite/io/hierarchical_blend.py b/python/lsst/scarlet/lite/io/hierarchical_blend.py index 30406a0..158bf89 100644 --- a/python/lsst/scarlet/lite/io/hierarchical_blend.py +++ b/python/lsst/scarlet/lite/io/hierarchical_blend.py @@ -6,6 +6,7 @@ import numpy as np from numpy.typing import DTypeLike +from ..bbox import Box from .blend_base import ScarletBlendBaseData from .migration import PRE_SCHEMA, MigrationRegistry, migration from .utils import PersistenceError, decode_metadata, encode_metadata @@ -35,6 +36,21 @@ class HierarchicalBlendData(ScarletBlendBaseData): children: dict[int, ScarletBlendBaseData] version: str = CURRENT_SCHEMA + @property + def bbox(self) -> Box: + """The bounding box of the blend""" + # Compute the bounding box that contains all children + if not self.children: + raise ValueError("HierarchicalBlendData has no children to compute bbox from.") + bboxes = [child.bbox for child in self.children.values()] + min_y = min(bbox.origin[0] for bbox in bboxes) + min_x = min(bbox.origin[1] for bbox in bboxes) + max_y = max(bbox.origin[0] + bbox.shape[0] for bbox in bboxes) + max_x = max(bbox.origin[1] + bbox.shape[1] for bbox in bboxes) + origin = (min_y, min_x) + shape = (max_y - min_y, max_x - min_x) + return Box(shape, origin=origin) + def as_dict(self) -> dict: """Return the object encoded into a dict for JSON serialization diff --git a/tests/test_io.py b/tests/test_io.py index 7ee5f82..9e8cb4a 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -85,6 +85,7 @@ def test_json(self): self.assertEqual(len(blend.sources), len(loaded_blend.sources)) self.assertEqual(len(blend.components), len(loaded_blend.components)) self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model()) + self.assertBoxEqual(blend.bbox, blend_data.bbox) for sidx in range(len(blend.sources)): source1 = blend.sources[sidx]