Skip to content

Commit 9e68f0e

Browse files
authored
Resolve version conflicts and overwritten attribute in copying (#87)
* Don't deepcopy scipy objects. * Bump to v3. * Compare safely in any supported scipy version. * Skip more steps in init during copy. * Reorder operations during copyinit. * Reorder operations during copyinit. * Fix bug in test. * Add a regression test. * Remove unused if statements, add clarifying comment.
1 parent 3b63f14 commit 9e68f0e

5 files changed

Lines changed: 65 additions & 21 deletions

File tree

codemeta.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@
1919
"codeRepository": [
2020
"https://github.com/SWIFTSIM/swiftgalaxy",
2121
],
22-
"version": "2.5.0",
22+
"version": "3.0.0",
2323
"license": "https://spdx.org/licenses/GPL-3.0-only.html",
2424
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "swiftgalaxy"
7-
version="2.5.0"
7+
version="3.0.0"
88
authors = [
99
{ name="Kyle Oman", email="kyle.a.oman@durham.ac.uk" },
1010
]

swiftgalaxy/halo_catalogues.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,12 +1853,14 @@ def lazy_mask() -> Union[NDArray, slice]:
18531853
if not hasattr(cat, list_name):
18541854
return null_slice
18551855
mask = getattr(cat, list_name)
1856-
mask = mask[in_one_of_ranges(mask, getattr(sg.mask, group_name))]
1856+
mask = mask[
1857+
in_one_of_ranges(mask, getattr(sg._spatial_mask, group_name))
1858+
]
18571859
mask = np.isin(
18581860
np.concatenate(
18591861
[
18601862
np.arange(start, end)
1861-
for start, end in getattr(sg.mask, group_name)
1863+
for start, end in getattr(sg._spatial_mask, group_name)
18621864
]
18631865
),
18641866
mask,

swiftgalaxy/reader.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,20 @@ def _apply_box_wrap(
7171
:class:`~swiftsimio.objects.cosmo_array`
7272
The coordinates wrapped to lie within the box dimensions.
7373
"""
74+
rotation_is_identity = (
75+
True
76+
if current_transform is None
77+
else current_transform.rotation.approx_equal(Rotation.identity())
78+
)
79+
# in scipy 1.16 approx_equal returns bool, in 1.17 returns array of bool, so:
80+
rotation_is_identity = (
81+
rotation_is_identity.all()
82+
if hasattr(rotation_is_identity, "all")
83+
else rotation_is_identity
84+
)
7485
if boxsize is None:
7586
return coords
76-
elif (
77-
current_transform is None
78-
or (current_transform.rotation.approx_equal(Rotation.identity())).squeeze()
79-
):
87+
elif current_transform is None or rotation_is_identity:
8088
return (coords + offset_frac * boxsize) % boxsize - offset_frac * boxsize
8189
else:
8290
return _apply_rotation(
@@ -1566,10 +1574,8 @@ def __init__(
15661574
self.id_particle_dataset_name = id_particle_dataset_name
15671575
self.coordinates_dataset_name = coordinates_dataset_name
15681576
self.velocities_dataset_name = velocities_dataset_name
1569-
if not hasattr(self, "_coordinate_like_transform"):
1570-
self._coordinate_like_transform = RigidTransform.identity()
1571-
if not hasattr(self, "_velocity_like_transform"):
1572-
self._velocity_like_transform = RigidTransform.identity()
1577+
self._coordinate_like_transform = RigidTransform.identity()
1578+
self._velocity_like_transform = RigidTransform.identity()
15731579
if self.halo_catalogue is None:
15741580
# in server mode we don't have a halo_catalogue yet
15751581
self._spatial_mask = getattr(self, "_spatial_mask", None)
@@ -1790,15 +1796,10 @@ def _copyinit(
17901796
"""
17911797
sg = cls.__new__(cls)
17921798
sg._spatial_mask = _spatial_mask
1793-
sg._extra_mask = _extra_mask
1794-
if _coordinate_like_transform is not None:
1795-
sg._coordinate_like_transform = _coordinate_like_transform
1796-
if _velocity_like_transform is not None:
1797-
sg._velocity_like_transform = _velocity_like_transform
17981799
SWIFTGalaxy.__init__(
17991800
sg,
18001801
snapshot_filename,
1801-
halo_catalogue,
1802+
halo_catalogue=None,
18021803
auto_recentre=auto_recentre,
18031804
transforms_like_coordinates=transforms_like_coordinates,
18041805
transforms_like_velocities=transforms_like_velocities,
@@ -1808,6 +1809,13 @@ def _copyinit(
18081809
coordinate_frame_from=coordinate_frame_from,
18091810
_data_server=_data_server,
18101811
)
1812+
if _extra_mask is not None:
1813+
sg._extra_mask = _extra_mask
1814+
if _coordinate_like_transform is not None:
1815+
sg._coordinate_like_transform = _coordinate_like_transform
1816+
if _velocity_like_transform is not None:
1817+
sg._velocity_like_transform = _velocity_like_transform
1818+
sg.halo_catalogue = halo_catalogue
18111819
return sg
18121820

18131821
def __str__(self) -> str:
@@ -1930,8 +1938,8 @@ def _data_copy(
19301938
velocities_dataset_name=deepcopy(self.velocities_dataset_name),
19311939
_spatial_mask=self._spatial_mask,
19321940
_extra_mask=deepcopy(self._extra_mask),
1933-
_coordinate_like_transform=deepcopy(self._coordinate_like_transform),
1934-
_velocity_like_transform=deepcopy(self._velocity_like_transform),
1941+
_coordinate_like_transform=copy(self._coordinate_like_transform),
1942+
_velocity_like_transform=copy(self._velocity_like_transform),
19351943
_data_server=_data_server,
19361944
)
19371945
for particle_name in sg.metadata.present_group_names:

tests/test_masking.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import pytest
44
from copy import copy, deepcopy
55
import numpy as np
6+
import unyt as u
67
from unyt.testing import assert_allclose_units
8+
from swiftsimio import cosmo_quantity
79
from swiftgalaxy import SWIFTGalaxy, MaskCollection
810
from swiftgalaxy.demo_data import (
911
ToyHF,
@@ -188,6 +190,38 @@ def test_bool_mask(self, sg, particle_name, before_load):
188190
ids = masked_dataset.particle_ids
189191
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
190192

193+
def test_chaining_masks(self, sg):
194+
"""
195+
Check that we can mask a particle dataset after masking the swiftgalaxy.
196+
197+
This is a regression test, but with no associated github issue.
198+
"""
199+
sg.mask_particles(
200+
MaskCollection(
201+
gas=sg.gas.spherical_coordinates.r
202+
< cosmo_quantity(
203+
3,
204+
u.kpc,
205+
comoving=True,
206+
scale_factor=sg.metadata.scale_factor,
207+
scale_exponent=1,
208+
)
209+
)
210+
)
211+
# this had previously caused a crash in version <=2.4.1:
212+
# IndexError: boolean index did not match indexed array along axis 0;
213+
# size of axis is 5000 but size of corresponding boolean axis is 1480
214+
sg.gas[
215+
sg.gas.spherical_coordinates.r
216+
> cosmo_quantity(
217+
1,
218+
u.kpc,
219+
comoving=True,
220+
scale_factor=sg.metadata.scale_factor,
221+
scale_exponent=1,
222+
)
223+
]
224+
191225

192226
class TestMaskingNamedColumnDatasets:
193227
"""Test applying masks to named column datasets."""
@@ -219,7 +253,7 @@ def test_reordering_int_mask(self, sg, before_load):
219253
if before_load:
220254
sg.gas.hydrogen_ionization_fractions._neutral = None
221255
del sg._extra_mask.gas._mask
222-
sg._extra_mask.gas._mask = False
256+
sg._extra_mask.gas._evaluated = False
223257
masked_namedcolumnsdataset = sg.gas.hydrogen_ionization_fractions[mask]
224258
fractions = masked_namedcolumnsdataset.neutral
225259
assert_allclose_units(

0 commit comments

Comments
 (0)