diff --git a/swiftgalaxy/reader.py b/swiftgalaxy/reader.py index cf90e9c..f952238 100644 --- a/swiftgalaxy/reader.py +++ b/swiftgalaxy/reader.py @@ -2296,6 +2296,158 @@ def wrap_box(self) -> None: setattr(dataset, f"_{field_name}", field_data) return + def get_bound_only_mask(self) -> MaskCollection: + """ + Get a ``bound_only`` mask for this galaxy. + + The mask is aligned with the current particle selection without applying + it to loaded data. + + The returned mask is evaluated lazily and aligned to currently selected + particles so it can be applied directly to currently loaded arrays. + + Returns + ------- + :class:`~swiftgalaxy.masks.MaskCollection` + The ``bound_only`` mask collection, aligned to the currently selected + particle indices and with the current extra mask applied. + + Raises + ------ + RuntimeError + If no halo catalogue is associated with this :class:`SWIFTGalaxy`. + """ + if self.halo_catalogue is None: + raise RuntimeError( + "Cannot get a bound_only mask without an associated halo catalogue." + ) + + # Build the bound mask in spatial-only index space by temporarily + # disabling extra masking and restoring all cached field data afterwards. + original_extra_mask = self._extra_mask + spatial_only_extra_mask = MaskCollection._blank_from_mask_types( + self.metadata.present_group_names + ) + cached_particle_fields = {} + try: + self._extra_mask = spatial_only_extra_mask + + # Ensure bound-mask backends read in spatial space by clearing any + # cached, already-masked fields before evaluating lazy masks. + for group_name in self.metadata.present_group_names: + particle_dataset = getattr(self, group_name)._particle_dataset + particle_metadata = getattr( + particle_dataset.metadata, f"{group_name}_properties" + ) + cached_particle_fields[group_name] = { + field_name: getattr(particle_dataset, f"_{field_name}") + for field_name in particle_metadata.field_names + } + for field_name in particle_metadata.field_names: + setattr(particle_dataset, f"_{field_name}", None) + + bound_only_mask_spatial = self.halo_catalogue._generate_bound_only_mask( + self, mask_loaded=False + ) + bound_only_mask_spatial_values = { + group_name: deepcopy(getattr(bound_only_mask_spatial, group_name).mask) + for group_name in self.metadata.present_group_names + } + finally: + # First, restore _extra_mask + self._extra_mask = original_extra_mask + + # Now restore cached fields and fix any that were loaded during generation + for group_name, fields in cached_particle_fields.items(): + particle_dataset = getattr(self, group_name)._particle_dataset + particle_dataset_helper = getattr(self, group_name) + + for field_name, field_data_originally in fields.items(): + field_currently_cached = getattr( + particle_dataset, f"_{field_name}", None + ) + + if field_data_originally is not None: + # Field was originally loaded, restore it + setattr( + particle_dataset, f"_{field_name}", field_data_originally + ) + elif field_currently_cached is not None: + # Field was originally not loaded but got loaded during + # bound generation. + # Re-apply the mask so it is correct if the mask filters + # data. + masked_data = particle_dataset_helper._apply_data_mask( + field_currently_cached + ) + + # Only cache the masked result if the mask actually + # changed the size (i.e., filtered particles). If sizes + # are the same, the data was not supposed to be loaded, + # so do not cache it. + if ( + hasattr(masked_data, "shape") + and hasattr(field_currently_cached, "shape") + and masked_data.shape != field_currently_cached.shape + ): + # Mask filtered data, so cache the masked version + setattr(particle_dataset, f"_{field_name}", masked_data) + else: + # Mask didn't filter (or unknown), so clear the cache + # The field will be properly loaded/masked when accessed + setattr(particle_dataset, f"_{field_name}", None) + # If both original and current are None, leave it as None + + # Map spatial-only bound mask to current selection space by lazy composition. + # For each particle type, create a lazy mask that maps the spatial bound indices + # to indices in the currently selected particles. + def n_spatial(group_name: str) -> int: + """ + Get number of particles after spatial masking for one particle type. + + Parameters + ---------- + group_name : str + Particle group name. + + Returns + ------- + int + Number of particles in this group after spatial masking. + """ + if self._spatial_mask is None: + return int(getattr(self.metadata, f"n_{group_name}")) + return int( + np.sum(self._spatial_mask.get_masked_counts_offsets()[0][group_name]) + ) + + current_bound_only_mask = {} + for group_name in self.metadata.present_group_names: + spatial_bound_mask = bound_only_mask_spatial_values[group_name] + current_extra_mask = getattr(self._extra_mask, group_name) + + def lazy_map( + *, + _group_name: str = group_name, + _spatial_bound: np.ndarray = spatial_bound_mask, + _current_extra: LazyMask = current_extra_mask, + ) -> np.ndarray: + """Map spatial bound indices to current particle selection.""" + n_spatial_particles = n_spatial(_group_name) + # Get which particles in spatial space are bound + spatial_bound_bool = np.zeros(n_spatial_particles, dtype=bool) + spatial_bound_bool[_spatial_bound] = True + # Get which spatial particles are in current selection + current_spatial_indices = np.arange(n_spatial_particles)[ + _current_extra.mask + ] + # Map: which current particles are bound? + return spatial_bound_bool[current_spatial_indices] + + current_bound_only_mask[group_name] = LazyMask(mask_function=lazy_map) + + return MaskCollection(**current_bound_only_mask) + def mask_particles(self, mask_collection: MaskCollection) -> None: """ Select a subset of the currently selected particles. diff --git a/tests/test_masking.py b/tests/test_masking.py index 4afbc6c..36a33af 100644 --- a/tests/test_masking.py +++ b/tests/test_masking.py @@ -250,6 +250,82 @@ def test_mask_combining_is_lazy(self, sg_soap): # and check we haven't loaded the DM group IDs, just to be sure: assert sg_soap.dark_matter._particle_dataset._group_nr_bound is None + def test_get_bound_only_mask_raises_without_halo_catalogue(self, sg_no_hf): + """Check that getting a bound_only mask requires a halo catalogue.""" + with pytest.raises(RuntimeError, match="without an associated halo catalogue"): + sg_no_hf.get_bound_only_mask() + + def test_get_bound_only_mask_is_lazy(self, sg): + """Check that creating a bound_only mask does not trigger data loading.""" + sg_unbound = SWIFTGalaxy( + sg.snapshot_filename, + ToyHF(snapfile=sg.snapshot_filename, extra_mask=None), + transforms_like_coordinates={"coordinates", "extra_coordinates"}, + transforms_like_velocities={"velocities", "extra_velocities"}, + ) + assert_no_data_loaded(sg_unbound) + current_bound_only = sg_unbound.get_bound_only_mask() + assert_no_data_loaded(sg_unbound) + # evaluate one particle type to ensure lazy mask works + assert current_bound_only.gas.mask.size > 0 + + def test_get_bound_only_mask_compatible_with_current_particles(self, sg): + """Check that returned mask is directly applicable to currently selected data.""" + sg_unbound = SWIFTGalaxy( + sg.snapshot_filename, + ToyHF(snapfile=sg.snapshot_filename, extra_mask=None), + transforms_like_coordinates={"coordinates", "extra_coordinates"}, + transforms_like_velocities={"velocities", "extra_velocities"}, + ) + sg_bound = SWIFTGalaxy( + sg.snapshot_filename, + ToyHF(snapfile=sg.snapshot_filename, extra_mask="bound_only"), + transforms_like_coordinates={"coordinates", "extra_coordinates"}, + transforms_like_velocities={"velocities", "extra_velocities"}, + ) + + sg_unbound.mask_particles( + MaskCollection( + gas=np.s_[::3], + dark_matter=np.s_[::-2], + stars=np.s_[::2], + black_holes=np.s_[...], + ) + ) + + current_bound_only = sg_unbound.get_bound_only_mask() + for ptype in sg_unbound.metadata.present_group_names: + current_ids = getattr(sg_unbound, ptype).particle_ids + expected_bound = np.isin(current_ids, getattr(sg_bound, ptype).particle_ids) + got_mask = getattr(current_bound_only, ptype).mask + assert got_mask.shape == current_ids.shape + assert np.array_equal(got_mask, expected_bound) + + def test_get_bound_only_mask_relative_to_current_default(self, sg): + """Check default bound-only mask is all-True for bound-only SWIFTGalaxy.""" + current_bound_only = sg.get_bound_only_mask() + for ptype in sg.metadata.present_group_names: + got_mask = getattr(current_bound_only, ptype).mask + assert got_mask.shape == getattr(sg, ptype).particle_ids.shape + assert got_mask.dtype == bool + assert got_mask.all() + + def test_get_bound_only_after_manual_masking(self, sg): + """Check that we can get a bound_only mask after applying a manual mask.""" + sg.mask_particles( + MaskCollection( + gas=np.s_[::3], + dark_matter=np.s_[::-2], + stars=np.s_[::2], + black_holes=np.s_[...], + ) + ) + current_bound_only = sg.get_bound_only_mask() + for ptype in sg.metadata.present_group_names: + got_mask = getattr(current_bound_only, ptype).mask + assert got_mask.shape == getattr(sg, ptype).particle_ids.shape + assert got_mask.dtype == bool + class TestMaskingParticleDatasets: """Test applying masks to particle datasets."""