Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions mesa_geo/raster_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(
Origin is at upper left corner of the grid. Use rowcol instead.
:param rowcol: Indices of the cell in (row, col) format.
Origin is at upper left corner of the grid
:param xy: Cell center coordinates in the CRS.
:param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS.
"""

super().__init__(model)
Expand Down Expand Up @@ -263,7 +263,7 @@ def rowcol(self) -> Coordinate | None:
@property
def xy(self) -> FloatCoordinate | None:
"""
Cell center coordinates in the CRS.
Geographic/projected (x, y) coordinates of the cell center in the CRS.
"""
return self._xy

Expand Down Expand Up @@ -426,53 +426,104 @@ def coord_iter(self) -> Iterator[tuple[Cell, int, int]]:
for col in range(self.height):
yield self.cells[row][col], row, col # cell, x, y

def apply_raster(self, data: np.ndarray, attr_name: str | None = None) -> None:
def apply_raster(
self, data: np.ndarray, attr_name: str | Sequence[str] | None = None
) -> None:
"""
Apply raster data to the cells.

:param np.ndarray data: 2D numpy array with shape (1, height, width).
:param str | None attr_name: Name of the attribute to be added to the cells.
If None, a random name will be generated. Default is None.
:raises ValueError: If the shape of the data is not (1, height, width).
:param np.ndarray data: 3D numpy array with shape (bands, height, width).
:param str | Sequence[str] | None attr_name: Attribute name(s) to be added to the
cells. For multi-band rasters, pass a list of names with length equal to
the number of bands, or a single base name to be suffixed per band. If None,
names are generated. Default is None.
:raises ValueError: If the shape of the data does not match the raster.
"""

if data.shape != (1, self.height, self.width):
if data.ndim != 3 or data.shape[1:] != (self.height, self.width):
raise ValueError(
f"Data shape does not match raster shape. "
f"Expected {(1, self.height, self.width)}, received {data.shape}."
f"Expected (*, {self.height}, {self.width}), received {data.shape}."
)
if attr_name is None:
attr_name = f"attribute_{len(self.cell_cls.__dict__)}"
self._attributes.add(attr_name)
for grid_x in range(self.width):
for grid_y in range(self.height):
setattr(
self.cells[grid_x][grid_y],
attr_name,
data[0, self.height - grid_y - 1, grid_x],
)
num_bands = data.shape[0]

if num_bands == 1:
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
if len(attr_name) != 1:
raise ValueError(
"attr_name sequence length must match the number of raster bands; "
f"expected {num_bands} band names, got {len(attr_name)}."
)
names = [attr_name[0]]
else:
names = [cast(str | None, attr_name)]
else:
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
if len(attr_name) != num_bands:
raise ValueError(
"attr_name sequence length must match the number of raster bands; "
f"expected {num_bands} band names, got {len(attr_name)}."
)
names = list(attr_name)
elif isinstance(attr_name, str):
names = [f"{attr_name}_{band_idx + 1}" for band_idx in range(num_bands)]
else:
names = [None] * num_bands

def _default_attr_name() -> str:
base = f"attribute_{len(self.cell_cls.__dict__)}"
if base not in self._attributes:
return base
suffix = 1
candidate = f"{base}_{suffix}"
while candidate in self._attributes:
suffix += 1
candidate = f"{base}_{suffix}"
return candidate

for band_idx, name in enumerate(names):
attr = _default_attr_name() if name is None else name
self._attributes.add(attr)
for grid_x in range(self.width):
for grid_y in range(self.height):
setattr(
self.cells[grid_x][grid_y],
attr,
data[band_idx, self.height - grid_y - 1, grid_x],
)

def get_raster(self, attr_name: str | None = None) -> np.ndarray:
def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray:
"""
Return the values of given attribute.

:param str | None attr_name: Name of the attribute to be returned. If None,
returns all attributes. Default is None.
:return: The values of given attribute as a 2D numpy array with shape (1, height, width).
:param str | Sequence[str] | None attr_name: Name(s) of attributes to be returned.
If None, returns all attributes. Default is None.
:return: The values of given attribute(s) as a numpy array with shape
(bands, height, width).
:rtype: np.ndarray
"""

if attr_name is not None and attr_name not in self.attributes:
if isinstance(attr_name, str) and attr_name not in self.attributes:
raise ValueError(
f"Attribute {attr_name} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
missing = [name for name in attr_name if name not in self.attributes]
if missing:
raise ValueError(
f"Attribute {missing[0]} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if attr_name is None:
num_bands = len(self.attributes)
attr_names = self.attributes
elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
num_bands = len(attr_name)
attr_names = list(attr_name)
else:
num_bands = 1
attr_names = {attr_name}
attr_names = [attr_name]
data = np.empty((num_bands, self.height, self.width))
for ind, name in enumerate(attr_names):
for grid_x in range(self.width):
Expand Down Expand Up @@ -684,16 +735,18 @@ def from_file(
raster_file: str,
model: Model,
cell_cls: type[Cell] = Cell,
attr_name: str | None = None,
attr_name: str | Sequence[str] | None = None,
rio_opener: Callable | None = None,
) -> RasterLayer:
"""
Creates a RasterLayer from a raster file.

:param str raster_file: Path to the raster file.
:param Type[Cell] cell_cls: The class of the cells in the layer.
:param str | None attr_name: The name of the attribute to use for the cell values.
If None, a random name will be generated. Default is None.
:param str | Sequence[str] | None attr_name: Attribute name(s) to use for the cell
values. For multi-band rasters, pass a list of names with length equal to
the number of bands, or a single base name to be suffixed per band. If None,
names are generated. Default is None.
:param Callable | None rio_opener: A callable passed to Rasterio open() function.
"""

Expand All @@ -713,14 +766,17 @@ def from_file(
return obj

def to_file(
self, raster_file: str, attr_name: str | None = None, driver: str = "GTiff"
self,
raster_file: str,
attr_name: str | Sequence[str] | None = None,
driver: str = "GTiff",
) -> None:
"""
Writes a raster layer to a file.

:param str raster_file: The path to the raster file to write to.
:param str | None attr_name: The name of the attribute to write to the raster.
If None, all attributes are written. Default is None.
:param str | Sequence[str] | None attr_name: The name(s) of attributes to write
to the raster. If None, all attributes are written. Default is None.
:param str driver: The GDAL driver to use for writing the raster file.
Default is 'GTiff'. See GDAL docs at https://gdal.org/drivers/raster/index.html.
"""
Expand Down
Loading
Loading