Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions swiftgalaxy/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,158 @@ def wrap_box(self) -> None:
setattr(dataset, f"_{field_name}", field_data)
return

def get_bound_only_mask(self) -> MaskCollection:
"""
Get a ``bound_only`` mask for this galaxy.

The mask is aligned with the current particle selection without applying
it to loaded data.

The returned mask is evaluated lazily and aligned to currently selected
particles so it can be applied directly to currently loaded arrays.

Returns
-------
:class:`~swiftgalaxy.masks.MaskCollection`
The ``bound_only`` mask collection, aligned to the currently selected
particle indices and with the current extra mask applied.

Raises
------
RuntimeError
If no halo catalogue is associated with this :class:`SWIFTGalaxy`.
"""
if self.halo_catalogue is None:
raise RuntimeError(
"Cannot get a bound_only mask without an associated halo catalogue."
)

# Build the bound mask in spatial-only index space by temporarily
# disabling extra masking and restoring all cached field data afterwards.
original_extra_mask = self._extra_mask
spatial_only_extra_mask = MaskCollection._blank_from_mask_types(
self.metadata.present_group_names
)
cached_particle_fields = {}
try:
self._extra_mask = spatial_only_extra_mask

# Ensure bound-mask backends read in spatial space by clearing any
# cached, already-masked fields before evaluating lazy masks.
for group_name in self.metadata.present_group_names:
particle_dataset = getattr(self, group_name)._particle_dataset
particle_metadata = getattr(
particle_dataset.metadata, f"{group_name}_properties"
)
cached_particle_fields[group_name] = {
field_name: getattr(particle_dataset, f"_{field_name}")
for field_name in particle_metadata.field_names
}
for field_name in particle_metadata.field_names:
setattr(particle_dataset, f"_{field_name}", None)

bound_only_mask_spatial = self.halo_catalogue._generate_bound_only_mask(
self, mask_loaded=False
)
bound_only_mask_spatial_values = {
group_name: deepcopy(getattr(bound_only_mask_spatial, group_name).mask)
for group_name in self.metadata.present_group_names
}
finally:
# First, restore _extra_mask
self._extra_mask = original_extra_mask

# Now restore cached fields and fix any that were loaded during generation
for group_name, fields in cached_particle_fields.items():
particle_dataset = getattr(self, group_name)._particle_dataset
particle_dataset_helper = getattr(self, group_name)

for field_name, field_data_originally in fields.items():
field_currently_cached = getattr(
particle_dataset, f"_{field_name}", None
)

if field_data_originally is not None:
# Field was originally loaded, restore it
setattr(
particle_dataset, f"_{field_name}", field_data_originally
)
elif field_currently_cached is not None:
# Field was originally not loaded but got loaded during
# bound generation.
# Re-apply the mask so it is correct if the mask filters
# data.
masked_data = particle_dataset_helper._apply_data_mask(
field_currently_cached
)

# Only cache the masked result if the mask actually
# changed the size (i.e., filtered particles). If sizes
# are the same, the data was not supposed to be loaded,
# so do not cache it.
if (
hasattr(masked_data, "shape")
and hasattr(field_currently_cached, "shape")
and masked_data.shape != field_currently_cached.shape
):
# Mask filtered data, so cache the masked version
setattr(particle_dataset, f"_{field_name}", masked_data)
else:
# Mask didn't filter (or unknown), so clear the cache
# The field will be properly loaded/masked when accessed
setattr(particle_dataset, f"_{field_name}", None)
# If both original and current are None, leave it as None

# Map spatial-only bound mask to current selection space by lazy composition.
# For each particle type, create a lazy mask that maps the spatial bound indices
# to indices in the currently selected particles.
def n_spatial(group_name: str) -> int:
"""
Get number of particles after spatial masking for one particle type.

Parameters
----------
group_name : str
Particle group name.

Returns
-------
int
Number of particles in this group after spatial masking.
"""
if self._spatial_mask is None:
return int(getattr(self.metadata, f"n_{group_name}"))
return int(
np.sum(self._spatial_mask.get_masked_counts_offsets()[0][group_name])
)

current_bound_only_mask = {}
for group_name in self.metadata.present_group_names:
spatial_bound_mask = bound_only_mask_spatial_values[group_name]
current_extra_mask = getattr(self._extra_mask, group_name)

