Skip to content

Commit 5d278b5

Browse files
committed
Add a ControllerManager for multi-drone deployments. Fix tests and linting
1 parent 5a4beb6 commit 5d278b5

File tree

7 files changed

+202
-7
lines changed

7 files changed

+202
-7
lines changed
+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Asynchronous controller manager for multi-process control of multiple drones.
2+
3+
This module provides a controller manager that allows multiple controllers to run in separate
4+
processes without blocking other controllers or the main process.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import multiprocessing as mp
10+
from queue import Empty
11+
from typing import TYPE_CHECKING
12+
13+
import numpy as np
14+
15+
from lsy_drone_racing.control.controller import Controller
16+
17+
if TYPE_CHECKING:
18+
from multiprocessing.synchronize import Event
19+
20+
from numpy.typing import NDArray
21+
22+
23+
class ControllerManager:
24+
"""Multi-process safe manager class for asynchronous/non-blocking controller execution.
25+
26+
Note:
27+
The controller manager currently does not support step and episode callbacks.
28+
29+
Todo:
30+
Implement an automated return mechanism for the controllers.
31+
"""
32+
33+
def __init__(self, controllers: list[Controller], default_action: NDArray):
34+
"""Initialize the controller manager."""
35+
assert all(isinstance(c, Controller) for c in controllers), "Invalid controller type(s)!"
36+
self._controllers_cls = controllers
37+
self._obs_queues = [mp.Queue(1) for _ in controllers]
38+
self._action_queues = [mp.Queue(1) for _ in controllers]
39+
self._ready = [mp.Event() for _ in controllers]
40+
self._shutdown = [mp.Event() for _ in controllers]
41+
self._actions = np.tile(default_action, (len(controllers), 1))
42+
43+
def start(self, init_args: tuple | None = None, init_kwargs: dict | None = None):
44+
"""Start the controller manager."""
45+
for i, c in enumerate(self._controllers_cls):
46+
args = (
47+
c,
48+
tuple() if init_args is None else init_args,
49+
dict() if init_kwargs is None else init_kwargs,
50+
self._obs_queues[i],
51+
self._action_queues[i],
52+
self._ready[i],
53+
self._shutdown[i],
54+
)
55+
self._controller_procs.append(mp.Process(target=self._control_loop, args=args))
56+
self._controller_procs[-1].start()
57+
for ready in self._ready: # Wait for all controllers to be ready
58+
ready.wait()
59+
60+
def update_obs(self, obs: dict, info: dict):
61+
"""Pass the observation and info updates to all controller processes.
62+
63+
Args:
64+
obs: The observation dictionary.
65+
info: The info dictionary.
66+
"""
67+
for obs_queue in self._obs_queues:
68+
_clear_producing_queue(obs_queue)
69+
obs_queue.put((obs, info))
70+
71+
def latest_actions(self) -> NDArray:
72+
"""Get the latest actions from all controllers."""
73+
for i, action_queue in enumerate(self._action_queues):
74+
if not action_queue.empty(): # Length of queue is 1 -> action is ready
75+
# The action queue could be cleared in between the check and the get() call. Since
76+
# the controller processes immediately put the next action into the queue, this
77+
# minimum block time is acceptable.
78+
self._actions[i] = action_queue.get()
79+
return np.array(self._actions)
80+
81+
@staticmethod
82+
def _control_loop(
83+
cls: type[Controller],
84+
init_args: tuple,
85+
init_kwargs: dict,
86+
obs_queue: mp.Queue,
87+
action_queue: mp.Queue,
88+
ready: Event,
89+
shutdown: Event,
90+
):
91+
controller = cls(*init_args, **init_kwargs)
92+
ready.set()
93+
while not shutdown.is_set():
94+
obs, info = obs_queue.get() # Blocks until new observation is available
95+
action = controller.compute_control(obs, info)
96+
_clear_producing_queue(action_queue)
97+
action_queue.put_nowait(action)
98+
99+
100+
def _clear_producing_queue(queue: mp.Queue):
101+
"""Clear the queue if it is not empty and this process is the ONLY producer.
102+
103+
Warning:
104+
Only works for queues with a length of 1.
105+
"""
106+
if not queue.empty(): # There are remaining items in the queue
107+
try:
108+
queue.get_nowait()
109+
except Empty: # Another process could have consumed the last item in between
110+
pass # This is fine, the queue is empty

lsy_drone_racing/envs/race_core.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from crazyflow.sim.symbolic import symbolic_attitude
2020
from flax.struct import dataclass
2121
from gymnasium import spaces
22-
from jax.scipy.spatial.transform import Rotation as JaxR
2322
from scipy.spatial.transform import Rotation as R
2423

