Skip to content

Commit 2a347cf

Browse files
committed
Fix terminated check. Replace euler angles with quaternions
1 parent 3de398a commit 2a347cf

File tree

5 files changed

+48
-66
lines changed

5 files changed

+48
-66
lines changed

lsy_drone_racing/control/attitude_controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def compute_control(
112112
target_thrust[2] += self.drone_mass * self.g
113113

114114
# Update z_axis to the current orientation of the drone
115-
z_axis = R.from_euler("xyz", obs["rpy"]).as_matrix()[:, 2]
115+
z_axis = R.from_quat(obs["quat"]).as_matrix()[:, 2]
116116

117117
# update current thrust
118118
thrust_desired = target_thrust.dot(z_axis)

lsy_drone_racing/envs/drone_racing_deploy_env.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, config: dict | ConfigDict):
7676
Args:
7777
config: The configuration of the environment.
7878
"""
79+
raise NotImplementedError("The deployment environment is currently not functional.")
7980
super().__init__()
8081
self.config = config
8182
self.action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(13,))
@@ -86,12 +87,12 @@ def __init__(self, config: dict | ConfigDict):
8687
self.observation_space = spaces.Dict(
8788
{
8889
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
89-
"rpy": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
90+
"quat": spaces.Box(low=-1, high=1, shape=(4,)),
9091
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
9192
"ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
9293
"target_gate": spaces.Discrete(n_gates, start=-1),
9394
"gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_gates, 3)),
94-
"gates_rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(n_gates, 3)),
95+
"gates_quat": spaces.Box(low=-1, high=1, shape=(n_gates, 4)),
9596
"gates_visited": spaces.Box(low=0, high=1, shape=(n_gates,), dtype=np.bool_),
9697
"obstacles_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_obstacles, 3)),
9798
"obstacles_visited": spaces.Box(
@@ -199,13 +200,11 @@ def close(self):
199200
def obs(self) -> dict:
200201
"""Return the observation of the environment."""
201202
drone = self.vicon.drone_name
202-
rpy = self.vicon.rpy[drone]
203-
ang_vel = R.from_euler("xyz", rpy).inv().apply(self.vicon.ang_vel[drone])
204203
obs = {
205204
"pos": self.vicon.pos[drone].astype(np.float32),
206-
"rpy": rpy.astype(np.float32),
205+
"quat": self.vicon.quat[drone].astype(np.float32),
207206
"vel": self.vicon.vel[drone].astype(np.float32),
208-
"ang_vel": ang_vel.astype(np.float32),
207+
"ang_vel": self.vicon.ang_vel[drone].astype(np.float32),
209208
}
210209

211210
sensor_range = self.config.env.sensor_range
@@ -246,7 +245,7 @@ def obs(self) -> dict:
246245
obs["obstacles_visited"] = self.obstacles_visited
247246

248247
obs["gates_pos"] = gates_pos.astype(np.float32)
249-
obs["gates_rpy"] = gates_rpy.astype(np.float32)
248+
obs["gates_quat"] = R.from_euler("xyz", gates_rpy).as_quat().astype(np.float32)
250249
obs["obstacles_pos"] = obstacles_pos.astype(np.float32)
251250
self._obs = obs
252251
return obs
@@ -268,7 +267,7 @@ def gate_passed(self, pos: NDArray[np.floating], prev_pos: NDArray[np.floating])
268267
# Real gates measure 0.4m x 0.4m, we account for meas. error
269268
gate_size = (0.56, 0.56)
270269
gate_pos = self._obs["gates_pos"][self.target_gate]
271-
gate_rot = R.from_euler("xyz", self._obs["gates_rpy"][self.target_gate])
270+
gate_rot = R.from_quat(self._obs["gates_quat"][self.target_gate])
272271
return check_gate_pass(gate_pos, gate_rot, gate_size, pos, prev_pos)
273272
return False
274273

lsy_drone_racing/envs/race_core.py

+31-47
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ class EnvData:
6363
contact_masks: Array
6464
pos_limit_low: Array
6565
pos_limit_high: Array
66-
rpy_limit_low: Array
67-
rpy_limit_high: Array
6866
gate_mj_ids: Array
6967
obstacle_mj_ids: Array
7068
max_episode_steps: Array
@@ -84,8 +82,6 @@ def create(
8482
sensor_range: float,
8583
pos_limit_low: Array,
8684
pos_limit_high: Array,
87-
rpy_limit_low: Array,
88-
rpy_limit_high: Array,
8985
device: Device,
9086
) -> EnvData:
9187
"""Create a new environment data struct with default values."""
@@ -100,8 +96,6 @@ def create(
10096
steps=jp.zeros(n_envs, dtype=int, device=device),
10197
pos_limit_low=jp.array(pos_limit_low, dtype=np.float32, device=device),
10298
pos_limit_high=jp.array(pos_limit_high, dtype=np.float32, device=device),
103-
rpy_limit_low=jp.array(rpy_limit_low, dtype=np.float32, device=device),
104-
rpy_limit_high=jp.array(rpy_limit_high, dtype=np.float32, device=device),
10599
gate_mj_ids=jp.array(gate_mj_ids, dtype=int, device=device),
106100
obstacle_mj_ids=jp.array(obstacle_mj_ids, dtype=int, device=device),
107101
max_episode_steps=jp.array([max_episode_steps], dtype=int, device=device),
@@ -124,12 +118,12 @@ def build_observation_space(n_gates: int, n_obstacles: int) -> spaces.Dict:
124118
"""Create the observation space for the environment."""
125119
obs_spec = {
126120
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
127-
"rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(3,)),
121+
"quat": spaces.Box(low=-1, high=1, shape=(4,)),
128122
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
129123
"ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
130124
"target_gate": spaces.Discrete(n_gates, start=-1),
131125
"gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_gates, 3)),
132-
"gates_rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(n_gates, 3)),
126+
"gates_quat": spaces.Box(low=-1, high=1, shape=(n_gates, 4)),
133127
"gates_visited": spaces.Box(low=0, high=1, shape=(n_gates,), dtype=bool),
134128
"obstacles_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_obstacles, 3)),
135129
"obstacles_visited": spaces.Box(low=0, high=1, shape=(n_obstacles,), dtype=bool),
@@ -160,15 +154,15 @@ class RaceCoreEnv:
160154
161155
The observation space is a dictionary with the following keys:
162156
- "pos": Drone position
163-
- "rpy": Drone orientation (roll, pitch, yaw)
157+
- "quat": Drone orientation as a quaternion (x, y, z, w)
164158
- "vel": Drone linear velocity
165159
- "ang_vel": Drone angular velocity
166-
- "gates.pos": Positions of the gates
167-
- "gates.rpy": Orientations of the gates
168-
- "gates.visited": Flags indicating if the drone already was/ is in the sensor range of the
160+
- "gates_pos": Positions of the gates
161+
- "gates_quat": Orientations of the gates
162+
- "gates_visited": Flags indicating if the drone already was/ is in the sensor range of the
169163
gates and the true position is known
170-
- "obstacles.pos": Positions of the obstacles
171-
- "obstacles.visited": Flags indicating if the drone already was/ is in the sensor range of the
164+
- "obstacles_pos": Positions of the obstacles
165+
- "obstacles_visited": Flags indicating if the drone already was/ is in the sensor range of the
172166
obstacles and the true position is known
173167
- "target_gate": The current target gate index
174168
@@ -253,7 +247,6 @@ def __init__(
253247
gate_mj_ids, obstacle_mj_ids = self.gates["mj_ids"], self.obstacles["mj_ids"]
254248
pos_limit_low = jp.array([-3, -3, 0], dtype=np.float32, device=self.device)
255249
pos_limit_high = jp.array([3, 3, 2.5], dtype=np.float32, device=self.device)
256-
rpy_limit = jp.array([jp.pi / 2, jp.pi / 2, jp.pi], dtype=jp.float32, device=self.device)
257250
self.data = EnvData.create(
258251
n_envs,
259252
n_drones,
@@ -266,8 +259,6 @@ def __init__(
266259
sensor_range,
267260
pos_limit_low,
268261
pos_limit_high,
269-
-rpy_limit,
270-
rpy_limit,
271262
self.device,
272263
)
273264

@@ -315,16 +306,14 @@ def _step(
315306
self.sim.data = self._warp_disabled_drones(self.sim.data, self.data.disabled_drones)
316307
# Apply the environment logic. Check which drones are now disabled, check which gates have
317308
# been passed, and update the target gate.
318-
drone_pos, drone_quat = self.sim.data.states.pos, self.sim.data.states.quat
309+
drone_pos = self.sim.data.states.pos
319310
mocap_pos, mocap_quat = self.sim.data.mjx_data.mocap_pos, self.sim.data.mjx_data.mocap_quat
320311
contacts = self.sim.contacts()
321312
# Get marked_for_reset before it is updated, because the autoreset needs to be based on the
322313
# previous flags, not the ones from the current step
323314
marked_for_reset = self.data.marked_for_reset
324315
# Apply the environment logic with updated simulation data.
325-
self.data = self._step_env(
326-
self.data, drone_pos, drone_quat, mocap_pos, mocap_quat, contacts
327-
)
316+
self.data = self._step_env(self.data, drone_pos, mocap_pos, mocap_quat, contacts)
328317
# Auto-reset envs. Add configuration option to disable for single-world envs
329318
if self.autoreset and marked_for_reset.any():
330319
self._reset(mask=marked_for_reset)
@@ -357,27 +346,25 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
357346
"""Return the observation of the environment."""
358347
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
359348
# use the actual pose, otherwise use the nominal pose.
360-
gates_pos, gates_rpy, obstacles_pos = self._obs(
349+
gates_pos, gates_quat, obstacles_pos = self._obs(
361350
self.sim.data.mjx_data.mocap_pos,
362351
self.sim.data.mjx_data.mocap_quat,
363352
self.data.gates_visited,
364353
self.gates["mj_ids"],
365354
self.gates["nominal_pos"],
366-
self.gates["nominal_rpy"],
355+
self.gates["nominal_quat"],
367356
self.data.obstacles_visited,
368357
self.obstacles["mj_ids"],
369358
self.obstacles["nominal_pos"],
370359
)
371-
quat = self.sim.data.states.quat
372-
rpy = R.from_quat(quat.reshape(-1, 4)).as_euler("xyz").reshape((*quat.shape[:-1], 3))
373360
obs = {
374361
"pos": np.array(self.sim.data.states.pos, dtype=np.float32),
375-
"rpy": rpy.astype(np.float32),
362+
"quat": np.array(self.sim.data.states.quat, dtype=np.float32),
376363
"vel": np.array(self.sim.data.states.vel, dtype=np.float32),
377364
"ang_vel": np.array(self.sim.data.states.ang_vel, dtype=np.float32),
378365
"target_gate": np.array(self.data.target_gate, dtype=int),
379366
"gates_pos": np.asarray(gates_pos, dtype=np.float32),
380-
"gates_rpy": np.asarray(gates_rpy, dtype=np.float32),
367+
"gates_quat": np.asarray(gates_quat, dtype=np.float32),
381368
"gates_visited": np.asarray(self.data.gates_visited, dtype=bool),
382369
"obstacles_pos": np.asarray(obstacles_pos, dtype=np.float32),
383370
"obstacles_visited": np.asarray(self.data.obstacles_visited, dtype=bool),
@@ -447,16 +434,11 @@ def _reset_env_data(data: EnvData, drone_pos: Array, mask: Array | None = None)
447434
@staticmethod
448435
@jax.jit
449436
def _step_env(
450-
data: EnvData,
451-
drone_pos: Array,
452-
drone_quat: Array,
453-
mocap_pos: Array,
454-
mocap_quat: Array,
455-
contacts: Array,
437+
data: EnvData, drone_pos: Array, mocap_pos: Array, mocap_quat: Array, contacts: Array
456438
) -> EnvData:
457439
"""Step the environment data."""
458440
n_gates = len(data.gate_mj_ids)
459-
disabled_drones = RaceCoreEnv._disabled_drones(drone_pos, drone_quat, contacts, data)
441+
disabled_drones = RaceCoreEnv._disabled_drones(drone_pos, contacts, data)
460442
gates_pos = mocap_pos[:, data.gate_mj_ids]
461443
obstacles_pos = mocap_pos[:, data.obstacle_mj_ids]
462444
# We need to convert the mocap quat from MuJoCo order to scipy order
@@ -498,27 +480,24 @@ def _obs(
498480
gates_visited: Array,
499481
gate_mocap_ids: Array,
500482
nominal_gate_pos: NDArray,
501-
nominal_gate_rpy: NDArray,
483+
nominal_gate_quat: NDArray,
502484
obstacles_visited: Array,
503485
obstacle_mocap_ids: Array,
504486
nominal_obstacle_pos: NDArray,
505487
) -> tuple[Array, Array]:
506488
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
507489
mask, real_pos = gates_visited[..., None], mocap_pos[:, gate_mocap_ids]
508-
real_rpy = JaxR.from_quat(mocap_quat[:, gate_mocap_ids][..., [1, 2, 3, 0]]).as_euler("xyz")
490+
real_quat = mocap_quat[:, gate_mocap_ids][..., [1, 2, 3, 0]]
509491
gates_pos = jp.where(mask, real_pos[:, None], nominal_gate_pos[None, None])
510-
gates_rpy = jp.where(mask, real_rpy[:, None], nominal_gate_rpy[None, None])
492+
gates_quat = jp.where(mask, real_quat[:, None], nominal_gate_quat[None, None])
511493
mask, real_pos = obstacles_visited[..., None], mocap_pos[:, obstacle_mocap_ids]
512494
obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None])
513-
return gates_pos, gates_rpy, obstacles_pos
495+
return gates_pos, gates_quat, obstacles_pos
514496

515497
@staticmethod
516-
def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array:
517-
rpy = JaxR.from_quat(quat).as_euler("xyz")
518-
disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1))
519-
disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1))
520-
disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1))
521-
disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1))
498+
def _disabled_drones(pos: Array, contacts: Array, data: EnvData) -> Array:
499+
disabled = jp.logical_or(data.disabled_drones, jp.any(pos < data.pos_limit_low, axis=-1))
500+
disabled = jp.logical_or(disabled, jp.any(pos > data.pos_limit_high, axis=-1))
522501
disabled = jp.logical_or(disabled, data.target_gate == -1)
523502
contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1)
524503
disabled = jp.logical_or(disabled, contacts)
@@ -539,8 +518,13 @@ def _warp_disabled_drones(data: SimData, mask: Array) -> SimData:
539518
def _load_track(self, track: dict) -> tuple[dict, dict, dict]:
540519
"""Load the track from the config file."""
541520
gate_pos = np.array([g["pos"] for g in track.gates])
542-
gate_rpy = np.array([g["rpy"] for g in track.gates])
543-
gates = {"pos": gate_pos, "rpy": gate_rpy, "nominal_pos": gate_pos, "nominal_rpy": gate_rpy}
521+
gate_quat = R.from_euler("xyz", np.array([g["rpy"] for g in track.gates])).as_quat()
522+
gates = {
523+
"pos": gate_pos,
524+
"quat": gate_quat,
525+
"nominal_pos": gate_pos,
526+
"nominal_quat": gate_quat,
527+
}
544528
obstacle_pos = np.array([o["pos"] for o in track.obstacles])
545529
obstacles = {"pos": obstacle_pos, "nominal_pos": obstacle_pos}
546530
drone_keys = ("pos", "rpy", "vel", "ang_vel")
@@ -578,7 +562,7 @@ def _load_track_into_sim(self, gate_spec: MjSpec, obstacle_spec: MjSpec):
578562
gate = frame.attach_body(gate_spec.find_body("gate"), "", f":{i}")
579563
gate.pos = self.gates["pos"][i]
580564
# Convert from scipy order to MuJoCo order
581-
gate.quat = R.from_euler("xyz", self.gates["rpy"][i]).as_quat()[[3, 0, 1, 2]]
565+
gate.quat = self.gates["quat"][i][[3, 0, 1, 2]]
582566
gate.mocap = True # Make mocap to modify the position of static bodies during sim
583567
for i in range(n_obstacles):
584568
obstacle = frame.attach_body(obstacle_spec.find_body("obstacle"), "", f":{i}")

lsy_drone_racing/utils/ros_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def check_race_track(config: ConfigDict):
3434
assert rng_info.obstacle_pos.type == "uniform", "Race track checks expect uniform distributions"
3535
for i, gate in enumerate(config.env.track.gates):
3636
name = f"gate{i + 1}"
37-
gate_pos, gate_rot = vicon.pos[name], R.from_euler("xyz", vicon.rpy[name])
37+
gate_pos, gate_rot = vicon.pos[name], R.from_quat(vicon.quat[name])
3838
check_bounds(name, gate_pos, gate.pos, rng_info.gate_pos.low, rng_info.gate_pos.high)
39-
check_rotation(name, gate_rot, R.from_euler("xyz", gate.rpy), ang_tol)
39+
check_rotation(name, gate_rot, R.from_quat(gate.quat), ang_tol)
4040

4141
for i, obstacle in enumerate(config.env.track.obstacles):
4242
name = f"obstacle{i + 1}"

lsy_drone_racing/vicon.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
self.track_names = track_names
6161
# Register the Vicon subscribers for the drone and any other tracked object
6262
self.pos: dict[str, np.ndarray] = {}
63-
self.rpy: dict[str, np.ndarray] = {}
63+
self.quat: dict[str, np.ndarray] = {}
6464
self.vel: dict[str, np.ndarray] = {}
6565
self.ang_vel: dict[str, np.ndarray] = {}
6666
self.time: dict[str, float] = {}
@@ -90,8 +90,7 @@ def estimator_callback(self, data: StateVector):
9090
if self.drone_name is None:
9191
return
9292
self.pos[self.drone_name] = np.array(data.pos)
93-
rpy = R.from_quat(data.quat).as_euler("xyz")
94-
self.rpy[self.drone_name] = np.array(rpy)
93+
self.quat[self.drone_name] = np.array(data.quat)
9594
self.vel[self.drone_name] = np.array(data.vel)
9695
self.ang_vel[self.drone_name] = np.array(data.omega_b)
9796

@@ -110,10 +109,10 @@ def tf_callback(self, data: TFMessage):
110109
continue
111110
T, Rot = tf.transform.translation, tf.transform.rotation
112111
pos = np.array([T.x, T.y, T.z])
113-
rpy = R.from_quat([Rot.x, Rot.y, Rot.z, Rot.w]).as_euler("xyz")
112+
quat = np.array([Rot.x, Rot.y, Rot.z, Rot.w])
114113
self.time[name] = time.time()
115114
self.pos[name] = pos
116-
self.rpy[name] = rpy
115+
self.quat[name] = quat
117116

118117
def pose(self, name: str) -> tuple[np.ndarray, np.ndarray]:
119118
"""Get the latest pose of a tracked object.
@@ -122,14 +121,14 @@ def pose(self, name: str) -> tuple[np.ndarray, np.ndarray]:
122121
name: The name of the object.
123122
124123
Returns:
125-
The position and rotation of the object. The rotation is in roll-pitch-yaw format.
124+
The position and orientation (as xyzw quaternion) of the object.
126125
"""
127-
return self.pos[name], self.rpy[name]
126+
return self.pos[name], self.quat[name]
128127

129128
@property
130129
def poses(self) -> tuple[np.ndarray, np.ndarray]:
131130
"""Get the latest poses of all objects."""
132-
return np.stack(self.pos.values()), np.stack(self.rpy.values())
131+
return np.stack(self.pos.values()), np.stack(self.quat.values())
133132

134133
@property
135134
def names(self) -> list[str]:

0 commit comments

Comments
 (0)