Skip to content

Commit

Permalink
Add keep_bits to API so writers can round_mantissa
Browse files Browse the repository at this point in the history
  • Loading branch information
scottstanie committed Feb 3, 2025
1 parent 360ea37 commit f2f46a6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/dolphin/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def repack_rasters(
)


def round_mantissa(z: np.ndarray, keep_bits=10) -> None:
def round_mantissa(z: np.ndarray, keep_bits: int = 10) -> None:
"""Zero out mantissa bits of elements of array in place.
Drops a specified number of bits from the floating point mantissa,
Expand All @@ -251,10 +251,11 @@ def round_mantissa(z: np.ndarray, keep_bits=10) -> None:
----------
z : numpy.ndarray
Real or complex array whose mantissas are to be zeroed out
keep_bits : int, optional
Number of bits to preserve in mantissa. Defaults to 10.
keep_bits : int
Number of bits to preserve in mantissa.
Lower numbers will truncate the mantissa more and enable
more compression.
Default is 10.
References
----------
Expand Down
45 changes: 39 additions & 6 deletions src/dolphin/io/_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dolphin._types import Filename

from ._background import BackgroundWriter
from ._utils import _unpack_3d_slices
from ._utils import _unpack_3d_slices, round_mantissa

__all__ = [
"BackgroundBlockWriter",
Expand All @@ -40,8 +40,16 @@
class BackgroundBlockWriter(BackgroundWriter):
"""Class to write data to multiple files in the background using `gdal` bindings."""

def __init__(self, *, max_queue: int = 0, debug: bool = False, **kwargs):
def __init__(
self,
*,
max_queue: int = 0,
debug: bool = False,
keep_bits: int | None = None,
**kwargs,
):
super().__init__(nq=max_queue, name="Writer")
self.keep_bits = keep_bits
if debug:
# background thread. Just synchronously write data
self.notify_finished()
Expand Down Expand Up @@ -79,6 +87,9 @@ def write(
"""
from dolphin.io import write_block

if np.issubdtype(data.dtype, np.floating) and self.keep_bits is not None:
round_mantissa(data, keep_bits=self.keep_bits)

write_block(data, filename, row_start, col_start, band=band)


Expand Down Expand Up @@ -155,6 +166,8 @@ class implements Python's context manager protocol, which can be used to reliabl
"""str or Path : Path to the file to write."""
band: int = 1
"""int : Band index in the file to write."""
keep_bits: int | None = None
"""int : For floating point rasters, the number of mantissa bits to keep."""

def __post_init__(self) -> None:
# Open the dataset.
Expand Down Expand Up @@ -183,6 +196,7 @@ def create(
transform: rasterio.transform.Affine | None = None,
*,
like_filename: Filename | None = None,
keep_bits: int | None = None,
**kwargs: Any,
) -> RasterT:
"""Create a new single-band raster dataset.
Expand Down Expand Up @@ -223,6 +237,9 @@ def create(
with the same metadata (shape, data-type, driver, CRS/geotransform, etc) as
the reference raster. All other arguments will override the corresponding
attribute of the reference raster. Defaults to None.
keep_bits : int, optional
Number of bits to preserve in mantissa. Defaults to None.
Lower numbers will truncate the mantissa more and enable more compression.
**kwargs : dict, optional
Additional driver-specific creation options passed to `rasterio.open`.
Expand Down Expand Up @@ -252,7 +269,7 @@ def create(
with rasterio.open(fp, mode="w+", **kwargs):
pass

return cls(fp, band=1)
return cls(fp, band=1, keep_bits=keep_bits)

@property
def dtype(self) -> np.dtype:
Expand Down Expand Up @@ -304,6 +321,8 @@ def __repr__(self) -> str:
return f"{clsname}(dataset={self.dataset!r}, band={self.band!r})"

def __setitem__(self, key: tuple[Index, ...], value: np.ndarray, /) -> None:
if np.issubdtype(value.dtype, np.floating) and self.keep_bits is not None:
round_mantissa(value, keep_bits=self.keep_bits)
with rasterio.open(
self.filename,
"r+",
Expand Down Expand Up @@ -332,7 +351,13 @@ class BackgroundRasterWriter(BackgroundWriter, DatasetWriter):
"""Class to write data to files in a background thread."""

def __init__(
self, filename: Filename, *, max_queue: int = 0, debug: bool = False, **kwargs
self,
filename: Filename,
*,
max_queue: int = 0,
debug: bool = False,
keep_bits: int | None = None,
**kwargs,
):
super().__init__(nq=max_queue, name="Writer")
if debug:
Expand All @@ -341,9 +366,9 @@ def __init__(
self.queue_write = self.write # type: ignore[assignment]

if Path(filename).exists():
self._raster = RasterWriter(filename)
self._raster = RasterWriter(filename, keep_bits=keep_bits)
else:
self._raster = RasterWriter.create(filename, **kwargs)
self._raster = RasterWriter.create(filename, keep_bits=keep_bits, **kwargs)
self.filename = filename
self.ndim = 2

Expand Down Expand Up @@ -402,6 +427,7 @@ def __init__(
like_filename: Filename | None = None,
max_queue: int = 0,
debug: bool = False,
keep_bits: int | None = None,
**file_creation_kwargs,
):
from dolphin.io import write_arr
Expand All @@ -421,6 +447,7 @@ def __init__(
)

self.file_list = file_list
self.keep_bits = keep_bits

with rasterio.open(self.file_list[0]) as src:
self.shape = (len(self.file_list), *src.shape)
Expand Down Expand Up @@ -451,11 +478,17 @@ def write(
"""
from dolphin.io import write_block

_do_round = (
np.issubdtype(data.dtype, np.floating) and self.keep_bits is not None
)
if data.ndim == 2:
data = data[None, ...]
if data.shape[0] != len(self.file_list):
raise ValueError(f"{data.shape = }, but {len(self.file_list) = }")
for fn, layer in zip(self.file_list, data):
if _do_round:
assert self.keep_bits is not None
round_mantissa(layer, keep_bits=self.keep_bits)
write_block(layer, fn, row_start, col_start, band=band)

def __setitem__(self, key, value):
Expand Down

0 comments on commit f2f46a6

Please sign in to comment.