|
249 | 249 | "\"\"\"\n",
|
250 | 250 | "This is a training task that tries to push a cube to a target location using all the joints and actuators on the robot.\n",
|
251 | 251 | "\"\"\"\n",
|
| 252 | + "from functools import cache\n", |
| 253 | + "\n", |
| 254 | + "\n", |
252 | 255 | "class StretchPushCubeTraining:\n",
|
253 | 256 | " def __init__(self, physics: mujoco.Physics, push_cube_by:tuple[float,float,float]):\n",
|
254 | 257 | " self.physics = physics\n",
|
|
269 | 272 | "\n",
|
270 | 273 | " self.current_distance_to_target = float('inf')\n",
|
271 | 274 | "\n",
|
| 275 | + " @cache\n", |
272 | 276 | " def _get_joints(self):\n",
|
273 |
| - " \"\"\"Gets joints, but removes unnamed ones (which we probably don't care about)\"\"\"\n", |
274 |
| - " return [name for name in self.physics.named.model.name_jntadr.axes.row.names if name != \"\"]\n", |
| 277 | + " \"\"\"Gets joint names in MJCF\"\"\"\n", |
| 278 | + " return [name for j in self._get_actuators() for name in j.get_joint_names_in_mjcf()]\n", |
| 279 | + " \n", |
| 280 | + " @cache\n", |
| 281 | + " def _get_actuator_names(self):\n", |
| 282 | + " return [j.name for j in self._get_actuators()]\n", |
| 283 | + " \n", |
| 284 | + " @cache\n", |
| 285 | + " def _get_actuators(self):\n", |
| 286 | + " return Actuators.get_actuated_joints()\n", |
275 | 287 | "\n",
|
| 288 | + " @cache\n", |
276 | 289 | " def _get_cube_id(self):\n",
|
277 | 290 | " return self.physics.model.name2id(\"object1\", \"body\")\n",
|
278 | 291 | " def _get_cube_pos(self):\n",
|
279 | 292 | " return self.physics.data.xpos[self._get_cube_id()]\n",
|
| 293 | + " \n", |
| 294 | + " @cache\n", |
280 | 295 | " def _get_cube_original_pos(self):\n",
|
281 | 296 | " return self.physics.model.body(\"object1\").pos\n",
|
282 | 297 | " \n",
|
|
320 | 335 | " time.sleep(time_until_next_step)\n",
|
321 | 336 | "\n",
|
322 | 337 | " # Apply the action to the joints\n",
|
323 |
| - " for index, name in enumerate([j.name for j in Actuators.get_arm_joints()]):\n", |
| 338 | + " for index, name in enumerate(self._get_actuator_names()):\n", |
324 | 339 | " self.physics.data.actuator(name).ctrl = action[index]\n",
|
325 | 340 | " \n",
|
326 | 341 | " # Step the simulation forward\n",
|
|
381 | 396 | "\n",
|
382 | 397 | " self.current_distance_to_target = float('inf')\n",
|
383 | 398 | "\n",
|
384 |
| - " def _get_joints(self):\n", |
385 |
| - " # As defined in stretch.xml MJCF:\n", |
386 |
| - " arm_joints = ['joint_arm_l0', 'joint_arm_l1', 'joint_arm_l2', 'joint_arm_l3', 'joint_gripper_slide', 'joint_lift', 'joint_wrist_pitch', 'joint_wrist_roll', 'joint_wrist_yaw']\n", |
387 |
| - " finger_joints = ['joint_gripper_finger_left_open', 'joint_gripper_finger_right_open',] #'rubber_left_x', 'rubber_left_y', 'rubber_right_x', 'rubber_right_y']\n", |
388 |
| - " return arm_joints" |
| 399 | + "\n", |
| 400 | + " @cache\n", |
| 401 | + " def _get_actuators(self):\n", |
| 402 | + " return Actuators.get_arm_joints()\n", |
| 403 | + " \n" |
389 | 404 | ]
|
390 | 405 | },
|
391 | 406 | {
|
|
704 | 719 | "spec = mujoco.MjSpec.from_file(xml_path)\n",
|
705 | 720 | "spec = mujoco.MjSpec.from_file(xml_path)\n",
|
706 | 721 | "spec.find_body(\"object1\").pos = stretchPushCubeTrainingArmOnly.target_position\n",
|
707 |
| - "spec.meshdir = \"../models/assets/\"\n", |
| 722 | + "spec.meshdir = \"../../stretch_mujoco/models/assets/\"\n", |
708 | 723 | "\"\"\"meshdir is relative here, it should be the same as spec.modelfiledir, but mujoco expects meshdir to be a relative dir? If you get a Not Found error, your relative path may be wrong.\"\"\"\n",
|
709 | 724 | "spec.texturedir = spec.meshdir\n",
|
710 | 725 | "spec.compile()\n",
|
|
0 commit comments