2524
from lsy_drone_racing.envs.randomize import (

lsy_drone_racing/vicon.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import yaml
2121
from crazyswarm.msg import StateVector
2222
from rosgraph import Master
23-
from scipy.spatial.transform import Rotation as R
2423
from tf2_msgs.msg import TFMessage
2524

2625
from lsy_drone_racing.utils.import_utils import get_ros_package_path

pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ classifiers = [
2020
dependencies = [
2121
"fire >= 0.6.0",
2222
"numpy >= 1.24.1, < 2.0.0",
23-
"PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency
24-
"rospkg >= 1.5.1", # TODO: Remove after moving to cflib
23+
"PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency
24+
"rospkg >= 1.5.1", # TODO: Remove after moving to cflib
2525
"scipy >= 1.10.1",
2626
"gymnasium >= 1.0.0",
2727
"toml >= 0.10.2",
28-
"ml_collections >= 1.0",
28+
"ml-collections >= 1.0",
2929
]
3030

3131
[project.optional-dependencies]

scripts/multi_deploy.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/usr/bin/env python
2+
"""Launch script for the real race with multiple drones.
3+
4+
Usage:
5+
6+
python deploy.py <path/to/controller.py> <path/to/config.toml>
7+
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import logging
13+
import time
14+
from pathlib import Path
15+
from typing import TYPE_CHECKING
16+
17+
import fire
18+
import gymnasium
19+
import rospy
20+
21+
from lsy_drone_racing.controller_manager import ControllerManager
22+
from lsy_drone_racing.utils import load_config, load_controller
23+
24+
if TYPE_CHECKING:
25+
from lsy_drone_racing.envs.drone_racing_deploy_env import (
26+
DroneRacingAttitudeDeployEnv,
27+
DroneRacingDeployEnv,
28+
)
29+
30+
# rospy.init_node changes the default logging configuration of Python, which is bad practice at
31+
# best. As a workaround, we can create loggers under the ROS root logger `rosout`.
32+
# Also see https://github.com/ros/ros_comm/issues/1384
33+
logger = logging.getLogger("rosout." + __name__)
34+
35+
36+
def main(config: str = "multi_level3.toml"):
37+
"""Deployment script to run the controller on the real drone.
38+
39+
Args:
40+
config: Path to the competition configuration. Assumes the file is in `config/`.
41+
controller: The name of the controller file in `lsy_drone_racing/control/` or None. If None,
42+
the controller specified in the config file is used.
43+
"""
44+
config = load_config(Path(__file__).parents[1] / "config" / config)
45+
env_id = "DroneRacingAttitudeDeploy-v0" if "Thrust" in config.env.id else "DroneRacingDeploy-v0"
46+
env: DroneRacingDeployEnv | DroneRacingAttitudeDeployEnv = gymnasium.make(env_id, config=config)
47+
obs, info = env.reset()
48+
49+
module_path = Path(__file__).parents[1] / "lsy_drone_racing/control"
50+
controller_paths = [module_path / p if p.is_relative() else p for p in config.controller.files]
51+
controller_manager = ControllerManager([load_controller(p) for p in controller_paths])
52+
controller_manager.start(init_args=(obs, info))
53+
54+
try:
55+
start_time = time.perf_counter()
56+
while not rospy.is_shutdown():
57+
t_loop = time.perf_counter()
58+
obs, info = env.unwrapped.obs, env.unwrapped.info
59+
# Compute the control action asynchronously. This limits delays and prevents slow
60+
# controllers from blocking the controllers for other drones.
61+
controller_manager.update_obs(obs, info)
62+
actions = controller_manager.latest_actions()
63+
next_obs, reward, terminated, truncated, info = env.step(actions)
64+
controller_manager.step_callback(actions, next_obs, reward, terminated, truncated, info)
65+
obs = next_obs
66+
if terminated or truncated:
67+
break
68+
if (dt := (time.perf_counter() - t_loop)) < (1 / config.env.freq):
69+
time.sleep(1 / config.env.freq - dt)
70+
else:
71+
exc = dt - 1 / config.env.freq
72+
logger.warning(f"Controller execution time exceeded loop frequency by {exc:.3f}s.")
73+
ep_time = time.perf_counter() - start_time
74+
controller_manager.episode_callback()
75+
logger.info(
76+
f"Track time: {ep_time:.3f}s" if obs["target_gate"] == -1 else "Task not completed"
77+
)
78+
finally:
79+
env.close()
80+
81+
82+
if __name__ == "__main__":
83+
logging.basicConfig(level=logging.INFO)
84+
fire.Fire(main)

tests/integration/test_controllers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gymnasium
44
import numpy as np
55
import pytest
6+
from gymnasium.wrappers import JaxToNumpy
67

78
from lsy_drone_racing.utils import load_config, load_controller
89

@@ -59,10 +60,11 @@ def test_attitude_controller(physics: str):
5960
random_resets=config.env.random_resets,
6061
seed=config.env.seed,
6162
)
63+
env = JaxToNumpy(env)
6264
obs, info = env.reset()
6365
ctrl = ctrl_cls(obs, info, config)
6466
while True:
65-
action = ctrl.compute_control(obs, info)
67+
action = ctrl.compute_control(obs, info).astype(np.float32)
6668
obs, reward, terminated, truncated, info = env.step(action)
6769
ctrl.step_callback(action, obs, reward, terminated, truncated, info)
6870
if terminated or truncated:

tests/unit/envs/test_envs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gymnasium
55
import pytest
66
from gymnasium.utils.env_checker import check_env
7+
from gymnasium.wrappers import JaxToNumpy
78

89
from lsy_drone_racing.utils import load_config
910

@@ -31,4 +32,4 @@ def test_passive_checker_wrapper_warnings(action_space: str):
3132
seed=config.env.seed,
3233
disable_env_checker=False,
3334
)
34-
check_env(env.unwrapped)
35+
check_env(JaxToNumpy(env))

0 commit comments

Comments
 (0)