28
28
29
29
import gymnasium
30
30
import numpy as np
31
+ from crazyflow .control .control import thrust_curve
32
+ from crazyflow .sim .symbolic import symbolic_attitude
31
33
from gymnasium import spaces
32
34
from scipy .spatial .transform import Rotation as R
33
35
34
36
from lsy_drone_racing .control .closing_controller import ClosingController
35
- from lsy_drone_racing .sim .drone import Drone
36
- from lsy_drone_racing .sim .sim import Sim
37
37
from lsy_drone_racing .utils import check_gate_pass
38
38
from lsy_drone_racing .utils .import_utils import get_ros_package_path , pycrazyswarm
39
39
from lsy_drone_racing .utils .ros_utils import check_drone_start_pos , check_race_track
@@ -111,17 +111,7 @@ def __init__(self, config: dict | Munch):
111
111
names += [f"gate{ g } " for g in range (1 , len (config .env .track .gates ) + 1 )]
112
112
names += [f"obstacle{ g } " for g in range (1 , len (config .env .track .obstacles ) + 1 )]
113
113
self .vicon = Vicon (track_names = names , timeout = 5 )
114
- self .symbolic = None
115
- if config .env .symbolic :
116
- sim = Sim (
117
- track = config .env .track ,
118
- sim_freq = config .sim .sim_freq ,
119
- ctrl_freq = config .sim .ctrl_freq ,
120
- disturbances = getattr (config .sim , "disturbances" , {}),
121
- randomization = getattr (config .env , "randomization" , {}),
122
- physics = config .sim .physics ,
123
- )
124
- self .symbolic = sim .symbolic ()
114
+ self .symbolic = symbolic_attitude (config .env .freq ) if config .env .symbolic else None
125
115
self ._last_pos = np .zeros (3 )
126
116
127
117
self .gates_visited = np .array ([False ] * len (config .env .track .gates ))
@@ -153,7 +143,7 @@ def reset(
153
143
info ["low_level_ctrl_freq" ] = self .config .sim .ctrl_freq
154
144
info ["env_freq" ] = self .config .env .freq
155
145
info ["drone_mass" ] = 0.033 # Crazyflie 2.1 mass in kg
156
- return self .obs , info
146
+ return self .obs () , info
157
147
158
148
def step (
159
149
self , action : NDArray [np .floating ]
@@ -175,7 +165,7 @@ def step(
175
165
if self .target_gate >= len (self .config .env .track .gates ):
176
166
self .target_gate = - 1
177
167
terminated = self .target_gate == - 1
178
- return self .obs , - 1.0 , terminated , False , self .info
168
+ return self .obs () , - 1.0 , terminated , False , self .info
179
169
180
170
def close (self ):
181
171
"""Close the environment by stopping the drone and landing back at the starting position."""
@@ -189,7 +179,7 @@ def close(self):
189
179
self .config .env .freq = freq_new
190
180
t_step_ctrl = 1 / self .config .env .freq
191
181
192
- obs = self .obs
182
+ obs = self .obs ()
193
183
obs ["acc" ] = np .array (
194
184
[0 , 0 , 0 ]
195
185
) # TODO, use actual value when avaiable or do one step to calculate from velocity
@@ -211,15 +201,14 @@ def close(self):
211
201
action [10 :],
212
202
)
213
203
self .cf .cmdFullState (pos , vel , acc , yaw , rpy_rate )
214
- obs = self .obs
204
+ obs = self .obs ()
215
205
obs ["acc" ] = np .array ([0 , 0 , 0 ])
216
206
controller .step_callback (action , obs , 0 , True , False , info )
217
207
time .sleep (t_step_ctrl )
218
208
219
209
self .cf .notifySetpointsStop ()
220
210
self .cf .land (0.05 , 2.0 )
221
211
222
- @property
223
212
def obs (self ) -> dict :
224
213
"""Return the observation of the environment."""
225
214
drone = self .vicon .drone_name
@@ -312,7 +301,6 @@ def __init__(self, config: dict | Munch):
312
301
"""
313
302
super ().__init__ (config )
314
303
self .action_space = gymnasium .spaces .Box (low = - 1 , high = 1 , shape = (4 ,))
315
- self .drone = Drone ("mellinger" )
316
304
317
305
def step (
318
306
self , action : NDArray [np .floating ]
@@ -329,12 +317,11 @@ def step(
329
317
assert action .shape == self .action_space .shape , f"Invalid action shape: { action .shape } "
330
318
collective_thrust , rpy = action [0 ], action [1 :]
331
319
rpy_deg = np .rad2deg (rpy )
332
- collective_thrust = self .drone ._thrust_to_pwms (collective_thrust )
333
- self .cf .cmdVel (* rpy_deg , collective_thrust )
320
+ self .cf .cmdVel (* rpy_deg , thrust_curve (collective_thrust ))
334
321
current_pos = self .vicon .pos [self .vicon .drone_name ]
335
322
self .target_gate += self .gate_passed (current_pos , self ._last_pos )
336
323
self ._last_pos [:] = current_pos
337
324
if self .target_gate >= len (self .config .env .track .gates ):
338
325
self .target_gate = - 1
339
326
terminated = self .target_gate == - 1
340
- return self .obs , - 1.0 , terminated , False , self .info
327
+ return self .obs () , - 1.0 , terminated , False , self .info
0 commit comments