diff --git a/swiftgalaxy/demo_data.py b/swiftgalaxy/demo_data.py index cae95ef..edc85ab 100644 --- a/swiftgalaxy/demo_data.py +++ b/swiftgalaxy/demo_data.py @@ -507,7 +507,10 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: return swift_mask def _generate_bound_only_mask( - self, sg: SWIFTGalaxy, mask_loaded: bool = True + self, + sg: SWIFTGalaxy, + mask_loaded: bool = True, + load_masked: bool = False, ) -> MaskCollection: """ Evaluate the extra mask selecting particles belonging to the target galaxy. @@ -524,13 +527,17 @@ def _generate_bound_only_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Returns ------- :class:`~swiftgalaxy.masks.MaskCollection` The extra mask. """ - def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask: + def generate_lazy_mask(group_name: str) -> LazyMask: """ Generate a function that evaluates a mask for bound particles. @@ -541,10 +548,6 @@ def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask: group_name : :obj:`str` The particle type to evaluate a mask for. - mask_loaded : :obj:`bool` - Whether to mask the data loaded while constructing the mask. The iterator - wants to switch this off. - Returns ------- Callable @@ -555,10 +558,6 @@ def lazy_mask() -> Union[NDArray, slice, EllipsisType]: """ "Evaluate" a mask that selects bound particles. - In reality we know what the mask is a priori. We pretend that we need to - load the particle ids so that we can test the behaviour of a dataset - loaded while constructing the mask. - This function must optionally mask the data (``particle_ids``) that it has loaded. @@ -567,19 +566,78 @@ def lazy_mask() -> Union[NDArray, slice, EllipsisType]: :class:`~numpy.ndarray`, :obj:`slice` or :obj:`Ellipsis` The mask that selects bound particles. """ - getattr( - getattr(sg, group_name)._particle_dataset, - sg.id_particle_dataset_name, - ) # load the ids assert isinstance(self._mask_index, int) # placate mypy - mask = { - "gas": (np.s_[-_n_g_1:], np.s_[-_n_g_2:])[self._mask_index], - "dark_matter": (np.s_[-_n_dm_1:], np.s_[-_n_dm_2:])[ - self._mask_index - ], - "stars": np.s_[...], - "black_holes": np.s_[...], - }[group_name] + if group_name == "gas": + bound_ids = [ + # gas IDs group 0 + np.arange(1 + _n_g_b // 2, 1 + _n_g_b // 2 + _n_g_1, dtype=int), + # gas IDs group 1 + np.arange(1 + _n_g_b + _n_g_1, 1 + _n_g_all, dtype=int), + ][self._mask_index] + elif group_name == "dark_matter": + bound_ids = [ + # dm IDs group 0 + np.arange( + 1 + _n_g_all + _n_dm_b // 2, + 1 + _n_g_all + _n_dm_b // 2 + _n_dm_1, + dtype=int, + ), + # dm IDs group 1 + np.arange( + 1 + _n_g_all + _n_dm_b + _n_dm_1, + 1 + _n_g_all + _n_dm_all, + dtype=int, + ), + ][self._mask_index] + elif group_name == "stars": + bound_ids = [ + # star IDs group 0 + np.arange( + 1 + _n_g_all + _n_dm_all, + 1 + _n_g_all + _n_dm_all + _n_s_1, + dtype=int, + ), + # star IDs group 1 + np.arange( + 1 + _n_g_all + _n_dm_all + _n_s_1, + 1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2, + dtype=int, + ), + ][self._mask_index] + elif group_name == "black_holes": + bound_ids = [ + # bh IDs group 0 + np.arange( + 1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2, + 1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2 + _n_bh_1, + dtype=int, + ), + # bh IDs group 1 + np.arange( + 1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2 + _n_bh_1, + 1 + + _n_g_all + + _n_dm_all + + _n_s_1 + + _n_s_2 + + _n_bh_1 + + _n_bh_2, + dtype=int, + ), + ][self._mask_index] + if load_masked: + particle_ids = getattr( + getattr(sg, group_name), sg.id_particle_dataset_name + ) + else: + particle_ids = getattr( + getattr(sg, group_name)._particle_dataset, + sg.id_particle_dataset_name, + ) + mask = np.isin( + particle_ids.to_physical_value(u.dimensionless), + bound_ids, + ) if mask_loaded: # mask the particle_ids setattr( @@ -601,7 +659,7 @@ def lazy_mask() -> Union[NDArray, slice, EllipsisType]: return MaskCollection( **{ - group_name: generate_lazy_mask(group_name, mask_loaded) + group_name: generate_lazy_mask(group_name) for group_name in sg.metadata.present_group_names } ) diff --git a/swiftgalaxy/halo_catalogues.py b/swiftgalaxy/halo_catalogues.py index 8cfd33d..1151654 100644 --- a/swiftgalaxy/halo_catalogues.py +++ b/swiftgalaxy/halo_catalogues.py @@ -264,7 +264,7 @@ def _get_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: return self._generate_spatial_mask(snapshot_filename) def _get_extra_mask( - self, sg: "SWIFTGalaxy", mask_loaded: bool = True + self, sg: "SWIFTGalaxy", mask_loaded: bool = True, load_masked: bool = False ) -> MaskCollection: """ Evaluate the extra (in the sense of in addition to spatial masking) mask. @@ -283,6 +283,10 @@ def _get_extra_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Returns ------- :class:`~swiftgalaxy.masks.MaskCollection` @@ -293,7 +297,9 @@ def _get_extra_mask( raise RuntimeError( "Halo catalogue has multiple galaxies and is not currently masked." ) - return self._generate_bound_only_mask(sg, mask_loaded=mask_loaded) + return self._generate_bound_only_mask( + sg, mask_loaded=mask_loaded, load_masked=load_masked + ) elif self.extra_mask is None: return MaskCollection._blank_from_mask_types( sg.metadata.present_group_names @@ -455,7 +461,7 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: @abstractmethod def _generate_bound_only_mask( - self, sg: "SWIFTGalaxy", mask_loaded: bool = True + self, sg: "SWIFTGalaxy", mask_loaded: bool = True, load_masked: bool = False ) -> MaskCollection: """ Abstract method. @@ -477,6 +483,10 @@ def _generate_bound_only_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Returns ------- :class:`~swiftgalaxy.masks.MaskCollection` @@ -784,7 +794,7 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: return sm def _generate_bound_only_mask( - self, sg: "SWIFTGalaxy", mask_loaded: bool = True + self, sg: "SWIFTGalaxy", mask_loaded: bool = True, load_masked: bool = False ) -> MaskCollection: """ Evaluate the mask to select gravitationally bound particles. @@ -803,6 +813,10 @@ def _generate_bound_only_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Returns ------- :class:`~swiftgalaxy.masks.MaskCollection` @@ -810,7 +824,7 @@ def _generate_bound_only_mask( set of particles. """ - def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask: + def generate_lazy_mask(group_name: str) -> LazyMask: """ Generate a function that evaluates a mask for bound particles. @@ -821,10 +835,6 @@ def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask: group_name : :obj:`str` The particle type to evaluate a mask for. - mask_loaded : :obj:`bool` - Whether to mask the data loaded while constructing the mask. The iterator - wants to switch this off. - Returns ------- Callable @@ -845,11 +855,16 @@ def lazy_mask() -> NDArray: :class:`~numpy.ndarray` The mask that selects bound particles. """ - mask = getattr( - sg, group_name - )._particle_dataset.group_nr_bound.to_value( + group_nr_bound = ( + getattr(sg, group_name).group_nr_bound + if load_masked + else getattr(sg, group_name)._particle_dataset.group_nr_bound + ) + mask = group_nr_bound.to_physical_value( u.dimensionless - ) == self.input_halos.halo_catalogue_index.to_value(u.dimensionless) + ) == self.input_halos.halo_catalogue_index.to_physical_value( + u.dimensionless + ) if mask_loaded: # mask the group_nr_bound array that we loaded getattr(sg, group_name)._particle_dataset._group_nr_bound = getattr( @@ -861,7 +876,7 @@ def lazy_mask() -> NDArray: return MaskCollection( **{ - group_name: generate_lazy_mask(group_name, mask_loaded) + group_name: generate_lazy_mask(group_name) for group_name in sg.metadata.present_group_names } ) @@ -1153,7 +1168,7 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: ) def _generate_bound_only_mask( - self, sg: "SWIFTGalaxy", mask_loaded: bool = True + self, sg: "SWIFTGalaxy", mask_loaded: bool = True, load_masked: bool = False ) -> MaskCollection: """ Evaluate the mask to select gravitationally bound particles. @@ -1173,6 +1188,10 @@ def _generate_bound_only_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Returns ------- :class:`~swiftgalaxy.masks.MaskCollection` @@ -1183,7 +1202,7 @@ def _generate_bound_only_mask( # because we need a lazy version and to bypass swiftgalaxy masking on read # while we construct the mask - def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask: + def generate_lazy_mask(group_name: str) -> LazyMask: """ Generate a function that evaluates a mask for bound particles. @@ -1194,10 +1213,6 @@ def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask: group_name : :obj:`str` The particle type to evaluate a mask for. - mask_loaded : :obj:`bool` - Whether to mask the data loaded while constructing the mask. The iterator - wants to switch this off. - Returns ------- Callable @@ -1229,8 +1244,13 @@ def lazy_mask() -> NDArray: if not particles.groups_instance.catalogue.units.comoving else 1.0 ) + particle_ids = ( + getattr(sg, group_name).particle_ids + if load_masked + else getattr(sg, group_name)._particle_dataset.particle_ids + ) mask = np.isin( - getattr(sg, group_name)._particle_dataset.particle_ids, + particle_ids, cosmo_array( particles.particle_ids, comoving=False, @@ -1249,7 +1269,7 @@ def lazy_mask() -> NDArray: return MaskCollection( **{ - group_name: generate_lazy_mask(group_name, mask_loaded) + group_name: generate_lazy_mask(group_name) for group_name in sg.metadata.present_group_names } ) @@ -1737,7 +1757,7 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: return sm def _generate_bound_only_mask( - self, sg: "SWIFTGalaxy", mask_loaded: bool = True + self, sg: "SWIFTGalaxy", mask_loaded: bool = True, load_masked: bool = False ) -> MaskCollection: """ Evaluate the mask to select gravitationally bound particles. @@ -1757,6 +1777,10 @@ def _generate_bound_only_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Returns ------- :class:`~swiftgalaxy.masks.MaskCollection` @@ -1806,9 +1830,7 @@ def in_one_of_ranges( "black_holes": "bhlist", } - def generate_lazy_mask( - group_name: str, list_name: str, mask_loaded: bool - ) -> LazyMask: + def generate_lazy_mask(group_name: str) -> LazyMask: """ Generate a function that evaluates a mask for bound particles. @@ -1819,14 +1841,6 @@ def generate_lazy_mask( group_name : :obj:`str` The particle type to evaluate a mask for. - list_name : :obj:`str` - The name of the list in the caesar catalogue that stores the membership - information. - - mask_loaded : :obj:`bool` - Whether to mask the data loaded while constructing the mask. The iterator - wants to switch this off. - Returns ------- Callable @@ -1847,30 +1861,41 @@ def lazy_mask() -> Union[NDArray, slice]: :class:`~numpy.ndarray` The mask that selects bound particles. """ - if not hasattr(cat, list_name): + if not hasattr(cat, list_names[group_name]): return null_slice - mask = getattr(cat, list_name) + mask = getattr(cat, list_names[group_name]) mask = mask[ in_one_of_ranges(mask, getattr(sg._spatial_mask, group_name)) ] + spatial_ranges = np.concatenate( + [ + np.arange(start, end) + for start, end in getattr(sg._spatial_mask, group_name) + ] + ) mask = np.isin( - np.concatenate( - [ - np.arange(start, end) - for start, end in getattr(sg._spatial_mask, group_name) - ] - ), + spatial_ranges, mask, ) - return mask + if load_masked: + # RecursionError unless the _extra_mask was previously evaluated, + # but if we've combined masks before that happens it's too late now. + if not getattr(sg._extra_mask, group_name)._evaluated: + raise RuntimeError( + f"Cannot evaluate mask for {group_name}. Try loading at " + "least one particle property (e.g. " + f"`{group_name}.particle_ids`) before calling " + "`get_bound_only_mask`." + ) + return mask[getattr(sg._extra_mask, group_name).mask] + else: + return mask return LazyMask(mask_function=lazy_mask) return MaskCollection( **{ - group_name: generate_lazy_mask( - group_name, list_names[group_name], mask_loaded - ) + group_name: generate_lazy_mask(group_name) for group_name in sg.metadata.present_group_names } ) @@ -2276,7 +2301,7 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: return sm def _generate_bound_only_mask( - self, sg: "SWIFTGalaxy", mask_loaded: bool = True + self, sg: "SWIFTGalaxy", mask_loaded: bool = True, load_masked: bool = False ) -> MaskCollection: """ Undefined for :class:`~swiftgalaxy.halo_catalogues.Standalone`. @@ -2295,6 +2320,10 @@ def _generate_bound_only_mask( Whether to mask any data loaded while creating the mask. The iterator wants to switch this off. + load_masked : :obj:`bool` + If ``True``, any data loaded to evaluate the mask are loaded in their masked + state, otherwise masking is bypassed during loading. + Raises ------ NotImplementedError : always raised if this function is called. diff --git a/swiftgalaxy/masks.py b/swiftgalaxy/masks.py index 579a9e2..5ebf87c 100644 --- a/swiftgalaxy/masks.py +++ b/swiftgalaxy/masks.py @@ -83,7 +83,7 @@ def _evaluate(self) -> None: self._mask = self._mask_function() self._evaluated = True - def _make_combinable(self, *, sg: "SWIFTGalaxy", mask_type: str) -> None: + def _ensure_combinable(self, *, sg: "SWIFTGalaxy", mask_type: str) -> None: """ Ensure that the mask can have an arbitrary second mask applied to combine them. @@ -100,6 +100,8 @@ def _make_combinable(self, *, sg: "SWIFTGalaxy", mask_type: str) -> None: The :mod:`swiftsimio` group name that this mask is for (e.g. ``"gas"``, ``"dark_matter"``, etc.), used to look up particle count metadata. """ + if self._combinable: + return # need to convert to an integer mask to combine # (boolean is insufficient in case of re-ordering masks) if sg._spatial_mask is None: @@ -145,6 +147,7 @@ def _combined_with( ~swiftgalaxy.masks.LazyMask The combined mask. """ + self._ensure_combinable(sg=sg, mask_type=mask_type) # may as well always defer evaluating combination until it's asked for def lazy_mask() -> NDArray: @@ -156,8 +159,6 @@ def lazy_mask() -> NDArray: :class:`~numpy.ndarray` The combined mask. """ - if not self._combinable: - self._make_combinable(sg=sg, mask_type=mask_type) assert isinstance(self.mask, np.ndarray) # placate mypy assert self.mask.dtype == int return self.mask[other_mask.mask] diff --git a/swiftgalaxy/reader.py b/swiftgalaxy/reader.py index cf90e9c..29da04c 100644 --- a/swiftgalaxy/reader.py +++ b/swiftgalaxy/reader.py @@ -71,16 +71,16 @@ def _apply_box_wrap( :class:`~swiftsimio.objects.cosmo_array` The coordinates wrapped to lie within the box dimensions. """ - rotation_is_identity = ( + _rotation_is_identity = ( True if current_transform is None else current_transform.rotation.approx_equal(Rotation.identity()) ) # in scipy 1.16 approx_equal returns bool, in 1.17 returns array of bool, so: rotation_is_identity = ( - rotation_is_identity.all() - if hasattr(rotation_is_identity, "all") - else rotation_is_identity + _rotation_is_identity.all() + if hasattr(_rotation_is_identity, "all") + else _rotation_is_identity ) if boxsize is None: return coords @@ -1690,8 +1690,8 @@ def _copyinit( coordinate_frame_from: Optional["SWIFTGalaxy"] = None, _spatial_mask: Optional[SWIFTMask] = None, _extra_mask: Optional[MaskCollection] = None, - _coordinate_like_transform: Optional[np.ndarray] = None, - _velocity_like_transform: Optional[np.ndarray] = None, + _coordinate_like_transform: Optional[RigidTransform] = None, + _velocity_like_transform: Optional[RigidTransform] = None, _data_server: Optional["SWIFTGalaxy"] = None, ) -> "SWIFTGalaxy": """ @@ -2296,6 +2296,35 @@ def wrap_box(self) -> None: setattr(dataset, f"_{field_name}", field_data) return + def get_bound_only_mask(self) -> MaskCollection: + """ + Get a ``bound_only`` mask for this galaxy. + + The mask is aligned with the current particle selection without applying + it to loaded data. + + The returned mask is evaluated lazily and aligned to currently selected + particles so it can be applied directly to currently loaded arrays. + + Returns + ------- + :class:`~swiftgalaxy.masks.MaskCollection` + The ``bound_only`` mask collection, aligned to the currently selected + particle indices and with the current extra mask applied. + + Raises + ------ + RuntimeError + If no halo catalogue is associated with this :class:`SWIFTGalaxy`. + """ + if self.halo_catalogue is None: + raise RuntimeError( + "Cannot get a bound_only mask without an associated halo catalogue." + ) + return self.halo_catalogue._generate_bound_only_mask( + self, mask_loaded=False, load_masked=True + ) + def mask_particles(self, mask_collection: MaskCollection) -> None: """ Select a subset of the currently selected particles. diff --git a/tests/conftest.py b/tests/conftest.py index 97b6920..26e015c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1031,33 +1031,36 @@ def sg_hf(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> SWIFTGa ) _remove_toyvr(filebase=toyvr_filebase) elif request.param == "sa": - yield Standalone( - extra_mask=MaskCollection( - gas=np.s_[_n_g_b // 2 :], - dark_matter=np.s_[_n_dm_b // 2 :], - stars=None, - black_holes=None, - ), - centre=cosmo_array( - [_centre_1, _centre_1, _centre_1], - u.Mpc, - comoving=True, - scale_factor=1.0, - scale_exponent=1, - ), - velocity_centre=cosmo_array( - [_vcentre_1, _vcentre_1, _vcentre_1], - u.km / u.s, - comoving=True, - scale_factor=1.0, - scale_exponent=0, - ), - spatial_offsets=cosmo_array( - [[-1, 1], [-1, 1], [-1, 1]], - u.Mpc, - comoving=True, - scale_factor=1.0, - scale_exponent=1, + yield SWIFTGalaxy( + toysnap_filename, + Standalone( + extra_mask=MaskCollection( + gas=np.s_[_n_g_b // 2 :], + dark_matter=np.s_[_n_dm_b // 2 :], + stars=None, + black_holes=None, + ), + centre=cosmo_array( + [_centre_1, _centre_1, _centre_1], + u.Mpc, + comoving=True, + scale_factor=1.0, + scale_exponent=1, + ), + velocity_centre=cosmo_array( + [_vcentre_1, _vcentre_1, _vcentre_1], + u.km / u.s, + comoving=True, + scale_factor=1.0, + scale_exponent=0, + ), + spatial_offsets=cosmo_array( + [[-1, 1], [-1, 1], [-1, 1]], + u.Mpc, + comoving=True, + scale_factor=1.0, + scale_exponent=1, + ), ), ) _remove_toysnap(snapfile=toysnap_filename) diff --git a/tests/test_halo_catalogues.py b/tests/test_halo_catalogues.py index 2133c50..7675e71 100644 --- a/tests/test_halo_catalogues.py +++ b/tests/test_halo_catalogues.py @@ -653,7 +653,9 @@ def test_with_swiftgalaxies(self, sgs_vr): ), ) for ptype in _present_particle_types.values(): - getattr(sg._extra_mask, ptype)._make_combinable(sg=sg, mask_type=ptype) + getattr(sg._extra_mask, ptype)._ensure_combinable( + sg=sg, mask_type=ptype + ) assert np.all( getattr(sg_from_sgs._extra_mask, ptype).mask == getattr(sg._extra_mask, ptype).mask @@ -893,7 +895,9 @@ def test_with_swiftgalaxies(self, sgs_caesar): ), ) for ptype in _present_particle_types.values(): - getattr(sg._extra_mask, ptype)._make_combinable(sg=sg, mask_type=ptype) + getattr(sg._extra_mask, ptype)._ensure_combinable( + sg=sg, mask_type=ptype + ) assert np.all( getattr(sg_from_sgs._extra_mask, ptype).mask == getattr(sg._extra_mask, ptype).mask @@ -1354,7 +1358,9 @@ def test_with_swiftgalaxies(self, sgs_soap): ), ) for ptype in _present_particle_types.values(): - getattr(sg._extra_mask, ptype)._make_combinable(sg=sg, mask_type=ptype) + getattr(sg._extra_mask, ptype)._ensure_combinable( + sg=sg, mask_type=ptype + ) assert np.all( getattr(sg_from_sgs._extra_mask, ptype).mask == getattr(sg._extra_mask, ptype).mask diff --git a/tests/test_iterator.py b/tests/test_iterator.py index 09cd0e7..d3c31ba 100644 --- a/tests/test_iterator.py +++ b/tests/test_iterator.py @@ -238,7 +238,9 @@ def test_iterate(self, sgs): ), ) for ptype in _present_particle_types.values(): - getattr(sg._extra_mask, ptype)._make_combinable(sg=sg, mask_type=ptype) + getattr(sg._extra_mask, ptype)._ensure_combinable( + sg=sg, mask_type=ptype + ) assert np.all( getattr(sg_from_sgs._extra_mask, ptype).mask == getattr(sg._extra_mask, ptype).mask diff --git a/tests/test_masking.py b/tests/test_masking.py index 4afbc6c..2954f9a 100644 --- a/tests/test_masking.py +++ b/tests/test_masking.py @@ -7,6 +7,7 @@ from unyt.testing import assert_allclose_units from swiftsimio import cosmo_quantity from swiftgalaxy import MaskCollection, SWIFTGalaxy +from swiftgalaxy.halo_catalogues import Standalone, Caesar from swiftgalaxy.demo_data import ( ToyHF, _present_particle_types, @@ -60,12 +61,12 @@ def test_getattr_masking(self, sg, particle_name): @pytest.mark.parametrize("before_load", (True, False)) def test_reordering_slice_mask(self, sg, particle_name, before_load): """Test whether a slice mask that re-orders elements works.""" + em_before = sg._extra_mask mask = np.s_[::-1] ids_before = getattr(sg, particle_name).particle_ids if before_load: getattr(sg, particle_name)._particle_dataset._particle_ids = None - del getattr(sg._extra_mask, particle_name)._mask - getattr(sg._extra_mask, particle_name)._evaluated = False + sg._extra_mask = em_before sg.mask_particles(MaskCollection(**{particle_name: mask})) ids = getattr(sg, particle_name).particle_ids assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0) @@ -74,6 +75,7 @@ def test_reordering_slice_mask(self, sg, particle_name, before_load): @pytest.mark.parametrize("before_load", (True, False)) def test_reordering_int_mask(self, sg, particle_name, before_load): """Test an integer array mask that re-orders elements and changes the length.""" + em_before = sg._extra_mask ids_before = getattr(sg, particle_name).particle_ids mask = np.arange(ids_before.size) # randomize order (in-place operation) @@ -82,8 +84,7 @@ def test_reordering_int_mask(self, sg, particle_name, before_load): mask = mask[: mask.size // 2] if before_load: getattr(sg, particle_name)._particle_dataset._particle_ids = None - del getattr(sg._extra_mask, particle_name)._mask - getattr(sg._extra_mask, particle_name)._evaluated = False + sg._extra_mask = em_before sg.mask_particles(MaskCollection(**{particle_name: mask})) ids = getattr(sg, particle_name).particle_ids assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0) @@ -92,26 +93,41 @@ def test_reordering_int_mask(self, sg, particle_name, before_load): @pytest.mark.parametrize("before_load", (True, False)) def test_bool_mask(self, sg, particle_name, before_load): """Test whether a boolean array mask works.""" + em_before = sg._extra_mask ids_before = getattr(sg, particle_name).particle_ids # randomly keep about half of particles mask = np.random.rand(ids_before.size) > 0.5 if before_load: getattr(sg, particle_name)._particle_dataset._particle_ids = None - del getattr(sg._extra_mask, particle_name)._mask - getattr(sg._extra_mask, particle_name)._evaluated = False + sg._extra_mask = em_before sg.mask_particles(MaskCollection(**{particle_name: mask})) ids = getattr(sg, particle_name).particle_ids assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0) + @pytest.mark.parametrize("before_load", (True, False)) + def test_data_masked(self, sg, before_load): + """Test that data get masked.""" + em_before = sg._extra_mask + masses_before = sg.gas.masses + mask = np.random.rand(masses_before.size) > 0.5 + if before_load: + sg.gas._particle_dataset._masses = None + sg._extra_mask = em_before + sg.mask_particles(MaskCollection(**{"gas": mask})) + masses = sg.gas.masses + assert_allclose_units( + masses_before[mask], masses, rtol=reltol_nd, atol=abstol_nd + ) + @pytest.mark.parametrize("before_load", (True, False)) def test_namedcolumn_masked(self, sg, before_load): """Test that named columns get masked too.""" + em_before = sg._extra_mask neutral_before = sg.gas.hydrogen_ionization_fractions.neutral mask = np.random.rand(neutral_before.size) > 0.5 if before_load: sg.gas.hydrogen_ionization_fractions._named_column_dataset._neutral = None - del sg._extra_mask.gas._mask - sg._extra_mask.gas._evaluated = False + sg._extra_mask = em_before sg.mask_particles(MaskCollection(**{"gas": mask})) neutral = sg.gas.hydrogen_ionization_fractions.neutral assert_allclose_units( @@ -205,11 +221,11 @@ def test_chained_masking(self, sg, before_load): Check both the case with (sg) and without (sg_no_hf) a spatial mask. """ + em_before = sg._extra_mask ids_before_sg = sg.gas.particle_ids if before_load: sg.gas._particle_dataset._particle_ids = None - del sg._extra_mask.gas._mask - sg._extra_mask.gas._evaluated = False + sg._extra_mask = em_before sg.mask_particles(MaskCollection(gas=np.s_[::2])) sg.mask_particles(MaskCollection(gas=np.s_[::2])) assert_allclose_units(ids_before_sg[::2][::2], sg.gas.particle_ids) @@ -250,6 +266,102 @@ def test_mask_combining_is_lazy(self, sg_soap): # and check we haven't loaded the DM group IDs, just to be sure: assert sg_soap.dark_matter._particle_dataset._group_nr_bound is None + @pytest.mark.parametrize("load_before", (True, False)) + def test_get_bound_only_mask(self, sg_hf, load_before): + """Check applying a bound_only mask from any catalogue (not Standalone).""" + if load_before: + for ptype in sg_hf.metadata.present_group_names: + getattr(sg_hf, ptype).masses + if isinstance(sg_hf.halo_catalogue, Standalone): + with pytest.raises(NotImplementedError): + sg_hf.get_bound_only_mask() + return + sg_hf.mask_particles(sg_hf.get_bound_only_mask()) + if isinstance(sg_hf.halo_catalogue, Caesar) and not load_before: + with pytest.raises(RuntimeError): + for ptype in sg_hf.metadata.present_group_names: + getattr(sg_hf, ptype).masses + return + for ptype in sg_hf.metadata.present_group_names: + getattr(sg_hf, ptype).masses + + def test_get_bound_only_mask_raises_without_halo_catalogue(self, sg_no_hf): + """Check that getting a bound_only mask requires a halo catalogue.""" + with pytest.raises(RuntimeError, match="without an associated halo catalogue"): + sg_no_hf.get_bound_only_mask() + + def test_get_bound_only_mask_is_lazy(self, sg): + """Check that creating a bound_only mask does not trigger data loading.""" + sg_unbound = SWIFTGalaxy( + sg.snapshot_filename, + ToyHF(snapfile=sg.snapshot_filename, extra_mask=None), + transforms_like_coordinates={"coordinates", "extra_coordinates"}, + transforms_like_velocities={"velocities", "extra_velocities"}, + ) + assert_no_data_loaded(sg_unbound) + current_bound_only = sg_unbound.get_bound_only_mask() + assert_no_data_loaded(sg_unbound) + # evaluate one particle type to ensure lazy mask works + assert current_bound_only.gas.mask.size > 0 + + def test_get_bound_only_mask_compatible_with_current_particles(self, sg): + """Check that returned mask is directly applicable to currently selected data.""" + sg_unbound = SWIFTGalaxy( + sg.snapshot_filename, + ToyHF(snapfile=sg.snapshot_filename, extra_mask=None), + transforms_like_coordinates={"coordinates", "extra_coordinates"}, + transforms_like_velocities={"velocities", "extra_velocities"}, + ) + sg_bound = SWIFTGalaxy( + sg.snapshot_filename, + ToyHF(snapfile=sg.snapshot_filename, extra_mask="bound_only"), + transforms_like_coordinates={"coordinates", "extra_coordinates"}, + transforms_like_velocities={"velocities", "extra_velocities"}, + ) + + sg_unbound.mask_particles( + MaskCollection( + gas=np.s_[::3], + dark_matter=np.s_[::-2], + stars=np.s_[::2], + black_holes=np.s_[...], + ) + ) + + current_bound_only = sg_unbound.get_bound_only_mask() + for ptype in sg_unbound.metadata.present_group_names: + current_ids = getattr(sg_unbound, ptype).particle_ids + expected_bound = np.isin(current_ids, getattr(sg_bound, ptype).particle_ids) + got_mask = getattr(current_bound_only, ptype).mask + assert got_mask.shape == current_ids.shape + assert np.array_equal(got_mask, expected_bound) + + def test_get_bound_only_mask_relative_to_current_default(self, sg): + """Check default bound-only mask is all-True for bound-only SWIFTGalaxy.""" + current_bound_only = sg.get_bound_only_mask() + for ptype in sg.metadata.present_group_names: + got_mask = getattr(current_bound_only, ptype).mask + assert got_mask.shape == getattr(sg, ptype).particle_ids.shape + assert got_mask.dtype == bool + assert got_mask.all() + + def test_get_bound_only_after_manual_masking(self, sg): + """Check that we can get a bound_only mask after applying a manual mask.""" + sg.mask_particles( + MaskCollection( + gas=np.s_[::3], + dark_matter=np.s_[::-2], + stars=np.s_[::2], + black_holes=np.s_[...], + ) + ) + current_bound_only = sg.get_bound_only_mask() + for ptype in sg.metadata.present_group_names: + got_mask = getattr(current_bound_only, ptype).mask + assert got_mask.dtype == bool + assert got_mask.dtype == bool + assert got_mask.all() + class TestMaskingParticleDatasets: """Test applying masks to particle datasets.""" @@ -258,12 +370,12 @@ class TestMaskingParticleDatasets: @pytest.mark.parametrize("before_load", (True, False)) def test_reordering_slice_mask(self, sg, particle_name, before_load): """Test whether a slice mask that re-orders elements works.""" + em_before = sg._extra_mask mask = np.s_[::-1] ids_before = getattr(sg, particle_name).particle_ids if before_load: getattr(sg, particle_name)._particle_dataset._particle_ids = None - del getattr(sg._extra_mask, particle_name)._mask - getattr(sg._extra_mask, particle_name)._evaluated = False + sg._extra_mask = em_before masked_dataset = getattr(sg, particle_name)[mask] ids = masked_dataset.particle_ids assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0) @@ -272,6 +384,7 @@ def test_reordering_slice_mask(self, sg, particle_name, before_load): @pytest.mark.parametrize("before_load", (True, False)) def test_reordering_int_mask(self, sg, particle_name, before_load): """Test masking with an integer array: re-orders elements and changes length.""" + em_before = sg._extra_mask ids_before = getattr(sg, particle_name).particle_ids mask = np.arange(ids_before.size) # randomize order (in-place operation) @@ -280,8 +393,7 @@ def test_reordering_int_mask(self, sg, particle_name, before_load): mask = mask[: mask.size // 2] if before_load: getattr(sg, particle_name)._particle_dataset._particle_ids = None - del getattr(sg._extra_mask, particle_name)._mask - getattr(sg._extra_mask, particle_name)._evaluated = False + sg._extra_mask = em_before masked_dataset = getattr(sg, particle_name)[mask] ids = masked_dataset.particle_ids assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0) @@ -290,13 +402,13 @@ def test_reordering_int_mask(self, sg, particle_name, before_load): @pytest.mark.parametrize("before_load", (True, False)) def test_bool_mask(self, sg, particle_name, before_load): """Test whether a boolean array mask works.""" + em_before = sg._extra_mask ids_before = getattr(sg, particle_name).particle_ids # randomly keep about half of particles mask = np.random.rand(ids_before.size) > 0.5 if before_load: getattr(sg, particle_name)._particle_dataset._particle_ids = None - del getattr(sg._extra_mask, particle_name)._mask - getattr(sg._extra_mask, particle_name)._evaluated = False + sg._extra_mask = em_before masked_dataset = getattr(sg, particle_name)[mask] ids = masked_dataset.particle_ids assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0) @@ -340,12 +452,12 @@ class TestMaskingNamedColumnDatasets: @pytest.mark.parametrize("before_load", (True, False)) def test_reordering_slice_mask(self, sg, before_load): """Test whether a slice mask that re-orders elements works.""" + em_before = sg._extra_mask mask = np.s_[::-1] fractions_before = sg.gas.hydrogen_ionization_fractions.neutral if before_load: sg.gas.hydrogen_ionization_fractions._neutral = None - del sg._extra_mask.gas._mask - sg._extra_mask.gas._evaluated = False + sg._extra_mask = em_before masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask] fractions = masked_namedcolumnsdataset.neutral assert_allclose_units( @@ -355,6 +467,7 @@ def test_reordering_slice_mask(self, sg, before_load): @pytest.mark.parametrize("before_load", (True, False)) def test_reordering_int_mask(self, sg, before_load): """Test masking with an integer array: re-orders and changes the length.""" + em_before = sg._extra_mask fractions_before = sg.gas.hydrogen_ionization_fractions.neutral mask = np.arange(fractions_before.size) # randomize order (in-place operation) @@ -363,8 +476,7 @@ def test_reordering_int_mask(self, sg, before_load): mask = mask[: mask.size // 2] if before_load: sg.gas.hydrogen_ionization_fractions._neutral = None - del sg._extra_mask.gas._mask - sg._extra_mask.gas._evaluated = False + sg._extra_mask = em_before masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask] fractions = masked_namedcolumnsdataset.neutral assert_allclose_units( @@ -374,13 +486,13 @@ def test_reordering_int_mask(self, sg, before_load): @pytest.mark.parametrize("before_load", (True, False)) def test_bool_mask(self, sg, before_load): """Test whether a boolean array mask works.""" + em_before = sg._extra_mask fractions_before = sg.gas.hydrogen_ionization_fractions.neutral # randomly keep about half of particles mask = np.random.rand(fractions_before.size) > 0.5 if before_load: sg.gas.hydrogen_ionization_fractions._neutral = None - del sg._extra_mask.gas._mask - sg._extra_mask.gas._evaluated = False + sg._extra_mask = em_before masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask] fractions = masked_namedcolumnsdataset.neutral assert_allclose_units( @@ -576,23 +688,23 @@ def test_compare_nonemask(self): assert lm == lm assert not lm != lm - def test_make_combinable_evaluated(self, sg): + def test_ensure_combinable_evaluated(self, sg): """Test that making a LazyMask 'combinable' results in an integer index array.""" lm = LazyMask(np.s_[:10]) assert not isinstance(lm.mask, np.ndarray) assert not lm._combinable - lm._make_combinable(sg=sg, mask_type="gas") + lm._ensure_combinable(sg=sg, mask_type="gas") assert lm._evaluated assert isinstance(lm.mask, np.ndarray) assert lm.mask.dtype == int assert lm._combinable assert len(lm.mask) == 10 - def test_make_combinable_unevaluated(self, sg): + def test_ensure_combinable_unevaluated(self, sg): """Test that making a LazyMask 'combinable' results in an integer index array.""" lm = LazyMask(mask_function=lambda: np.s_[:10]) assert not lm._combinable - lm._make_combinable(sg=sg, mask_type="gas") + lm._ensure_combinable(sg=sg, mask_type="gas") assert lm._combinable assert not lm._evaluated assert isinstance(lm.mask, np.ndarray) # triggers evaluation