2121from warnings import warn
2222from copy import copy , deepcopy
2323import numpy as np
24- from scipy .spatial .transform import Rotation
24+ from scipy .spatial .transform import Rotation , RigidTransform
2525import unyt
2626from swiftsimio import metadata as swiftsimio_metadata
2727from swiftsimio .reader import (
4040def _apply_box_wrap (
4141 coords : cosmo_array ,
4242 boxsize : Optional [cosmo_array ],
43- current_transform : Optional [np . ndarray ],
43+ current_transform : Optional [RigidTransform ],
4444 offset_frac : float = 0.5 ,
4545) -> cosmo_array :
4646 """
@@ -59,8 +59,8 @@ def _apply_box_wrap(
5959 boxsize : :class:`~swiftsimio.objects.cosmo_array` or ``None``
6060 The dimensions of the box to wrap (3 elements).
6161
62- current_transform : :class:`~numpy.ndarray `
63- The currently active 4x4 transformation matrix .
62+ current_transform : :class:`~scipy.spatial.transform.RigidTransform `
63+ The currently active transformation.
6464
6565 offset_frac : :obj:`float`, default: ``0.5``
6666 The fraction of the box to offset by. The default it to wrap to [-Lbox/2, Lbox/2].
@@ -71,15 +71,23 @@ def _apply_box_wrap(
7171 :class:`~swiftsimio.objects.cosmo_array`
7272 The coordinates wrapped to lie within the box dimensions.
7373 """
74- rot = current_transform [:3 , :3 ] if current_transform is not None else np .eye (3 )
75- return (
76- (
77- (coords .dot (rot .T ) + offset_frac * boxsize ) % boxsize
78- - offset_frac * boxsize
79- ).dot (rot )
80- if boxsize is not None
81- else coords
82- )
74+ if boxsize is None :
75+ return coords
76+ elif (
77+ current_transform is None
78+ or (current_transform .rotation .approx_equal (Rotation .identity ())).squeeze ()
79+ ):
80+ return (coords + offset_frac * boxsize ) % boxsize - offset_frac * boxsize
81+ else :
82+ return _apply_rotation (
83+ (
84+ _apply_rotation (coords , current_transform .rotation .inv ())
85+ + offset_frac * boxsize
86+ )
87+ % boxsize
88+ - offset_frac * boxsize ,
89+ current_transform .rotation ,
90+ )
8391
8492
8593def _apply_translation (coords : cosmo_array , offset : cosmo_array ) -> cosmo_array :
@@ -114,9 +122,9 @@ def _apply_translation(coords: cosmo_array, offset: cosmo_array) -> cosmo_array:
114122 return coords + offset
115123
116124
117- def _apply_rotmat (coords : cosmo_array , rotation_matrix : np . ndarray ) -> cosmo_array :
125+ def _apply_rotation (coords : cosmo_array , rotation : Rotation ) -> cosmo_array :
118126 """
119- Apply a rotation matrix to a coordinate array.
127+ Apply a rotation to a coordinate array.
120128
121129 Applies a rotation in-place using a view through a :class:`numpy.ndarray`, then
122130 restores units and metadata of the :class:`~swiftsimio.objects.cosmo_array`.
@@ -125,24 +133,26 @@ def _apply_rotmat(coords: cosmo_array, rotation_matrix: np.ndarray) -> cosmo_arr
125133 ----------
126134 coords : :class:`~swiftsimio.objects.cosmo_array`
127135 The coordinate array to be rotated.
128- rotation_matrix : :class:`~numpy.ndarray `
129- The rotation matrix (3x3) .
136+ rotation : :class:`~scipy.spatial.transform.Rotation `
137+ The rotation to apply .
130138
131139 Returns
132140 -------
133141 :class:`~swiftsimio.objects.cosmo_array`
134142 The coordinate array with rotation applied.
135143 """
136144 return cosmo_array (
137- coords .view (np .ndarray ). dot ( rotation_matrix ),
145+ rotation . apply ( coords .view (np .ndarray )),
138146 units = coords .units ,
139147 cosmo_factor = coords .cosmo_factor ,
140148 comoving = coords .comoving ,
141149 )
142150
143151
144- def _apply_4transform (
145- coords : cosmo_array , transform : np .ndarray , transform_units : unyt .unit_object .Unit
152+ def _apply_rigid_transform (
153+ coords : cosmo_array ,
154+ rigid_transform : RigidTransform ,
155+ transform_units : unyt .unit_object .Unit ,
146156) -> cosmo_array :
147157 """
148158 Apply an affine coordinate transformation to a coordinate array.
@@ -157,8 +167,8 @@ def _apply_4transform(
157167 coords : :class:`~swiftsimio.objects.cosmo_array`
158168 The coordinate array to be transformed.
159169
160- transform : :class:`~numpy.ndarray `
161- The 4x4 transformation matrix .
170+ rigid_transform : :class:`~scipy.spatial.transform.RigidTransform `
171+ The transformation.
162172
163173 transform_units : :class:`unyt.unit_object.Unit`
164174 The units assumed in the translation portion of the transformation matrix.
@@ -169,20 +179,14 @@ def _apply_4transform(
169179 The coordinate array with transformation applied.
170180 """
171181 retval = cosmo_array (
172- np .hstack (
173- (
174- coords .to_comoving ().to_value (transform_units ),
175- np .ones (coords .shape [0 ])[:, np .newaxis ],
176- )
177- ).dot (transform )[:, :3 ],
182+ rigid_transform .apply (coords .to_comoving_value (transform_units )),
178183 units = transform_units ,
179184 comoving = True ,
180185 cosmo_factor = coords .cosmo_factor ,
181186 )
182- if coords .comoving :
183- return retval .to_comoving ()
184- else :
185- return retval .to_physical ()
187+ if not coords .comoving :
188+ retval .convert_to_physical ()
189+ return retval
186190
187191
188192def _data_read_wrapper (prop : str ) -> Callable :
@@ -941,7 +945,7 @@ def _apply_transforms(self, data: cosmo_array, dataset_name: str) -> cosmo_array
941945 else :
942946 transform = None
943947 if transform is not None :
944- data = _apply_4transform (data , transform , transform_units )
948+ data = _apply_rigid_transform (data , transform , transform_units )
945949 boxsize = getattr (self ._particle_dataset .metadata , "boxsize" , None )
946950 if dataset_name in self ._swiftgalaxy .transforms_like_coordinates :
947951 data = _apply_box_wrap (data , boxsize , transform )
@@ -1563,9 +1567,9 @@ def __init__(
15631567 self .coordinates_dataset_name = coordinates_dataset_name
15641568 self .velocities_dataset_name = velocities_dataset_name
15651569 if not hasattr (self , "_coordinate_like_transform" ):
1566- self ._coordinate_like_transform = np . eye ( 4 )
1570+ self ._coordinate_like_transform = RigidTransform . identity ( )
15671571 if not hasattr (self , "_velocity_like_transform" ):
1568- self ._velocity_like_transform = np . eye ( 4 )
1572+ self ._velocity_like_transform = RigidTransform . identity ( )
15691573 if self .halo_catalogue is None :
15701574 # in server mode we don't have a halo_catalogue yet
15711575 self ._spatial_mask = getattr (self , "_spatial_mask" , None )
@@ -2011,23 +2015,23 @@ def rotate(self, rotation: Rotation) -> None:
20112015 :class:`~scipy.spatial.transform.Rotation` supports several input
20122016 formats, including axis-angle, rotation matrices, and others.
20132017 """
2014- rotation_matrix = rotation .as_matrix ()
20152018 rotatable = self .transforms_like_coordinates | self .transforms_like_velocities
20162019 for particle_name in self .metadata .present_group_names :
20172020 dataset = getattr (self , particle_name )._particle_dataset
20182021 for field_name in rotatable :
20192022 field_data = getattr (dataset , f"_{ field_name } " )
20202023 if field_data is not None :
2021- field_data = _apply_rotmat (field_data , rotation_matrix )
2024+ field_data = _apply_rotation (field_data , rotation )
20222025 setattr (dataset , f"_{ field_name } " , field_data )
2023- rotmat4 = np .eye (4 )
2024- rotmat4 [:3 , :3 ] = rotation_matrix
2025- self ._append_to_coordinate_like_transform (rotmat4 )
2026- self ._append_to_velocity_like_transform (rotmat4 )
2026+
2027+ self ._append_to_coordinate_like_transform (
2028+ RigidTransform .from_rotation (rotation )
2029+ )
2030+ self ._append_to_velocity_like_transform (RigidTransform .from_rotation (rotation ))
20272031 self .wrap_box ()
20282032 return
20292033
2030- def _transform (self , transform4 : cosmo_array , boost : bool = False ) -> None :
2034+ def _transform (self , rigid_transform : RigidTransform , boost : bool = False ) -> None :
20312035 """
20322036 Apply a 4x4 transformation matrix to either the spatial or velocity coordinates.
20332037
@@ -2036,7 +2040,7 @@ def _transform(self, transform4: cosmo_array, boost: bool = False) -> None:
20362040
20372041 Parameters
20382042 ----------
2039- transform4 : :class:`~numpy.ndarray `
2043+ rigid_transform : :class:`~scipy.spatial.transform.RigidTransform `
20402044 The transformation to be applied.
20412045 boost : :obj:`bool`
20422046 If ``True``, translate the velocity coordinates, else translate the spatial
@@ -2059,14 +2063,14 @@ def _transform(self, transform4: cosmo_array, boost: bool = False) -> None:
20592063 for field_name in transformable :
20602064 field_data = getattr (dataset , f"_{ field_name } " )
20612065 if field_data is not None :
2062- field_data = _apply_4transform (
2063- field_data , transform4 , transform_units
2066+ field_data = _apply_rigid_transform (
2067+ field_data , rigid_transform , transform_units
20642068 )
20652069 setattr (dataset , f"_{ field_name } " , field_data )
20662070 if boost :
2067- self ._append_to_velocity_like_transform (transform4 )
2071+ self ._append_to_velocity_like_transform (rigid_transform )
20682072 else :
2069- self ._append_to_coordinate_like_transform (transform4 )
2073+ self ._append_to_coordinate_like_transform (rigid_transform )
20702074 if not boost :
20712075 self .wrap_box ()
20722076
@@ -2101,19 +2105,22 @@ def _translate(self, translation: cosmo_array, boost: bool = False) -> None:
21012105 transform_units = self .metadata .units .length / self .metadata .units .time
21022106 else :
21032107 transform_units = self .metadata .units .length
2104- transform4 = np .eye (4 )
21052108 if hasattr (translation , "comoving" ):
2106- transform4 [3 , :3 ] = translation .to_comoving ().to_value (transform_units )
2109+ rigid_transform = RigidTransform .from_translation (
2110+ translation .to_comoving_value (transform_units )
2111+ )
21072112 else :
2108- transform4 [3 , :3 ] = translation .to_value (transform_units )
2113+ rigid_transform = RigidTransform .from_translation (
2114+ translation .to_value (transform_units )
2115+ )
21092116 warn (
21102117 "Translation assumed to be in comoving (not physical) coordinates." ,
21112118 category = RuntimeWarning ,
21122119 )
21132120 if boost :
2114- self ._append_to_velocity_like_transform (transform4 )
2121+ self ._append_to_velocity_like_transform (rigid_transform )
21152122 else :
2116- self ._append_to_coordinate_like_transform (transform4 )
2123+ self ._append_to_coordinate_like_transform (rigid_transform )
21172124 if not boost :
21182125 self .wrap_box ()
21192126 return
@@ -2131,17 +2138,16 @@ def centre(self) -> cosmo_array:
21312138 The origin of the coordinate reference frame.
21322139 """
21332140 transform_units = self .metadata .units .length
2134- transform = np .linalg .inv (self ._coordinate_like_transform )
21352141 return _apply_box_wrap (
2136- _apply_4transform (
2142+ _apply_rigid_transform (
21372143 cosmo_array (
21382144 np .zeros ((1 , 3 )),
21392145 units = transform_units ,
21402146 comoving = True ,
21412147 scale_factor = self .metadata .scale_factor ,
21422148 scale_exponent = 1 ,
21432149 ),
2144- transform ,
2150+ self . _coordinate_like_transform . inv () ,
21452151 transform_units ,
21462152 ).squeeze (),
21472153 self .metadata .boxsize ,
@@ -2162,16 +2168,15 @@ def velocity_centre(self) -> cosmo_array:
21622168 The origin of the velocity reference frame.
21632169 """
21642170 transform_units = self .metadata .units .length / self .metadata .units .time
2165- transform = np .linalg .inv (self ._velocity_like_transform )
2166- return _apply_4transform (
2171+ return _apply_rigid_transform (
21672172 cosmo_array (
21682173 np .zeros ((1 , 3 )),
21692174 units = transform_units ,
21702175 comoving = True ,
21712176 scale_factor = self .metadata .scale_factor ,
21722177 scale_exponent = 0 ,
21732178 ),
2174- transform ,
2179+ self . _velocity_like_transform . inv () ,
21752180 transform_units ,
21762181 ).squeeze ()
21772182
@@ -2185,7 +2190,7 @@ def rotation(self) -> Rotation:
21852190 :class:`scipy.spatial.transform.Rotation`
21862191 The current rotation.
21872192 """
2188- return Rotation . from_matrix ( self ._coordinate_like_transform [: 3 , : 3 ])
2193+ return self ._coordinate_like_transform . rotation
21892194
21902195 def translate (self , translation : cosmo_array ) -> None :
21912196 """
@@ -2343,37 +2348,43 @@ def mask_particles(self, mask_collection: MaskCollection) -> None:
23432348 getattr (self , particle_name )._mask_dataset (mask )
23442349 return
23452350
2346- def _append_to_coordinate_like_transform (self , transform : np .ndarray ) -> None :
2351+ def _append_to_coordinate_like_transform (
2352+ self , rigid_transform : RigidTransform
2353+ ) -> None :
23472354 """
23482355 Add a new transformation to the sequence for the spatial-like coordinates.
23492356
2350- The cumulative transformation is stored as a single 4x4 transformation matrix ,
2351- so we update the current transformation using a dot product . This voids any
2352- derived (spherical/cylindrical) coordinates.
2357+ The cumulative transformation is stored as a single transformation object ,
2358+ so we update the current transformation. This voids any derived
2359+ (spherical/cylindrical) coordinates.
23532360
23542361 Parameters
23552362 ----------
2356- transform : :class:`~numpy.ndarray `
2363+ rigid_transform : :class:`~scipy.spatial.transform.RigidTransform `
23572364 The transform to add to the cumulative coordinate transformation.
23582365 """
2359- self ._coordinate_like_transform = self ._coordinate_like_transform .dot (transform )
2366+ self ._coordinate_like_transform = (
2367+ rigid_transform * self ._coordinate_like_transform
2368+ )
23602369 self ._void_derived_coordinates ()
23612370 return
23622371
2363- def _append_to_velocity_like_transform (self , transform : np .ndarray ) -> None :
2372+ def _append_to_velocity_like_transform (
2373+ self , rigid_transform : RigidTransform
2374+ ) -> None :
23642375 """
23652376 Add a new transformation to the sequence for the velocity-like coordinates.
23662377
2367- The cumulative transformation is stored as a single 4x4 transformation matrix ,
2368- so we update the current transformation using a dot product . This voids any
2369- derived (spherical/cylindrical) coordinates.
2378+ The cumulative transformation is stored as a single transformation object ,
2379+ so we update the current transformation. This voids any derived
2380+ (spherical/cylindrical) coordinates.
23702381
23712382 Parameters
23722383 ----------
2373- transform : :class:`~numpy.ndarray `
2384+ rigid_transform : :class:`~scipy.spatial.transform.RigidTransform `
23742385 The transform to add to the cumulative velocity transformation.
23752386 """
2376- self ._velocity_like_transform = self ._velocity_like_transform . dot ( transform )
2387+ self ._velocity_like_transform = rigid_transform * self ._velocity_like_transform
23772388 self ._void_derived_coordinates ()
23782389 return
23792390
0 commit comments