def lazy_map(
*,
_group_name: str = group_name,
_spatial_bound: np.ndarray = spatial_bound_mask,
_current_extra: LazyMask = current_extra_mask,
) -> np.ndarray:
"""Map spatial bound indices to current particle selection."""
n_spatial_particles = n_spatial(_group_name)
# Get which particles in spatial space are bound
spatial_bound_bool = np.zeros(n_spatial_particles, dtype=bool)
spatial_bound_bool[_spatial_bound] = True
# Get which spatial particles are in current selection
current_spatial_indices = np.arange(n_spatial_particles)[
_current_extra.mask
]
# Map: which current particles are bound?
return spatial_bound_bool[current_spatial_indices]

current_bound_only_mask[group_name] = LazyMask(mask_function=lazy_map)

return MaskCollection(**current_bound_only_mask)

def mask_particles(self, mask_collection: MaskCollection) -> None:
"""
Select a subset of the currently selected particles.
Expand Down
76 changes: 76 additions & 0 deletions tests/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,82 @@ def test_mask_combining_is_lazy(self, sg_soap):
# and check we haven't loaded the DM group IDs, just to be sure:
assert sg_soap.dark_matter._particle_dataset._group_nr_bound is None

def test_get_bound_only_mask_raises_without_halo_catalogue(self, sg_no_hf):
"""Check that getting a bound_only mask requires a halo catalogue."""
with pytest.raises(RuntimeError, match="without an associated halo catalogue"):
sg_no_hf.get_bound_only_mask()

def test_get_bound_only_mask_is_lazy(self, sg):
"""Check that creating a bound_only mask does not trigger data loading."""
sg_unbound = SWIFTGalaxy(
sg.snapshot_filename,
ToyHF(snapfile=sg.snapshot_filename, extra_mask=None),
transforms_like_coordinates={"coordinates", "extra_coordinates"},
transforms_like_velocities={"velocities", "extra_velocities"},
)
assert_no_data_loaded(sg_unbound)
current_bound_only = sg_unbound.get_bound_only_mask()
assert_no_data_loaded(sg_unbound)
# evaluate one particle type to ensure lazy mask works
assert current_bound_only.gas.mask.size > 0

def test_get_bound_only_mask_compatible_with_current_particles(self, sg):
"""Check that returned mask is directly applicable to currently selected data."""
sg_unbound = SWIFTGalaxy(
sg.snapshot_filename,
ToyHF(snapfile=sg.snapshot_filename, extra_mask=None),
transforms_like_coordinates={"coordinates", "extra_coordinates"},
transforms_like_velocities={"velocities", "extra_velocities"},
)
sg_bound = SWIFTGalaxy(
sg.snapshot_filename,
ToyHF(snapfile=sg.snapshot_filename, extra_mask="bound_only"),
transforms_like_coordinates={"coordinates", "extra_coordinates"},
transforms_like_velocities={"velocities", "extra_velocities"},
)

sg_unbound.mask_particles(
MaskCollection(
gas=np.s_[::3],
dark_matter=np.s_[::-2],
stars=np.s_[::2],
black_holes=np.s_[...],
)
)

current_bound_only = sg_unbound.get_bound_only_mask()
for ptype in sg_unbound.metadata.present_group_names:
current_ids = getattr(sg_unbound, ptype).particle_ids
expected_bound = np.isin(current_ids, getattr(sg_bound, ptype).particle_ids)
got_mask = getattr(current_bound_only, ptype).mask
assert got_mask.shape == current_ids.shape
assert np.array_equal(got_mask, expected_bound)

def test_get_bound_only_mask_relative_to_current_default(self, sg):
"""Check default bound-only mask is all-True for bound-only SWIFTGalaxy."""
current_bound_only = sg.get_bound_only_mask()
for ptype in sg.metadata.present_group_names:
got_mask = getattr(current_bound_only, ptype).mask
assert got_mask.shape == getattr(sg, ptype).particle_ids.shape
assert got_mask.dtype == bool
assert got_mask.all()

def test_get_bound_only_after_manual_masking(self, sg):
"""Check that we can get a bound_only mask after applying a manual mask."""
sg.mask_particles(
MaskCollection(
gas=np.s_[::3],
dark_matter=np.s_[::-2],
stars=np.s_[::2],
black_holes=np.s_[...],
)
)
current_bound_only = sg.get_bound_only_mask()
for ptype in sg.metadata.present_group_names:
got_mask = getattr(current_bound_only, ptype).mask
assert got_mask.shape == getattr(sg, ptype).particle_ids.shape
assert got_mask.dtype == bool


class TestMaskingParticleDatasets:
"""Test applying masks to particle datasets."""
Expand Down
Loading