Skip to content

Commit aa052de

Browse files
Improved joint definitions. Arm training does not use the lift anymore.
1 parent dea87cb commit aa052de

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

examples/rlearning/dm_control_ppo.ipynb

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@
249249
"\"\"\"\n",
250250
"This is a training task that tries to push a cube to a target location using all the joints and actuators on the robot.\n",
251251
"\"\"\"\n",
252+
"from functools import cache\n",
253+
"\n",
254+
"\n",
252255
"class StretchPushCubeTraining:\n",
253256
" def __init__(self, physics: mujoco.Physics, push_cube_by:tuple[float,float,float]):\n",
254257
" self.physics = physics\n",
@@ -269,14 +272,26 @@
269272
"\n",
270273
" self.current_distance_to_target = float('inf')\n",
271274
"\n",
275+
" @cache\n",
272276
" def _get_joints(self):\n",
273-
" \"\"\"Gets joints, but removes unnamed ones (which we probably don't care about)\"\"\"\n",
274-
" return [name for name in self.physics.named.model.name_jntadr.axes.row.names if name != \"\"]\n",
277+
" \"\"\"Gets joint names in MJCF\"\"\"\n",
278+
" return [name for j in self._get_actuators() for name in j.get_joint_names_in_mjcf()]\n",
279+
" \n",
280+
" @cache\n",
281+
" def _get_actuator_names(self):\n",
282+
" return [j.name for j in self._get_actuators()]\n",
283+
" \n",
284+
" @cache\n",
285+
" def _get_actuators(self):\n",
286+
" return Actuators.get_actuated_joints()\n",
275287
"\n",
288+
" @cache\n",
276289
" def _get_cube_id(self):\n",
277290
" return self.physics.model.name2id(\"object1\", \"body\")\n",
278291
" def _get_cube_pos(self):\n",
279292
" return self.physics.data.xpos[self._get_cube_id()]\n",
293+
" \n",
294+
" @cache\n",
280295
" def _get_cube_original_pos(self):\n",
281296
" return self.physics.model.body(\"object1\").pos\n",
282297
" \n",
@@ -320,7 +335,7 @@
320335
" time.sleep(time_until_next_step)\n",
321336
"\n",
322337
" # Apply the action to the joints\n",
323-
" for index, name in enumerate([j.name for j in Actuators.get_arm_joints()]):\n",
338+
" for index, name in enumerate(self._get_actuator_names()):\n",
324339
" self.physics.data.actuator(name).ctrl = action[index]\n",
325340
" \n",
326341
" # Step the simulation forward\n",
@@ -381,11 +396,11 @@
381396
"\n",
382397
" self.current_distance_to_target = float('inf')\n",
383398
"\n",
384-
" def _get_joints(self):\n",
385-
" # As defined in stretch.xml MJCF:\n",
386-
" arm_joints = ['joint_arm_l0', 'joint_arm_l1', 'joint_arm_l2', 'joint_arm_l3', 'joint_gripper_slide', 'joint_lift', 'joint_wrist_pitch', 'joint_wrist_roll', 'joint_wrist_yaw']\n",
387-
" finger_joints = ['joint_gripper_finger_left_open', 'joint_gripper_finger_right_open',] #'rubber_left_x', 'rubber_left_y', 'rubber_right_x', 'rubber_right_y']\n",
388-
" return arm_joints"
399+
"\n",
400+
" @cache\n",
401+
" def _get_actuators(self):\n",
402+
" return Actuators.get_arm_joints()\n",
403+
" \n"
389404
]
390405
},
391406
{
@@ -704,7 +719,7 @@
704719
"spec = mujoco.MjSpec.from_file(xml_path)\n",
705720
"spec = mujoco.MjSpec.from_file(xml_path)\n",
706721
"spec.find_body(\"object1\").pos = stretchPushCubeTrainingArmOnly.target_position\n",
707-
"spec.meshdir = \"../models/assets/\"\n",
722+
"spec.meshdir = \"../../stretch_mujoco/models/assets/\"\n",
708723
"\"\"\"meshdir is relative here, it should be the same as spec.modelfiledir, but mujoco expects meshdir to be a relative dir? If you get a Not Found error, your relative path may be wrong.\"\"\"\n",
709724
"spec.texturedir = spec.meshdir\n",
710725
"spec.compile()\n",

stretch_mujoco/enums/actuators.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class Actuators(Enum):
77
"""
8-
An enum for the joints defined in the URDF.
8+
An enum for the joints defined in the MJCF (stretch.xml).
99
"""
1010

1111
arm = 0
@@ -24,12 +24,33 @@ class Actuators(Enum):
2424
@classmethod
2525
def get_arm_joints(cls) -> list["Actuators"]:
2626
return [
27-
cls.lift,
2827
cls.arm,
2928
cls.wrist_pitch,
3029
cls.wrist_roll,
3130
cls.wrist_yaw,
3231
]
32+
33+
@classmethod
34+
def get_actuated_joints(cls) -> list["Actuators"]:
35+
return [actuator for actuator in cls if actuator != cls.base_rotate and actuator!= cls.base_translate]
36+
37+
def get_joint_names_in_mjcf(self):
38+
"""
39+
An actuator may have multiple joints. Return their names here.
40+
"""
41+
if self == Actuators.left_wheel_vel: return [ "joint_left_wheel"]
42+
if self == Actuators.right_wheel_vel: return ["joint_right_wheel"]
43+
if self == Actuators.lift: return ["joint_lift"]
44+
if self == Actuators.arm: return ['joint_arm_l0', 'joint_arm_l1', 'joint_arm_l2', 'joint_arm_l3']
45+
if self == Actuators.wrist_yaw: return ["joint_wrist_yaw"]
46+
if self == Actuators.wrist_pitch: return [ "joint_wrist_pitch"]
47+
if self == Actuators.wrist_roll: return ["joint_wrist_roll"]
48+
if self == Actuators.gripper: return ["joint_gripper_slide"]
49+
if self == Actuators.head_pan: return [ "joint_head_pan"]
50+
if self == Actuators.head_tilt: return ["joint_head_tilt"]
51+
52+
raise NotImplementedError(f"Joint names for {self} are not defined.")
53+
3354
def _get_status_attribute(self, is_position: bool, status: StretchStatus) -> float:
3455
attribute_name = "pos" if is_position else "vel"
3556
if self == Actuators.arm:

0 commit comments

Comments
 (0)