Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 81 additions & 23 deletions swiftgalaxy/demo_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,10 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask:
return swift_mask

def _generate_bound_only_mask(
self, sg: SWIFTGalaxy, mask_loaded: bool = True
self,
sg: SWIFTGalaxy,
mask_loaded: bool = True,
load_masked: bool = False,
) -> MaskCollection:
"""
Evaluate the extra mask selecting particles belonging to the target galaxy.
Expand All @@ -524,13 +527,17 @@ def _generate_bound_only_mask(
Whether to mask any data loaded while creating the mask. The iterator wants to
switch this off.

load_masked : :obj:`bool`
If ``True``, any data loaded to evaluate the mask are loaded in their masked
state, otherwise masking is bypassed during loading.

Returns
-------
:class:`~swiftgalaxy.masks.MaskCollection`
The extra mask.
"""

def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask:
def generate_lazy_mask(group_name: str) -> LazyMask:
"""
Generate a function that evaluates a mask for bound particles.

Expand All @@ -541,10 +548,6 @@ def generate_lazy_mask(group_name: str, mask_loaded: bool) -> LazyMask:
group_name : :obj:`str`
The particle type to evaluate a mask for.

mask_loaded : :obj:`bool`
Whether to mask the data loaded while constructing the mask. The iterator
wants to switch this off.

Returns
-------
Callable
Expand All @@ -555,10 +558,6 @@ def lazy_mask() -> Union[NDArray, slice, EllipsisType]:
"""
"Evaluate" a mask that selects bound particles.

In reality we know what the mask is a priori. We pretend that we need to
load the particle ids so that we can test the behaviour of a dataset
loaded while constructing the mask.

This function must optionally mask the data (``particle_ids``) that it
has loaded.

Expand All @@ -567,19 +566,78 @@ def lazy_mask() -> Union[NDArray, slice, EllipsisType]:
:class:`~numpy.ndarray`, :obj:`slice` or :obj:`Ellipsis`
The mask that selects bound particles.
"""
getattr(
getattr(sg, group_name)._particle_dataset,
sg.id_particle_dataset_name,
) # load the ids
assert isinstance(self._mask_index, int) # placate mypy
mask = {
"gas": (np.s_[-_n_g_1:], np.s_[-_n_g_2:])[self._mask_index],
"dark_matter": (np.s_[-_n_dm_1:], np.s_[-_n_dm_2:])[
self._mask_index
],
"stars": np.s_[...],
"black_holes": np.s_[...],
}[group_name]
if group_name == "gas":
bound_ids = [
# gas IDs group 0
np.arange(1 + _n_g_b // 2, 1 + _n_g_b // 2 + _n_g_1, dtype=int),
# gas IDs group 1
np.arange(1 + _n_g_b + _n_g_1, 1 + _n_g_all, dtype=int),
][self._mask_index]
elif group_name == "dark_matter":
bound_ids = [
# dm IDs group 0
np.arange(
1 + _n_g_all + _n_dm_b // 2,
1 + _n_g_all + _n_dm_b // 2 + _n_dm_1,
dtype=int,
),
# dm IDs group 1
np.arange(
1 + _n_g_all + _n_dm_b + _n_dm_1,
1 + _n_g_all + _n_dm_all,
dtype=int,
),
][self._mask_index]
elif group_name == "stars":
bound_ids = [
# star IDs group 0
np.arange(
1 + _n_g_all + _n_dm_all,
1 + _n_g_all + _n_dm_all + _n_s_1,
dtype=int,
),
# star IDs group 1
np.arange(
1 + _n_g_all + _n_dm_all + _n_s_1,
1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2,
dtype=int,
),
][self._mask_index]
elif group_name == "black_holes":
bound_ids = [
# bh IDs group 0
np.arange(
1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2,
1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2 + _n_bh_1,
dtype=int,
),
# bh IDs group 1
np.arange(
1 + _n_g_all + _n_dm_all + _n_s_1 + _n_s_2 + _n_bh_1,
1
+ _n_g_all
+ _n_dm_all
+ _n_s_1
+ _n_s_2
+ _n_bh_1
+ _n_bh_2,
dtype=int,
),
][self._mask_index]
if load_masked:
particle_ids = getattr(
getattr(sg, group_name), sg.id_particle_dataset_name
)
else:
particle_ids = getattr(
getattr(sg, group_name)._particle_dataset,
sg.id_particle_dataset_name,
)
mask = np.isin(
particle_ids.to_physical_value(u.dimensionless),
bound_ids,
)
if mask_loaded:
# mask the particle_ids
setattr(
Expand All @@ -601,7 +659,7 @@ def lazy_mask() -> Union[NDArray, slice, EllipsisType]:

return MaskCollection(
**{
group_name: generate_lazy_mask(group_name, mask_loaded)
group_name: generate_lazy_mask(group_name)
for group_name in sg.metadata.present_group_names
}
)
Expand Down
Loading
Loading