Skip to content

Commit

Permalink
Merge branch 'main' into offlinerl
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 authored Sep 30, 2024
2 parents 61c6b95 + 94b84ba commit 0cfa784
Show file tree
Hide file tree
Showing 4 changed files with 670 additions and 2 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
![Python](https://img.shields.io/badge/python-3.8%20%7C%203.9-blue)
[![arXiv](https://img.shields.io/badge/arXiv-2406.17490-b31b1b.svg)](https://arxiv.org/abs/2406.17490)
[![Website](https://img.shields.io/badge/Website-Visit%20Now-blue)](https://bricksrl.github.io/ProjectPage/)
[![Discord](https://img.shields.io/badge/Join_our_Discord-7289da?logo=discord&logoColor=ffffff&labelColor=7289da)](https://discord.gg/qdTsFaVfZm)


BricksRL allows the training of custom LEGO robots using deep reinforcement learning. By integrating [PyBricks](https://pybricks.com/) and [TorchRL](https://pytorch.org/rl/stable/index.html), it facilitates efficient real-world training via Bluetooth communication between LEGO hubs and a local computing device. Check out our [paper](https://arxiv.org/abs/2406.17490)!
Expand Down Expand Up @@ -173,6 +174,8 @@ The datasets can be downloaded from huggingface and contain expert and random tr

The datasets consist of TensorDicts containing expert and random transitions, which can be directly loaded into the replay buffer. When initiating (pre-)training, simply provide the path to the desired TensorDict when prompted to load the replay buffer.

## High-Level Examples
In the [example notebook](example_notebook.ipynb) we provide high-level training examples to train a **SAC agent** in the **RoboArmSim-v0** environment and a **TD3 agent** in the **WalkerSim-v0** enviornment.


#### Pretrain an Agent
Expand Down
4 changes: 2 additions & 2 deletions environments/roboarm_mixed_v0/RoboArmMixedEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
self.observation_key: torch.tensor(observation, dtype=torch.float32),
self.pixels_observation_key: torch.from_numpy(resized_frame)[
None, :
].float(),
].to(torch.uint8),
self.original_pixels_key: torch.from_numpy(frame)[None, :].to(
torch.uint8
),
Expand Down Expand Up @@ -324,7 +324,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
),
self.pixels_observation_key: torch.from_numpy(resized_frame)[
None, :
].float(),
].to(torch.uint8),
self.original_pixels_key: torch.from_numpy(frame)[None, :].to(
torch.uint8
),
Expand Down
664 changes: 664 additions & 0 deletions example_notebook.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/roboarm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def run(cfg: DictConfig) -> None:
step_start_time = time.time()
td = agent.get_eval_action(td)
td = env.step(td)
agent.add_experience(td)
if env_name in VIDEO_LOGGING_ENVS:
image_caputres.append(
td.get(("next", "original_pixels")).cpu().numpy()
Expand Down

0 comments on commit 0cfa784

Please sign in to comment.