diff --git a/.gitignore b/.gitignore
index 374c6a1..a4fac52 100644
--- a/.gitignore
+++ b/.gitignore
@@ -132,4 +132,6 @@ dmypy.json
# wandb
wandb/
# hydra
-outputs/
\ No newline at end of file
+outputs/
+# .pth files
+*.pth
\ No newline at end of file
diff --git a/README.md b/README.md
index 4b667aa..ff9629b 100644
--- a/README.md
+++ b/README.md
@@ -117,7 +117,7 @@ Before running experiments, please review and modify the configuration settings
### Robots
-Robots utilized for our experiments. Building instructions can be found [here](https://sites.google.com/view/bricksrl/building-instructions).
+Robots utilized for our experiments. Building instructions can be found [here](https://bricksrl.github.io/ProjectPage/).
|  |  |  |
|:--:|:--:|:--:|
@@ -140,7 +140,7 @@ Robots utilized for our experiments. Building instructions can be found [here](h
Click me
-Evaluation videos of the trained agents can be found [here](https://sites.google.com/view/bricksrl/main).
+Evaluation videos of the trained agents can be found [here](https://bricksrl.github.io/ProjectPage/).
### 2Wheeler Results:
@@ -159,7 +159,44 @@ Evaluation videos of the trained agents can be found [here](https://sites.google
+### Offline RL
+
+ Click me
+With the use of precollected [offline datasets](https://huggingface.co/datasets/compsciencelab/BricksRL-Datasets) we can pretrain agents with offline RL to perform a task without the need of real world interaction. Such pretrained policies can be evaluated directly or used for later training to fine tuning the pretrained policy on the real robot.
+
+#### Datasets
+The datasets can be downloaded from huggingface and contain expert and random transitions for the 2Wheeler (RunAway-v0 and Spinning-v0), Walker (Walker-v0) and RoboArm (RoboArm-v0) robots.
+
+ ```bash
+ git lfs install
+ git clone git@hf.co:datasets/compsciencelab/BricksRL-Datasets
+ ```
+
+The datasets consist of TensorDicts containing expert and random transitions, which can be directly loaded into the replay buffer. When initiating (pre-)training, simply provide the path to the desired TensorDict when prompted to load the replay buffer.
## High-Level Examples
In the [example notebook](example_notebook.ipynb) we provide high-level training examples to train a **SAC agent** in the **RoboArmSim-v0** environment and a **TD3 agent** in the **WalkerSim-v0** enviornment.
+
+#### Pretrain an Agent
+
+The execution of an experiment for offline training is similar to the online training except that you run the **pretrain.py** script:
+
+ ```bash
+ python experiments/walker/pretrain.py
+ ```
+
+Trained policies can then be evaluated as before with:
+
+ ```bash
+ python experiments/walker/eval.py
+ ```
+
+Or run training for fine-tuning the policy on the real robot:
+
+ ```bash
+ python experiments/walker/train.py
+ ```
+
+
+
diff --git a/conf/agent/bc.yaml b/conf/agent/bc.yaml
new file mode 100644
index 0000000..37e07ed
--- /dev/null
+++ b/conf/agent/bc.yaml
@@ -0,0 +1,11 @@
+name: bc
+lr: 3e-4
+batch_size: 256
+num_updates: 1
+prefill_episodes: 0
+
+
+policy_type: deterministic # stochastic or deterministic
+num_cells: 256
+dropout: 0.01
+normalization: LayerNorm
diff --git a/conf/agent/cql.yaml b/conf/agent/cql.yaml
new file mode 100644
index 0000000..07c2082
--- /dev/null
+++ b/conf/agent/cql.yaml
@@ -0,0 +1,28 @@
+name: cql
+lr: 3e-4
+batch_size: 256
+num_updates: 1
+prefill_episodes: 10
+
+bc_steps: 1000
+
+# CQL specific
+num_cells: 256
+gamma: 0.99
+soft_update_eps: 0.995
+loss_function: l2
+temperature: 1.0
+min_q_weight: 1.0
+max_q_backup: False
+deterministic_backup: False
+num_random: 10
+with_lagrange: True
+lagrange_thresh: 5.0 # tau
+
+normalization: None
+dropout: 0.0
+
+prb: 0
+buffer_size: 1000000
+pretrain: False
+reset_params: False
\ No newline at end of file
diff --git a/conf/agent/iql.yaml b/conf/agent/iql.yaml
new file mode 100644
index 0000000..d2ca4a8
--- /dev/null
+++ b/conf/agent/iql.yaml
@@ -0,0 +1,20 @@
+name: iql
+lr: 3e-4
+batch_size: 256
+num_updates: 1
+prefill_episodes: 0
+
+num_cells: 256
+gamma: 0.99
+soft_update_eps: 0.995
+loss_function: l2
+temperature: 1.0
+expectile: 0.5
+
+normalization: None
+dropout: 0.0
+
+prb: 0
+buffer_size: 1000000
+pretrain: False
+reset_params: False
\ No newline at end of file
diff --git a/conf/agent/td3.yaml b/conf/agent/td3.yaml
index 4465a71..f341f50 100644
--- a/conf/agent/td3.yaml
+++ b/conf/agent/td3.yaml
@@ -15,4 +15,6 @@ dropout: 0.0
prb: 0
buffer_size: 1000000
-reset_params: False
\ No newline at end of file
+reset_params: False
+use_bc: False
+alpha: 1.0
\ No newline at end of file
diff --git a/conf/config.yaml b/conf/config.yaml
index 609b137..478d1c4 100644
--- a/conf/config.yaml
+++ b/conf/config.yaml
@@ -4,10 +4,10 @@ run_name: ""
verbose: 0
device: "cuda"
-episodes: 200
+episodes: 250
defaults:
- _self_
# random, sac, td3, droq
- agent: sac
- - env: roboarm_sim-v0
\ No newline at end of file
+ - env: walker_sim-v0
\ No newline at end of file
diff --git a/environments/__init__.py b/environments/__init__.py
index fe070f6..e53dbf7 100644
--- a/environments/__init__.py
+++ b/environments/__init__.py
@@ -2,9 +2,7 @@
from torchrl.envs import (
CatFrames,
Compose,
- DoubleToFloat,
ObservationNorm,
- RewardSum,
ToTensorImage,
TransformedEnv,
)
@@ -17,28 +15,32 @@
from environments.walker_v0.WalkerEnv import WalkerEnv_v0
from environments.walker_v0.WalkerEnvSim import WalkerEnvSim_v0
-
VIDEO_LOGGING_ENVS = ["roboarm_mixed-v0", "walker_mixed-v0"]
ALL_2WHEELER_ENVS = ["spinning-v0", "runaway-v0"]
ALL_WALKER_ENVS = [
"walker-v0",
"walker_sim-v0",
]
-ALL_ROBOARM_ENVS = ["roboarm-v0", "roboarm_mixed-v0", "roboarm_sim-v0"]
+ALL_ROBOARM_ENVS = [
+ "roboarm-v0",
+ "roboarm_mixed-v0",
+ "roboarm_sim-v0",
+]
ALL_ENVS = ALL_2WHEELER_ENVS + ALL_WALKER_ENVS + ALL_ROBOARM_ENVS
-def make_env(config):
+def make_env(config, pretrain=False):
"""
Creates a new environment based on the provided configuration.
Args:
config: A configuration object containing the environment name and maximum episode steps.
+ pretrain: A boolean indicating whether the environment is for pretraining.
Returns:
A tuple containing the new environment, its action space, and its state space.
"""
- env = make(name=config.env.name, env_conf=config.env)
+ env = make(name=config.env.name, env_conf=config.env, pretrain=pretrain)
observation_keys = [key for key in env.observation_spec.keys()]
transforms = []
@@ -76,24 +78,27 @@ def make_env(config):
return env, action_spec, state_spec
-def make(name="RunAway", env_conf=None):
+def make(name="RunAway", env_conf=None, pretrain=False):
if name == "runaway-v0":
return RunAwayEnv_v0(
max_episode_steps=env_conf.max_episode_steps,
min_distance=env_conf.min_distance,
verbose=env_conf.verbose,
+ pretrain=pretrain,
)
elif name == "spinning-v0":
return SpinningEnv_v0(
max_episode_steps=env_conf.max_episode_steps,
sleep_time=env_conf.sleep_time,
verbose=env_conf.verbose,
+ pretrain=pretrain,
)
elif name == "walker-v0":
return WalkerEnv_v0(
max_episode_steps=env_conf.max_episode_steps,
verbose=env_conf.verbose,
sleep_time=env_conf.sleep_time,
+ pretrain=pretrain,
)
elif name == "walker_sim-v0":
return WalkerEnvSim_v0(
@@ -109,6 +114,7 @@ def make(name="RunAway", env_conf=None):
verbose=env_conf.verbose,
sleep_time=env_conf.sleep_time,
reward_signal=env_conf.reward_signal,
+ pretrain=pretrain,
)
elif name == "roboarm_sim-v0":
return RoboArmSimEnv_v0(
@@ -125,6 +131,7 @@ def make(name="RunAway", env_conf=None):
reward_signal=env_conf.reward_signal,
camera_id=env_conf.camera_id,
goal_radius=env_conf.goal_radius,
+ pretrain=pretrain,
)
else:
print("Environment not found")
diff --git a/environments/base/base_env.py b/environments/base/base_env.py
index 761a034..bc8f936 100644
--- a/environments/base/base_env.py
+++ b/environments/base/base_env.py
@@ -16,12 +16,15 @@ class BaseEnv(EnvBase):
Args:
action_dim (int): The dimensionality of the action space.
state_dim (int): The dimensionality of the state space.
+ use_hub (bool): Whether to use the Pybricks hub for communication, if False, only the observation spec and action specs are created and can be used.
+ verbose (bool): Whether to print verbose output.
"""
def __init__(
self,
action_dim: int,
state_dim: int,
+ use_hub: bool = True,
verbose: bool = False,
):
self.verbose = verbose
@@ -36,11 +39,14 @@ def __init__(
# buffer state in case of missing data
self.buffered_state = np.zeros(self.state_dim, dtype=np.float32)
- self.hub = PybricksHub(
- state_dim=state_dim, out_format_str=self.state_format_str
- )
- self.hub.connect()
- print("Connected to hub.")
+ if use_hub:
+ self.hub = PybricksHub(
+ state_dim=state_dim, out_format_str=self.state_format_str
+ )
+ self.hub.connect()
+ print("Connected to hub.")
+ else:
+ self.hub = None
super().__init__(batch_size=torch.Size([1]))
def send_to_hub(self, action: np.array) -> None:
@@ -129,6 +135,8 @@ class BaseSimEnv(EnvBase):
Args:
action_dim (int): The dimensionality of the action space.
state_dim (int): The dimensionality of the state space.
+ verbose (bool): Whether to print verbose output.
+ use_hub (bool): This argument is kept for compatibility but is not used in the simulation environment.
"""
def __init__(
@@ -136,6 +144,7 @@ def __init__(
action_dim: int,
state_dim: int,
verbose: bool = False,
+ use_hub: bool = False,
):
self.verbose = verbose
self.action_dim = action_dim
diff --git a/environments/dummy/mixed_obs_dummy.py b/environments/dummy/mixed_obs_dummy.py
index 09b9dd3..14f9b16 100644
--- a/environments/dummy/mixed_obs_dummy.py
+++ b/environments/dummy/mixed_obs_dummy.py
@@ -21,7 +21,7 @@ class MixedObsDummyEnv(EnvBase):
observation_key = "observation"
pixel_observation_key = "pixels"
- def __init__(self, max_episode_steps=10):
+ def __init__(self, max_episode_steps=10, img_shape=(64, 64, 3)):
self.max_episode_steps = max_episode_steps
self._batch_size = torch.Size([1])
self.action_spec = BoundedTensorSpec(
@@ -36,8 +36,8 @@ def __init__(self, max_episode_steps=10):
)
pixel_observation_spec = BoundedTensorSpec(
- low=torch.zeros((1,) + (64, 64, 3), dtype=torch.uint8),
- high=torch.ones((1,) + (64, 64, 3), dtype=torch.uint8) * 255,
+ low=torch.zeros((1,) + img_shape, dtype=torch.uint8),
+ high=torch.ones((1,) + img_shape, dtype=torch.uint8) * 255,
)
self.observation_spec = CompositeSpec(shape=(1,))
diff --git a/environments/roboarm_mixed_v0/RoboArmMixedEnv.py b/environments/roboarm_mixed_v0/RoboArmMixedEnv.py
index 6a72a5d..3f5c123 100644
--- a/environments/roboarm_mixed_v0/RoboArmMixedEnv.py
+++ b/environments/roboarm_mixed_v0/RoboArmMixedEnv.py
@@ -68,6 +68,7 @@ def __init__(
max_episode_steps: int = 50,
sleep_time: float = 0.0,
verbose: bool = False,
+ pretrain: bool = False,
reward_signal: str = "dense",
camera_id: int = 0,
goal_radius: float = 25,
@@ -131,7 +132,10 @@ def __init__(
self.goal_positions = self.init_camera_position()
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=1 - pretrain,
)
def init_camera_position(
diff --git a/environments/roboarm_v0/RoboArmEnv.py b/environments/roboarm_v0/RoboArmEnv.py
index cd8028a..ea8b176 100644
--- a/environments/roboarm_v0/RoboArmEnv.py
+++ b/environments/roboarm_v0/RoboArmEnv.py
@@ -32,6 +32,7 @@ def __init__(
max_episode_steps: int = 50,
sleep_time: float = 0.0,
verbose: bool = False,
+ pretrain: bool = False,
reward_signal: str = "dense",
):
self.sleep_time = sleep_time
@@ -77,7 +78,10 @@ def __init__(
self.observation_spec.set(self.observation_key, observation_spec)
self.observation_spec.set(self.goal_observation_key, observation_spec)
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=1 - pretrain,
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
diff --git a/environments/roboarm_v0/RoboArmSim.py b/environments/roboarm_v0/RoboArmSim.py
index d0c3ee7..cca28cb 100644
--- a/environments/roboarm_v0/RoboArmSim.py
+++ b/environments/roboarm_v0/RoboArmSim.py
@@ -77,7 +77,10 @@ def __init__(
self.observation_spec.set(self.observation_key, observation_spec)
self.observation_spec.set(self.goal_observation_key, observation_spec)
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=False,
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
diff --git a/environments/runaway_v0/RunAwayEnv.py b/environments/runaway_v0/RunAwayEnv.py
index 89297ed..7adb5ba 100644
--- a/environments/runaway_v0/RunAwayEnv.py
+++ b/environments/runaway_v0/RunAwayEnv.py
@@ -45,6 +45,7 @@ def __init__(
min_distance: float = 40,
sleep_time: float = 0.2,
verbose: bool = False,
+ pretrain: bool = False,
):
self.sleep_time = sleep_time
self.min_distance = min_distance
@@ -81,7 +82,10 @@ def __init__(
)
self.verbose = verbose
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=1 - pretrain,
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
diff --git a/environments/spinning_v0/SpinningEnv.py b/environments/spinning_v0/SpinningEnv.py
index 84a86fe..eab608b 100644
--- a/environments/spinning_v0/SpinningEnv.py
+++ b/environments/spinning_v0/SpinningEnv.py
@@ -39,6 +39,7 @@ def __init__(
max_episode_steps: int = 50,
sleep_time: float = 0.2,
verbose: bool = False,
+ pretrain: bool = False,
):
self.sleep_time = sleep_time
self._batch_size = torch.Size([1])
@@ -74,7 +75,10 @@ def __init__(
)
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=1 - pretrain,
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
diff --git a/environments/walker_v0/WalkerEnv.py b/environments/walker_v0/WalkerEnv.py
index 6371cfb..9d8bdb9 100644
--- a/environments/walker_v0/WalkerEnv.py
+++ b/environments/walker_v0/WalkerEnv.py
@@ -45,6 +45,7 @@ def __init__(
max_episode_steps: int = 50,
sleep_time: float = 0.0,
verbose: bool = False,
+ pretrain: bool = False,
):
self.sleep_time = sleep_time
self._batch_size = torch.Size([1])
@@ -83,7 +84,10 @@ def __init__(
{self.observation_key: observation_spec}, shape=(1,)
)
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=1 - pretrain,
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
diff --git a/environments/walker_v0/WalkerEnvSim.py b/environments/walker_v0/WalkerEnvSim.py
index 2ab17a2..1a1e8e4 100644
--- a/environments/walker_v0/WalkerEnvSim.py
+++ b/environments/walker_v0/WalkerEnvSim.py
@@ -74,7 +74,10 @@ def __init__(
{self.observation_key: observation_spec}, shape=(1,)
)
super().__init__(
- action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose
+ action_dim=self.action_dim,
+ state_dim=self.state_dim,
+ verbose=verbose,
+ use_hub=False,
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
diff --git a/experiments/2wheeler/eval.py b/experiments/2wheeler/eval.py
index 3d4d978..7d91075 100644
--- a/experiments/2wheeler/eval.py
+++ b/experiments/2wheeler/eval.py
@@ -91,7 +91,6 @@ def run(cfg: DictConfig) -> None:
except KeyboardInterrupt:
print("Evaluation interrupted by user.")
-
logout(agent)
env.close()
diff --git a/experiments/2wheeler/pretrain.py b/experiments/2wheeler/pretrain.py
new file mode 100644
index 0000000..0849e94
--- /dev/null
+++ b/experiments/2wheeler/pretrain.py
@@ -0,0 +1,61 @@
+import os
+import sys
+
+import hydra
+import wandb
+from omegaconf import DictConfig, OmegaConf
+from tqdm import tqdm
+
+# Add the project root to PYTHONPATH
+project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+from environments import make_env
+from src.agents import get_agent
+from src.utils import login, logout, setup_check, tensordict2dict
+
+
+@hydra.main(version_base=None, config_path=project_root + "/conf", config_name="config")
+def run(cfg: DictConfig) -> None:
+ print(OmegaConf.to_yaml(cfg))
+
+ # make environment.
+ setup_check(robot="2wheeler", config=cfg)
+ env, action_space, state_space = make_env(cfg, pretrain=True)
+
+ # make agent
+ agent, project_name = get_agent(action_space, state_space, cfg)
+ login(agent)
+
+ # initialize wandb
+ wandb.init(project=project_name)
+ wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
+ wandb.watch(agent.actor, log_freq=1) if agent.actor else None
+
+ batch_size = cfg.agent.batch_size
+ num_updates = cfg.agent.num_updates
+ train_episodes = cfg.episodes
+ print("Start training...")
+ try:
+ for e in tqdm(range(train_episodes), desc="Training"):
+
+ loss_info = agent.train(batch_size=batch_size, num_updates=num_updates)
+
+ # Metrics Logging
+ log_dict = {
+ "epoch": e,
+ "buffer_size": agent.replay_buffer.__len__(),
+ }
+ log_dict.update(tensordict2dict(loss_info))
+ wandb.log(log_dict)
+
+ except KeyboardInterrupt:
+ print("Training interrupted by user.")
+
+ logout(agent)
+ env.close()
+
+
+if __name__ == "__main__":
+ run()
diff --git a/experiments/roboarm/eval.py b/experiments/roboarm/eval.py
index 8a088a8..c482ca1 100644
--- a/experiments/roboarm/eval.py
+++ b/experiments/roboarm/eval.py
@@ -61,6 +61,7 @@ def run(cfg: DictConfig) -> None:
image_caputres.append(
td.get(("next", "original_pixels")).cpu().numpy()
)
+ agent.add_experience(td)
total_agent_step_time = time.time() - step_start_time
total_step_times.append(total_agent_step_time)
done = td.get(("next", "done"), False)
diff --git a/experiments/roboarm/pretrain.py b/experiments/roboarm/pretrain.py
new file mode 100644
index 0000000..5ff54d7
--- /dev/null
+++ b/experiments/roboarm/pretrain.py
@@ -0,0 +1,61 @@
+import os
+import sys
+
+import hydra
+import wandb
+from omegaconf import DictConfig, OmegaConf
+from tqdm import tqdm
+
+# Add the project root to PYTHONPATH
+project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+from environments import make_env
+from src.agents import get_agent
+from src.utils import login, logout, setup_check, tensordict2dict
+
+
+@hydra.main(version_base=None, config_path=project_root + "/conf", config_name="config")
+def run(cfg: DictConfig) -> None:
+ print(OmegaConf.to_yaml(cfg))
+
+ # make environment.
+ setup_check(robot="roboarm", config=cfg)
+ env, action_space, state_space = make_env(cfg, pretrain=True)
+
+ # make agent
+ agent, project_name = get_agent(action_space, state_space, cfg)
+ login(agent)
+
+ # initialize wandb
+ wandb.init(project=project_name)
+ wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
+ wandb.watch(agent.actor, log_freq=1) if agent.actor else None
+
+ batch_size = cfg.agent.batch_size
+ num_updates = cfg.agent.num_updates
+ train_episodes = cfg.episodes
+ print("Start training...")
+ try:
+ for e in tqdm(range(train_episodes), desc="Training"):
+
+ loss_info = agent.train(batch_size=batch_size, num_updates=num_updates)
+
+ # Metrics Logging
+ log_dict = {
+ "epoch": e,
+ "buffer_size": agent.replay_buffer.__len__(),
+ }
+ log_dict.update(tensordict2dict(loss_info))
+ wandb.log(log_dict)
+
+ except KeyboardInterrupt:
+ print("Training interrupted by user.")
+
+ logout(agent)
+ env.close()
+
+
+if __name__ == "__main__":
+ run()
diff --git a/experiments/walker/pretrain.py b/experiments/walker/pretrain.py
new file mode 100644
index 0000000..602d3e6
--- /dev/null
+++ b/experiments/walker/pretrain.py
@@ -0,0 +1,61 @@
+import os
+import sys
+
+import hydra
+import wandb
+from omegaconf import DictConfig, OmegaConf
+from tqdm import tqdm
+
+# Add the project root to PYTHONPATH
+project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+from environments import make_env
+from src.agents import get_agent
+from src.utils import login, logout, setup_check, tensordict2dict
+
+
+@hydra.main(version_base=None, config_path=project_root + "/conf", config_name="config")
+def run(cfg: DictConfig) -> None:
+ print(OmegaConf.to_yaml(cfg))
+
+ # make environment.
+ setup_check(robot="walker", config=cfg)
+ env, action_space, state_space = make_env(cfg, pretrain=True)
+
+ # make agent
+ agent, project_name = get_agent(action_space, state_space, cfg)
+ login(agent)
+
+ # initialize wandb
+ wandb.init(project=project_name)
+ wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
+ wandb.watch(agent.actor, log_freq=1) if agent.actor else None
+
+ batch_size = cfg.agent.batch_size
+ num_updates = cfg.agent.num_updates
+ train_episodes = cfg.episodes
+ print("Start training...")
+ try:
+ for e in tqdm(range(train_episodes), desc="Training"):
+
+ loss_info = agent.train(batch_size=batch_size, num_updates=num_updates)
+
+ # Metrics Logging
+ log_dict = {
+ "epoch": e,
+ "buffer_size": agent.replay_buffer.__len__(),
+ }
+ log_dict.update(tensordict2dict(loss_info))
+ wandb.log(log_dict)
+
+ except KeyboardInterrupt:
+ print("Training interrupted by user.")
+
+ logout(agent)
+ env.close()
+
+
+if __name__ == "__main__":
+ run()
diff --git a/requirements.txt b/requirements.txt
index 7ca509c..8b7f967 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
pybricksdev
-tensordict==0.4.0
-torchrl==0.4.0
+tensordict==0.5.0
+torchrl==0.5.0
hydra-core==1.3.2
wandb==0.16.1
opencv-python==4.9.0.80
@@ -9,4 +9,5 @@ tqdm==4.66.1
pytest==8.0.2
ufmt
pre-commit
-numpy==1.24.1
\ No newline at end of file
+numpy==1.24.1
+pynput
\ No newline at end of file
diff --git a/src/agents/__init__.py b/src/agents/__init__.py
index cf74e48..5e5cafd 100644
--- a/src/agents/__init__.py
+++ b/src/agents/__init__.py
@@ -1,8 +1,11 @@
+from src.agents.behavior_cloning import BehavioralCloningAgent
+from src.agents.cql import CQLAgent
+from src.agents.iql import IQLAgent
from src.agents.random import RandomAgent
from src.agents.sac import SACAgent
from src.agents.td3 import TD3Agent
-all_agents = ["td3", "sac", "random"]
+all_agents = ["td3", "sac", "iql", "cql", "bc", "random"]
def get_agent(action_spec, state_spec, cfg):
@@ -20,6 +23,13 @@ def get_agent(action_spec, state_spec, cfg):
agent_config=cfg.agent,
device=cfg.device,
)
+ elif cfg.agent.name == "bc":
+ agent = BehavioralCloningAgent(
+ action_spec=action_spec,
+ state_spec=state_spec,
+ agent_config=cfg.agent,
+ device=cfg.device,
+ )
elif cfg.agent.name == "random":
agent = RandomAgent(
action_spec=action_spec,
@@ -27,6 +37,20 @@ def get_agent(action_spec, state_spec, cfg):
agent_config=cfg.agent,
device=cfg.device,
)
+ elif cfg.agent.name == "iql":
+ agent = IQLAgent(
+ action_spec=action_spec,
+ state_spec=state_spec,
+ agent_config=cfg.agent,
+ device=cfg.device,
+ )
+ elif cfg.agent.name == "cql":
+ agent = CQLAgent(
+ action_spec=action_spec,
+ state_spec=state_spec,
+ agent_config=cfg.agent,
+ device=cfg.device,
+ )
else:
raise NotImplementedError(
f"Agent {cfg.agent.name} not implemented, please choose from {all_agents}"
diff --git a/src/agents/base.py b/src/agents/base.py
index d43b6d9..8605478 100644
--- a/src/agents/base.py
+++ b/src/agents/base.py
@@ -13,18 +13,17 @@ class BaseAgent:
"""Implements a base agent used to interact with the lego robots.
Args:
- state_space (gym.Space): The state space of the environment.
- action_space (gym.Space): The action space of the environment.
- device (torch.device): The device to use for computation.
- observation_keys (Tuple[str]): The keys used to access the observation in the tensor dictionary.
+ state_spec (TensorSpec): The state specification of the environment.
+ action_spec (TensorSpec): The action specification of the environment.
+ agent_name (str): The name of the agent.
+ device (str): The device to use for computation.
Attributes:
- state_space (gym.Space): The state space of the environment.
- action_space (gym.Space): The action space of the environment.
- state_dim (int): The dimension of the state space.
- action_dim (int): The dimension of the action space.
- device (torch.device): The device to use for computation.
- observation_keys (Tuple[str]): The keys used to access the observation in the tensor dictionary.
+ name (str): The name of the agent.
+ observation_spec (TensorSpec): The state specification of the environment.
+ action_spec (TensorSpec): The action specification of the environment.
+ device (str): The device to use for computation.
+ observation_keys (List[str]): The keys used to access the observation in the tensor dictionary.
"""
def __init__(
diff --git a/src/agents/behavior_cloning.py b/src/agents/behavior_cloning.py
new file mode 100644
index 0000000..bfb39cb
--- /dev/null
+++ b/src/agents/behavior_cloning.py
@@ -0,0 +1,153 @@
+import numpy as np
+import tensordict as td
+import torch
+from tensordict import TensorDictBase
+from torch import nn, optim
+from torchrl.data import BoundedTensorSpec, TensorDictReplayBuffer
+
+from torchrl.data.replay_buffers.storages import LazyMemmapStorage
+from torchrl.envs import RenameTransform, ToTensorImage
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+
+from src.agents.base import BaseAgent
+from src.networks.networks import get_deterministic_actor, get_stochastic_actor
+
+
+def initialize(net, std=0.02):
+ for p, n in net.named_parameters():
+ if "weight" in p:
+ # nn.init.xavier_uniform_(n)
+ nn.init.normal_(n, mean=0, std=std)
+ elif "bias" in p:
+ nn.init.zeros_(n)
+
+
+class BehavioralCloningAgent(BaseAgent):
+ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
+ super(BehavioralCloningAgent, self).__init__(
+ state_spec, action_spec, agent_config.name, device
+ )
+
+ if agent_config.policy_type == "deterministic":
+ self.actor = get_deterministic_actor(state_spec, action_spec, agent_config)
+ elif agent_config.policy_type == "stochastic":
+ raise NotImplementedError(
+ "Stochastic actor training is not implemented yet"
+ )
+ # TODO: Implement stochastic actor training
+ # self.actor = get_stochastic_actor(
+ # state_spec, action_spec, agent_config
+ # )
+ else:
+ raise ValueError(
+ "policy_type not recognized, choose deterministic or stochastic"
+ )
+ self.actor.to(device)
+ # initialize networks
+ self.init_nets([self.actor])
+
+ self.optimizer = optim.Adam(
+ self.actor.parameters(), lr=agent_config.lr, weight_decay=0.0
+ )
+
+ # create replay buffer
+ self.batch_size = agent_config.batch_size
+ self.replay_buffer = self.create_replay_buffer()
+
+ # general stats
+ self.collected_transitions = 0
+ self.do_pretrain = False
+ self.episodes = 0
+
+ def get_agent_statedict(self):
+ """Save agent"""
+ act_statedict = self.actor.state_dict()
+ return {"actor": act_statedict}
+
+ def load_model(self, path):
+ """load model"""
+ try:
+ statedict = torch.load(path)
+ self.actor.load_state_dict(statedict["actor"])
+ print("Model loaded")
+ except:
+ raise ValueError("Model not loaded")
+
+ def load_replaybuffer(self, path):
+ """load replay buffer"""
+ try:
+ loaded_data = TensorDictBase.load_memmap(path)
+ self.replay_buffer.extend(loaded_data)
+ if self.replay_buffer._batch_size != self.batch_size:
+ Warning(
+ "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size."
+ )
+ self.replay_buffer._batch_size = self.batch_size
+ print("Replay Buffer loaded")
+ print("Replay Buffer size: ", self.replay_buffer.__len__(), "\n")
+ except:
+ raise ValueError("Replay Buffer not loaded")
+
+ def eval(self):
+ """Sets the agent to evaluation mode."""
+ self.actor.eval()
+
+ @torch.no_grad()
+ def get_eval_action(self, td: TensorDictBase) -> TensorDictBase:
+ """Get eval action from actor network"""
+ with set_exploration_type(ExplorationType.MODE):
+ out_td = self.actor(td.to(self.device))
+ return out_td
+
+ def create_replay_buffer(
+ self,
+ buffer_size=1000000,
+ buffer_scratch_dir="./tmp",
+ device="cpu",
+ prefetch=3,
+ ):
+ """Create replay buffer"""
+
+ replay_buffer = TensorDictReplayBuffer(
+ pin_memory=False,
+ prefetch=prefetch,
+ storage=LazyMemmapStorage(
+ buffer_size,
+ scratch_dir=buffer_scratch_dir,
+ ),
+ batch_size=self.batch_size,
+ )
+ replay_buffer.append_transform(lambda x: x.to(device))
+ # TODO: check if we have image in observation space if so add this transform
+ # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True))
+
+ return replay_buffer
+
+ @torch.no_grad()
+ def get_action(self, td: TensorDictBase) -> TensorDictBase:
+ """Get action from actor network"""
+ with set_exploration_type(ExplorationType.RANDOM):
+ out_td = self.actor(td.to(self.device))
+ return out_td
+
+ def add_experience(self, transition: td.TensorDict):
+ """Add experience to replay buffer"""
+ """Add experience to replay buffer"""
+ self.replay_buffer.extend(transition)
+ self.collected_transitions += 1
+
+ def train(self, batch_size=64, num_updates=1):
+ """Train the agent"""
+ log_data = {}
+
+ for i in range(num_updates):
+ batch = self.replay_buffer.sample(batch_size).to(self.device)
+ orig_action = batch.get("action").clone()
+
+ out_dict = self.actor(batch)
+ loss = torch.mean((out_dict.get("action") - orig_action) ** 2)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ log_data.update({"loss": loss})
+ return log_data
diff --git a/src/agents/cql.py b/src/agents/cql.py
new file mode 100644
index 0000000..54d9d2d
--- /dev/null
+++ b/src/agents/cql.py
@@ -0,0 +1,252 @@
+import tensordict as td
+import torch
+from tensordict import TensorDictBase
+from torch import optim
+from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
+from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.objectives import SoftUpdate
+
+from torchrl.objectives.cql import CQLLoss
+
+from src.agents.base import BaseAgent
+from src.networks.networks import get_critic, get_stochastic_actor
+
+
+class CQLAgent(BaseAgent):
+ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
+ super(CQLAgent, self).__init__(
+ state_spec, action_spec, agent_config.name, device
+ )
+
+ with_lagrange = agent_config.with_lagrange
+
+ self.actor = get_stochastic_actor(state_spec, action_spec, agent_config)
+ self.critic = get_critic(state_spec, agent_config)
+
+ self.actor.to(device)
+ self.critic.to(device)
+
+ # initialize networks
+ self.init_nets([self.actor, self.critic])
+
+ # define loss function
+ self.loss_module = CQLLoss(
+ actor_network=self.actor,
+ qvalue_network=self.critic,
+ loss_function=agent_config.loss_function,
+ temperature=agent_config.temperature,
+ min_q_weight=agent_config.min_q_weight,
+ max_q_backup=agent_config.max_q_backup,
+ deterministic_backup=agent_config.deterministic_backup,
+ num_random=agent_config.num_random,
+ with_lagrange=agent_config.with_lagrange,
+ lagrange_thresh=agent_config.lagrange_thresh,
+ )
+ # Define Target Network Updater
+ self.target_net_updater = SoftUpdate(
+ self.loss_module, eps=agent_config.soft_update_eps
+ )
+ self.target_net_updater.init_()
+
+ # Reset weights
+ self.reset_params = agent_config.reset_params
+
+ # Define Replay Buffer
+ self.batch_size = agent_config.batch_size
+ self.replay_buffer = self.create_replay_buffer(
+ prb=agent_config.prb,
+ buffer_size=agent_config.buffer_size,
+ device=device,
+ )
+
+ # Define Optimizer
+ critic_params = list(
+ self.loss_module.qvalue_network_params.flatten_keys().values()
+ )
+ actor_params = list(
+ self.loss_module.actor_network_params.flatten_keys().values()
+ )
+ self.optimizer_actor = optim.Adam(
+ actor_params, lr=agent_config.lr, weight_decay=0.0
+ )
+ self.optimizer_critic = optim.Adam(
+ critic_params, lr=agent_config.lr, weight_decay=0.0
+ )
+ self.optimizer_alpha = optim.Adam(
+ [self.loss_module.log_alpha],
+ lr=3.0e-4,
+ )
+ if with_lagrange:
+ self.alpha_prime_optim = torch.optim.Adam(
+ [self.loss_module.log_alpha_prime],
+ lr=agent_config.lr,
+ )
+ else:
+ self.alpha_prime_optim = None
+ # general stats
+ self.collected_transitions = 0
+ self.total_updates = 0
+ self.do_pretrain = agent_config.pretrain
+ self.bc_steps = agent_config.bc_steps
+
+ def get_agent_statedict(self):
+ """Save agent"""
+ act_statedict = self.actor.state_dict()
+ critic_statedict = self.critic.state_dict()
+ return {"actor": act_statedict, "critic": critic_statedict}
+
+ def load_model(self, path):
+ """load model"""
+ try:
+ statedict = torch.load(path)
+ self.actor.load_state_dict(statedict["actor"])
+ self.critic.load_state_dict(statedict["critic"])
+ print("Model loaded")
+ except:
+ raise ValueError("Model not loaded")
+
+ def load_replaybuffer(self, path):
+ """load replay buffer"""
+ try:
+ loaded_data = TensorDictBase.load_memmap(path)
+ self.replay_buffer.extend(loaded_data)
+ if self.replay_buffer._batch_size != self.batch_size:
+ Warning(
+ "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size."
+ )
+ self.replay_buffer._batch_size = self.batch_size
+ print("Replay Buffer loaded")
+ print("Replay Buffer size: ", self.replay_buffer.__len__(), "\n")
+ except:
+ raise ValueError("Replay Buffer not loaded")
+
+ def reset_networks(self):
+ """reset network parameters"""
+ print("Resetting Networks!")
+ self.loss_module.actor_network_params.apply(self.reset_parameter)
+ self.loss_module.target_actor_network_params.apply(self.reset_parameter)
+ self.loss_module.qvalue_network_params.apply(self.reset_parameter)
+ self.loss_module.target_qvalue_network_params.apply(self.reset_parameter)
+
+ def eval(self):
+ """Sets the agent to evaluation mode."""
+ self.actor.eval()
+
+ def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase:
+ # TODO not ideal to have this here
+ td.pop("scale")
+ td.pop("loc")
+ td.pop("params")
+ if "vector_obs_embedding" in td.keys():
+ td.pop("vector_obs_embedding")
+ if "image_embedding" in td.keys():
+ td.pop("image_embedding")
+
+ def create_replay_buffer(
+ self,
+ prb=False,
+ buffer_size=100000,
+ buffer_scratch_dir=None,
+ device="cpu",
+ prefetch=3,
+ ):
+ """Create replay buffer"""
+ # TODO: make this part of base off policy agent
+ if prb:
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
+ alpha=0.7,
+ beta=0.5,
+ pin_memory=False,
+ prefetch=1,
+ storage=LazyTensorStorage(
+ buffer_size,
+ ),
+ )
+ else:
+ replay_buffer = TensorDictReplayBuffer(
+ pin_memory=False,
+ prefetch=prefetch,
+ storage=LazyMemmapStorage(
+ buffer_size,
+ scratch_dir=buffer_scratch_dir,
+ ),
+ batch_size=self.batch_size,
+ )
+ replay_buffer.append_transform(lambda x: x.to(device))
+ # TODO: check if we have image in observation space if so add this transform
+ # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True))
+ return replay_buffer
+
+ @torch.no_grad()
+ def get_action(self, td: TensorDictBase) -> TensorDictBase:
+ """Get action from actor network"""
+ with set_exploration_type(ExplorationType.RANDOM):
+ out_td = self.actor(td.to(self.device))
+ self.td_preprocessing(out_td)
+ return out_td
+
+ @torch.no_grad()
+ def get_eval_action(self, td: TensorDictBase) -> TensorDictBase:
+ """Get eval action from actor network"""
+ with set_exploration_type(ExplorationType.MODE):
+ out_td = self.actor(td.to(self.device))
+ self.td_preprocessing(out_td)
+ return out_td
+
+ def add_experience(self, transition: td.TensorDict):
+ """Add experience to replay buffer"""
+ self.replay_buffer.extend(transition)
+ self.collected_transitions += 1
+
+ def train(self, batch_size=64, num_updates=1):
+ """Train the agent"""
+ self.actor.train()
+ for i in range(num_updates):
+ self.total_updates += 1
+ # Sample a batch from the replay buffer
+ batch = self.replay_buffer.sample(batch_size)
+ # Compute CQL Loss
+ loss = self.loss_module(batch)
+
+ # Update alpha
+ alpha_loss = loss["loss_alpha"]
+ alpha_prime_loss = loss["loss_alpha_prime"]
+ self.optimizer_alpha.zero_grad()
+ alpha_loss.backward()
+ self.optimizer_alpha.step()
+
+ # Update Actpr Network
+ # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
+ if self.total_updates >= self.bc_steps:
+ actor_loss = loss["loss_actor"]
+ else:
+ actor_loss = loss["loss_actor_bc"]
+ self.optimizer_actor.zero_grad()
+ actor_loss.backward()
+ self.optimizer_actor.step()
+
+ if self.alpha_prime_optim is not None:
+ self.alpha_prime_optim.zero_grad()
+ alpha_prime_loss.backward(retain_graph=True)
+ self.alpha_prime_optim.step()
+
+ # Update Critic Network
+ q_loss = loss["loss_qvalue"]
+ cql_loss = loss["loss_cql"]
+
+ q_loss = q_loss + cql_loss
+ self.optimizer_critic.zero_grad()
+ q_loss.backward(retain_graph=False)
+ self.optimizer_critic.step()
+
+ # Update Target Networks
+ self.target_net_updater.step()
+ # Update Prioritized Replay Buffer
+ if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer):
+ self.replay_buffer.update_priorities(
+ batch["indices"],
+ loss["critic_loss"].detach().cpu().numpy(),
+ )
+ self.actor.eval()
+ return loss
diff --git a/src/agents/iql.py b/src/agents/iql.py
new file mode 100644
index 0000000..244bb17
--- /dev/null
+++ b/src/agents/iql.py
@@ -0,0 +1,254 @@
+import tensordict as td
+import torch
+from tensordict import TensorDictBase
+from torch import optim
+from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
+from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage
+from torchrl.envs.transforms import ToTensorImage
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.objectives import SoftUpdate
+
+from torchrl.objectives.iql import IQLLoss
+
+from src.agents.base import BaseAgent
+from src.networks.networks import get_critic, get_stochastic_actor, get_value_operator
+
+
+class IQLAgent(BaseAgent):
+ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
+ super(IQLAgent, self).__init__(
+ state_spec, action_spec, agent_config.name, device
+ )
+
+ self.actor = get_stochastic_actor(state_spec, action_spec, agent_config)
+ self.critic = get_critic(state_spec, agent_config)
+
+ self.value = get_value_operator(state_spec, agent_config)
+
+ self.actor.to(device)
+ self.critic.to(device)
+ self.value.to(device)
+
+ # initialize networks
+ self.init_nets([self.actor, self.critic, self.value])
+
+ # define loss function
+ self.loss_module = IQLLoss(
+ actor_network=self.actor,
+ qvalue_network=self.critic,
+ value_network=self.value,
+ num_qvalue_nets=2,
+ temperature=agent_config.temperature,
+ expectile=agent_config.expectile,
+ loss_function=agent_config.loss_function,
+ )
+ # Define Target Network Updater
+ self.target_net_updater = SoftUpdate(
+ self.loss_module, eps=agent_config.soft_update_eps
+ )
+ self.target_net_updater.init_()
+
+ # Reset weights
+ self.reset_params = agent_config.reset_params
+
+ # Define Replay Buffer
+ self.batch_size = agent_config.batch_size
+
+ self.replay_buffer = self.create_replay_buffer(
+ prb=agent_config.prb,
+ buffer_size=agent_config.buffer_size,
+ device=device,
+ )
+
+ # Define Optimizer
+ critic_params = list(
+ self.loss_module.qvalue_network_params.flatten_keys().values()
+ )
+ value_params = list(
+ self.loss_module.value_network_params.flatten_keys().values()
+ )
+ actor_params = list(
+ self.loss_module.actor_network_params.flatten_keys().values()
+ )
+ self.optimizer_actor = optim.Adam(
+ actor_params, lr=agent_config.lr, weight_decay=0.0
+ )
+ self.optimizer_critic = optim.Adam(
+ critic_params, lr=agent_config.lr, weight_decay=0.0
+ )
+ self.optimizer_value = optim.Adam(
+ value_params, lr=agent_config.lr, weight_decay=0.0
+ )
+
+ # general stats
+ self.collected_transitions = 0
+ self.total_updates = 0
+ self.do_pretrain = agent_config.pretrain
+
+ def get_agent_statedict(self):
+ """Save agent"""
+ act_statedict = self.actor.state_dict()
+ critic_statedict = self.critic.state_dict()
+ value_statedict = self.value.state_dict()
+ return {
+ "actor": act_statedict,
+ "critic": critic_statedict,
+ "value": value_statedict,
+ }
+
+ def load_model(self, path):
+ """load model"""
+
+ try:
+ statedict = torch.load(path)
+ self.actor.load_state_dict(statedict["actor"])
+ self.critic.load_state_dict(statedict["critic"])
+ self.value.load_state_dict(statedict["value"])
+ print("Model loaded")
+ except:
+ raise ValueError("Model not loaded")
+
+ def load_replaybuffer(self, path):
+ """load replay buffer"""
+ try:
+ loaded_data = TensorDictBase.load_memmap(path)
+ self.replay_buffer.extend(loaded_data)
+ if self.replay_buffer._batch_size != self.batch_size:
+ Warning(
+ "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size."
+ )
+ self.replay_buffer._batch_size = self.batch_size
+ print("Replay Buffer loaded")
+ print("Replay Buffer size: ", self.replay_buffer.__len__(), "\n")
+ except:
+ raise ValueError("Replay Buffer not loaded")
+
+ def reset_networks(self):
+ """reset network parameters"""
+ print("Resetting Networks!")
+ self.loss_module.actor_network_params.apply(self.reset_parameter)
+ self.loss_module.target_actor_network_params.apply(self.reset_parameter)
+ self.loss_module.qvalue_network_params.apply(self.reset_parameter)
+ self.loss_module.target_qvalue_network_params.apply(self.reset_parameter)
+ self.loss_module.value_network_params.apply(self.reset_parameter)
+
+ def eval(self):
+ """Sets the agent to evaluation mode."""
+ self.actor.eval()
+
+ def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase:
+ # TODO not ideal to have this here
+ td.pop("scale")
+ td.pop("loc")
+ td.pop("params")
+ if "vector_obs_embedding" in td.keys():
+ td.pop("vector_obs_embedding")
+ if "image_embedding" in td.keys():
+ td.pop("image_embedding")
+
+ def create_replay_buffer(
+ self,
+ prb=False,
+ buffer_size=100000,
+ buffer_scratch_dir=None,
+ device="cpu",
+ prefetch=3,
+ ):
+ """Create replay buffer"""
+ # TODO: make this part of base off policy agent
+ if prb:
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
+ alpha=0.7,
+ beta=0.5,
+ pin_memory=False,
+ prefetch=1,
+ storage=LazyTensorStorage(
+ buffer_size,
+ device=device,
+ ),
+ )
+ else:
+ replay_buffer = TensorDictReplayBuffer(
+ pin_memory=False,
+ prefetch=prefetch,
+ storage=LazyMemmapStorage(
+ buffer_size,
+ scratch_dir=buffer_scratch_dir,
+ ),
+ batch_size=self.batch_size,
+ )
+ replay_buffer.append_transform(lambda x: x.to(device))
+ # TODO: check if we have image in observation space if so add this transform
+ # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True))
+
+ return replay_buffer
+
+ @torch.no_grad()
+ def get_action(self, td: TensorDictBase) -> TensorDictBase:
+ """Get action from actor network"""
+ with set_exploration_type(ExplorationType.RANDOM):
+ out_td = self.actor(td.to(self.device))
+ self.td_preprocessing(out_td)
+ return out_td
+
+ @torch.no_grad()
+ def get_eval_action(self, td: TensorDictBase) -> TensorDictBase:
+ """Get eval action from actor network"""
+ with set_exploration_type(ExplorationType.MODE):
+ out_td = self.actor(td.to(self.device))
+ self.td_preprocessing(out_td)
+ return out_td
+
+ def add_experience(self, transition: td.TensorDict):
+ """Add experience to replay buffer"""
+ self.replay_buffer.extend(transition)
+ self.collected_transitions += 1
+
+ def pretrain(self, wandb, batch_size=64, num_updates=1):
+ """Pretrain the agent with simple behavioral cloning"""
+ # TODO: implement pretrain for testing
+ # for i in range(num_updates):
+ # batch = self.replay_buffer.sample(batch_size)
+ # pred, _ = self.actor(batch["observations"].float())
+ # loss = torch.mean((pred - batch["actions"]) ** 2)
+ # self.optimizer.zero_grad()
+ # loss.backward()
+ # self.optimizer.step()
+ # wandb.log({"pretrain/loss": loss.item()})
+
+ def train(self, batch_size=64, num_updates=1):
+ """Train the agent"""
+ self.actor.train()
+ for i in range(num_updates):
+ self.total_updates += 1
+ if self.reset_params and self.total_updates % self.reset_params == 0:
+ self.reset_networks()
+ # Sample a batch from the replay buffer
+ batch = self.replay_buffer.sample(batch_size)
+ # Compute IQL Loss
+ loss = self.loss_module(batch)
+
+ # Update Actpr Network
+ self.optimizer_actor.zero_grad()
+ loss["loss_actor"].backward()
+ self.optimizer_actor.step()
+ # Update Critic Network
+ self.optimizer_critic.zero_grad()
+ loss["loss_qvalue"].backward()
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
+ self.optimizer_critic.step()
+ # Update Value Network
+ self.optimizer_value.zero_grad()
+ loss["loss_value"].backward()
+ self.optimizer_value.step()
+
+ # Update Target Networks
+ self.target_net_updater.step()
+ # Update Prioritized Replay Buffer
+ if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer):
+ self.replay_buffer.update_priorities(
+ batch["indices"],
+ loss["critic_loss"].detach().cpu().numpy(),
+ )
+ self.actor.eval()
+ return loss
diff --git a/src/agents/random.py b/src/agents/random.py
index 55a37ef..f599f43 100644
--- a/src/agents/random.py
+++ b/src/agents/random.py
@@ -1,5 +1,7 @@
import torch
from tensordict import TensorDictBase
+from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
+from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage
from src.agents.base import BaseAgent
@@ -11,7 +13,13 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
)
self.actor = None
- self.replay_buffer = {}
+ self.replay_buffer = self.create_replay_buffer(
+ batch_size=256,
+ prb=False,
+ buffer_size=1000000,
+ device=device,
+ buffer_scratch_dir="/tmp",
+ )
def eval(self):
"""Sets the agent to evaluation mode."""
@@ -30,8 +38,41 @@ def get_eval_action(self, tensordict: TensorDictBase):
def add_experience(self, transition: TensorDictBase):
"""Add experience to replay buffer"""
- pass
+ self.replay_buffer.extend(transition)
def train(self, batch_size=64, num_updates=1):
"""Train the agent"""
return {}
+
+ def create_replay_buffer(
+ self,
+ batch_size=256,
+ prb=False,
+ buffer_size=100000,
+ buffer_scratch_dir=None,
+ device="cpu",
+ prefetch=3,
+ ):
+ """Create replay buffer"""
+ # TODO: make this part of base off policy agent
+ if prb:
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
+ alpha=0.7,
+ beta=0.5,
+ pin_memory=False,
+ prefetch=1,
+ storage=LazyTensorStorage(
+ buffer_size,
+ ),
+ )
+ else:
+ replay_buffer = TensorDictReplayBuffer(
+ pin_memory=False,
+ prefetch=prefetch,
+ storage=LazyMemmapStorage(
+ buffer_size,
+ scratch_dir=buffer_scratch_dir,
+ ),
+ batch_size=batch_size,
+ )
+ return replay_buffer
diff --git a/src/agents/sac.py b/src/agents/sac.py
index 7636c18..a87049f 100644
--- a/src/agents/sac.py
+++ b/src/agents/sac.py
@@ -19,10 +19,8 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
state_spec, action_spec, agent_config.name, device
)
- self.actor = get_stochastic_actor(
- self.observation_keys, action_spec, agent_config
- )
- self.critic = get_critic(self.observation_keys, agent_config)
+ self.actor = get_stochastic_actor(state_spec, action_spec, agent_config)
+ self.critic = get_critic(state_spec, agent_config)
self.actor.to(device)
self.critic.to(device)
@@ -52,14 +50,13 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
self.batch_size = agent_config.batch_size
# Define Replay Buffer
+ self.buffer_batch_size = agent_config.batch_size
self.replay_buffer = self.create_replay_buffer(
- batch_size=self.batch_size,
prb=agent_config.prb,
buffer_size=agent_config.buffer_size,
- device=device,
buffer_scratch_dir="/tmp",
+ device=device,
)
-
# Define Optimizer
critic_params = list(
self.loss_module.qvalue_network_params.flatten_keys().values()
@@ -101,7 +98,8 @@ def load_model(self, path):
def load_replaybuffer(self, path):
"""load replay buffer"""
try:
- self.replay_buffer.load(path)
+ loaded_data = TensorDictBase.load_memmap(path)
+ self.replay_buffer.extend(loaded_data)
if self.replay_buffer._batch_size != self.batch_size:
Warning(
"Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size."
@@ -136,10 +134,9 @@ def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase:
def create_replay_buffer(
self,
- batch_size=256,
prb=False,
buffer_size=100000,
- buffer_scratch_dir=None,
+ buffer_scratch_dir=".",
device="cpu",
prefetch=3,
):
@@ -163,9 +160,12 @@ def create_replay_buffer(
buffer_size,
scratch_dir=buffer_scratch_dir,
),
- batch_size=batch_size,
+ batch_size=self.batch_size,
)
replay_buffer.append_transform(lambda x: x.to(device))
+ # TODO: check if we have image in observation space if so add this transform
+ # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True))
+
return replay_buffer
@torch.no_grad()
diff --git a/src/agents/td3.py b/src/agents/td3.py
index 5c1b9e5..706d821 100644
--- a/src/agents/td3.py
+++ b/src/agents/td3.py
@@ -10,6 +10,7 @@
from torchrl.modules import AdditiveGaussianWrapper
from torchrl.objectives import SoftUpdate
from torchrl.objectives.td3 import TD3Loss
+from torchrl.objectives.td3_bc import TD3BCLoss
from src.agents.base import BaseAgent
from src.networks.networks import get_critic, get_deterministic_actor
@@ -30,12 +31,13 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
state_spec, action_spec, agent_config.name, device
)
- self.actor = get_deterministic_actor(
- self.observation_keys, action_spec, agent_config
- )
- self.critic = get_critic(self.observation_keys, agent_config)
+ self.actor = get_deterministic_actor(state_spec, action_spec, agent_config)
+ self.critic = get_critic(state_spec, agent_config)
self.model = nn.ModuleList([self.actor, self.critic]).to(device)
+
+ print(self.actor)
+ print(self.critic)
# initialize networks
self.init_nets(self.model)
@@ -48,14 +50,27 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
).to(device)
# define loss function
- self.loss_module = TD3Loss(
- actor_network=self.model[0],
- qvalue_network=self.model[1],
- action_spec=action_spec,
- num_qvalue_nets=2,
- loss_function=agent_config.loss_function,
- separate_losses=False,
- )
+ self.use_bc = agent_config.use_bc
+ if not self.use_bc:
+ self.loss_module = TD3Loss(
+ actor_network=self.model[0],
+ qvalue_network=self.model[1],
+ action_spec=action_spec,
+ num_qvalue_nets=2,
+ loss_function=agent_config.loss_function,
+ separate_losses=False,
+ )
+ else:
+ self.loss_module = TD3BCLoss(
+ actor_network=self.model[0],
+ qvalue_network=self.model[1],
+ action_spec=action_spec,
+ num_qvalue_nets=2,
+ loss_function=agent_config.loss_function,
+ separate_losses=False,
+ alpha=agent_config.alpha,
+ )
+
# Define Target Network Updater
self.target_net_updater = SoftUpdate(
self.loss_module, eps=agent_config.soft_update_eps
@@ -65,7 +80,6 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"):
self.batch_size = agent_config.batch_size
# Define Replay Buffer
self.replay_buffer = self.create_replay_buffer(
- batch_size=self.batch_size,
prb=agent_config.prb,
buffer_size=agent_config.buffer_size,
device=device,
@@ -112,7 +126,8 @@ def load_model(self, path):
def load_replaybuffer(self, path):
"""load replay buffer"""
try:
- self.replay_buffer.load(path)
+ loaded_data = TensorDictBase.load_memmap(path)
+ self.replay_buffer.extend(loaded_data)
if self.replay_buffer._batch_size != self.batch_size:
Warning(
"Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size."
@@ -133,7 +148,6 @@ def reset_networks(self):
def create_replay_buffer(
self,
- batch_size=256,
prb=False,
buffer_size=100000,
buffer_scratch_dir=None,
@@ -160,9 +174,17 @@ def create_replay_buffer(
buffer_size,
scratch_dir=buffer_scratch_dir,
),
- batch_size=batch_size,
+ batch_size=self.batch_size,
)
replay_buffer.append_transform(lambda x: x.to(device))
+ # TODO: check if we have image in observation space if so add this transform
+ # replay_buffer.append_transform(
+ # ToTensorImage(
+ # from_int=True,
+ # shape_tolerant=True,
+ # in_keys=["pixels", ("next", "pixels")],
+ # )
+ # )
return replay_buffer
def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase:
@@ -215,7 +237,10 @@ def train(self, batch_size=64, num_updates=1):
else:
sampled_tensordict = sampled_tensordict.clone()
# Update Critic Network
- q_loss, _ = self.loss_module.value_loss(sampled_tensordict)
+ if self.use_bc:
+ q_loss, _ = self.loss_module.qvalue_loss(sampled_tensordict)
+ else:
+ q_loss, _ = self.loss_module.value_loss(sampled_tensordict)
self.optimizer_critic.zero_grad()
q_loss.backward()
self.optimizer_critic.step()
diff --git a/src/networks/networks.py b/src/networks/networks.py
index 4d82763..b1ff9c0 100644
--- a/src/networks/networks.py
+++ b/src/networks/networks.py
@@ -3,7 +3,6 @@
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.modules import (
- AdditiveGaussianWrapper,
ConvNet,
MLP,
ProbabilisticActor,
@@ -12,7 +11,7 @@
TanhModule,
ValueOperator,
)
-from torchrl.modules.distributions import TanhDelta, TanhNormal
+from torchrl.modules.distributions import TanhNormal
def get_normalization(normalization):
@@ -26,7 +25,9 @@ def get_normalization(normalization):
raise NotImplementedError(f"Normalization {normalization} not implemented")
-def get_critic(observation_keys, agent_config):
+def get_critic(observation_spec, agent_config):
+ observation_keys = [key for key in observation_spec.keys()]
+
if "observation" in observation_keys and not "pixels" in observation_keys:
return get_vec_critic(
in_keys=observation_keys,
@@ -36,6 +37,17 @@ def get_critic(observation_keys, agent_config):
normalization=agent_config.normalization,
dropout=agent_config.dropout,
)
+ elif "pixels" in observation_keys and not "observation" in observation_keys:
+ return get_img_only_critic(
+ img_in_keys="pixels",
+ num_cells=[agent_config.num_cells, agent_config.num_cells],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization=agent_config.normalization,
+ dropout=agent_config.dropout,
+ img_shape=observation_spec["pixels"].shape,
+ )
+
elif "pixels" in observation_keys and "observation" in observation_keys:
return get_mixed_critic(
vec_in_keys="observation",
@@ -45,11 +57,171 @@ def get_critic(observation_keys, agent_config):
activation_class=nn.ReLU,
normalization=agent_config.normalization,
dropout=agent_config.dropout,
+ img_shape=observation_spec["pixels"].shape,
)
else:
raise NotImplementedError("Critic for this observation space not implemented")
+def get_value_operator(observation_spec, agent_config):
+ observation_keys = [key for key in observation_spec.keys()]
+ if "observation" in observation_keys and not "pixels" in observation_keys:
+ return get_vec_value(
+ in_keys=observation_keys,
+ num_cells=[agent_config.num_cells, agent_config.num_cells],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization=agent_config.normalization,
+ dropout=agent_config.dropout,
+ )
+ elif "pixels" in observation_keys and not "observation" in observation_keys:
+ return get_img_only_value(
+ img_in_keys="pixels",
+ num_cells=[agent_config.num_cells, agent_config.num_cells],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization=agent_config.normalization,
+ dropout=agent_config.dropout,
+ img_shape=observation_spec["pixels"].shape,
+ )
+ elif "pixels" in observation_keys and "observation" in observation_keys:
+ return get_mixed_value(
+ vec_in_keys="observation",
+ img_in_keys="pixels",
+ num_cells=[agent_config.num_cells, agent_config.num_cells],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization=agent_config.normalization,
+ dropout=agent_config.dropout,
+ img_shape=observation_spec["pixels"].shape,
+ )
+
+
+def get_vec_value(
+ in_keys=["observation"],
+ num_cells=[256, 256],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization="None",
+ dropout=0.0,
+):
+ """Returns a critic network"""
+ normalization = get_normalization(normalization)
+ qvalue_net = MLP(
+ num_cells=num_cells,
+ out_features=out_features,
+ activation_class=activation_class,
+ norm_class=normalization,
+ norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None,
+ dropout=dropout,
+ )
+
+ qvalue = ValueOperator(
+ in_keys=in_keys,
+ module=qvalue_net,
+ )
+ return qvalue
+
+
+def get_img_only_value(
+ img_in_keys,
+ num_cells=[256, 256],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization="None",
+ dropout=0.0,
+ img_shape=(3, 64, 64),
+):
+ normalization = get_normalization(normalization)
+ # image encoder
+ cnn = ConvNet(
+ activation_class=activation_class,
+ num_cells=[32, 64, 64],
+ kernel_sizes=[8, 4, 3],
+ strides=[4, 2, 1],
+ )
+ cnn_output = cnn(torch.ones(img_shape))
+ mlp = MLP(
+ in_features=cnn_output.shape[-1],
+ activation_class=activation_class,
+ out_features=128,
+ num_cells=[256],
+ )
+ image_encoder = SafeModule(
+ torch.nn.Sequential(cnn, mlp),
+ in_keys=[img_in_keys],
+ out_keys=["pixel_embedding"],
+ )
+
+ # output head
+ mlp = MLP(
+ activation_class=torch.nn.ReLU,
+ out_features=out_features,
+ num_cells=num_cells,
+ norm_class=normalization,
+ norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None,
+ dropout=dropout,
+ )
+ v_head = ValueOperator(mlp, ["pixel_embedding"])
+ # model
+ return SafeSequential(image_encoder, v_head)
+
+
+def get_mixed_value(
+ vec_in_keys,
+ img_in_keys,
+ num_cells=[256, 256],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization="None",
+ dropout=0.0,
+ img_shape=(3, 64, 64),
+):
+ normalization = get_normalization(normalization)
+ # image encoder
+ cnn = ConvNet(
+ activation_class=activation_class,
+ num_cells=[32, 64, 64],
+ kernel_sizes=[8, 4, 3],
+ strides=[4, 2, 1],
+ )
+ cnn_output = cnn(torch.ones(img_shape))
+ mlp = MLP(
+ in_features=cnn_output.shape[-1],
+ activation_class=activation_class,
+ out_features=128,
+ num_cells=[256],
+ )
+ image_encoder = SafeModule(
+ torch.nn.Sequential(cnn, mlp),
+ in_keys=[img_in_keys],
+ out_keys=["pixel_embedding"],
+ )
+
+ # vector_obs encoder
+ mlp = MLP(
+ activation_class=activation_class,
+ out_features=32,
+ num_cells=[128],
+ )
+ vector_obs_encoder = SafeModule(
+ mlp, in_keys=[vec_in_keys], out_keys=["obs_embedding"]
+ )
+
+ # output head
+ mlp = MLP(
+ activation_class=torch.nn.ReLU,
+ out_features=out_features,
+ num_cells=num_cells,
+ norm_class=normalization,
+ norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None,
+ dropout=dropout,
+ )
+ v_head = ValueOperator(mlp, ["pixel_embedding", "obs_embedding"])
+ # model
+ return SafeSequential(image_encoder, vector_obs_encoder, v_head)
+
+
def get_vec_critic(
in_keys=["observation"],
num_cells=[256, 256],
@@ -84,6 +256,7 @@ def get_mixed_critic(
activation_class=nn.ReLU,
normalization="None",
dropout=0.0,
+ img_shape=(3, 64, 64),
):
normalization = get_normalization(normalization)
# image encoder
@@ -93,7 +266,7 @@ def get_mixed_critic(
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
- cnn_output = cnn(torch.ones((3, 64, 64)))
+ cnn_output = cnn(torch.ones(img_shape))
mlp = MLP(
in_features=cnn_output.shape[-1],
activation_class=activation_class,
@@ -130,7 +303,53 @@ def get_mixed_critic(
return SafeSequential(image_encoder, vector_obs_encoder, v_head)
-def get_deterministic_actor(observation_keys, action_spec, agent_config):
+def get_img_only_critic(
+ img_in_keys,
+ num_cells=[256, 256],
+ out_features=1,
+ activation_class=nn.ReLU,
+ normalization="None",
+ dropout=0.0,
+ img_shape=(3, 64, 64),
+):
+ normalization = get_normalization(normalization)
+ # image encoder
+ cnn = ConvNet(
+ activation_class=activation_class,
+ num_cells=[32, 64, 64],
+ kernel_sizes=[8, 4, 3],
+ strides=[4, 2, 1],
+ )
+ cnn_output = cnn(torch.ones(img_shape))
+ mlp = MLP(
+ in_features=cnn_output.shape[-1],
+ activation_class=activation_class,
+ out_features=128,
+ num_cells=[256],
+ )
+ image_encoder = SafeModule(
+ torch.nn.Sequential(cnn, mlp),
+ in_keys=[img_in_keys],
+ out_keys=["pixel_embedding"],
+ )
+
+ # output head
+ mlp = MLP(
+ activation_class=torch.nn.ReLU,
+ out_features=out_features,
+ num_cells=num_cells,
+ norm_class=normalization,
+ norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None,
+ dropout=dropout,
+ )
+ v_head = ValueOperator(mlp, ["pixel_embedding", "action"])
+ # model
+ return SafeSequential(image_encoder, v_head)
+
+
+def get_deterministic_actor(observation_spec, action_spec, agent_config):
+ observation_keys = [key for key in observation_spec.keys()]
+
if "observation" in observation_keys and not "pixels" in observation_keys:
return get_vec_deterministic_actor(
action_spec=action_spec,
@@ -139,6 +358,15 @@ def get_deterministic_actor(observation_keys, action_spec, agent_config):
activation_class=nn.ReLU,
)
+ elif "pixels" in observation_keys and not "observation" in observation_keys:
+ return get_img_only_det_actor(
+ img_in_keys="pixels",
+ action_spec=action_spec,
+ num_cells=[agent_config.num_cells, agent_config.num_cells],
+ activation_class=nn.ReLU,
+ img_shape=observation_spec["pixels"].shape,
+ )
+
elif "pixels" in observation_keys and "observation" in observation_keys:
return get_mixed_deterministic_actor(
vec_in_keys="observation",
@@ -146,6 +374,7 @@ def get_deterministic_actor(observation_keys, action_spec, agent_config):
action_spec=action_spec,
num_cells=[agent_config.num_cells, agent_config.num_cells],
activation_class=nn.ReLU,
+ img_shape=observation_spec["pixels"].shape,
)
else:
raise NotImplementedError("Actor for this observation space not implemented")
@@ -190,6 +419,58 @@ def get_vec_deterministic_actor(
return actor
+def get_img_only_det_actor(
+ img_in_keys,
+ action_spec,
+ num_cells=[256, 256],
+ activation_class=nn.ReLU,
+ normalization="None",
+ dropout=0.0,
+ img_shape=(3, 64, 64),
+):
+ normalization = get_normalization(normalization)
+ # image encoder
+ cnn = ConvNet(
+ activation_class=activation_class,
+ num_cells=[32, 64, 64],
+ kernel_sizes=[8, 4, 3],
+ strides=[4, 2, 1],
+ )
+ cnn_output = cnn(torch.ones(img_shape))
+ mlp = MLP(
+ in_features=cnn_output.shape[-1],
+ activation_class=activation_class,
+ out_features=128,
+ num_cells=[256],
+ )
+ image_encoder = SafeModule(
+ torch.nn.Sequential(cnn, mlp),
+ in_keys=[img_in_keys],
+ out_keys=["pixel_embedding"],
+ )
+
+ # output head
+ mlp = MLP(
+ activation_class=torch.nn.ReLU,
+ out_features=action_spec.shape[-1],
+ num_cells=num_cells,
+ norm_class=normalization,
+ norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None,
+ dropout=dropout,
+ )
+ combined = SafeModule(mlp, ["pixel_embedding"], out_keys=["param"])
+ out_module = TanhModule(
+ in_keys=["param"],
+ out_keys=["action"],
+ spec=action_spec,
+ )
+ return SafeSequential(
+ image_encoder,
+ combined,
+ out_module,
+ )
+
+
def get_mixed_deterministic_actor(
vec_in_keys,
img_in_keys,
@@ -198,6 +479,7 @@ def get_mixed_deterministic_actor(
activation_class=nn.ReLU,
normalization="None",
dropout=0.0,
+ img_shape=(3, 64, 64),
):
normalization = get_normalization(normalization)
# image encoder
@@ -207,7 +489,7 @@ def get_mixed_deterministic_actor(
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
- cnn_output = cnn(torch.ones((3, 64, 64)))
+ cnn_output = cnn(torch.ones(img_shape))
mlp = MLP(
in_features=cnn_output.shape[-1],
activation_class=activation_class,
@@ -253,7 +535,8 @@ def get_mixed_deterministic_actor(
)
-def get_stochastic_actor(observation_keys, action_spec, agent_config):
+def get_stochastic_actor(observation_spec, action_spec, agent_config):
+ observation_keys = [key for key in observation_spec.keys()]
if "observation" in observation_keys and not "pixels" in observation_keys:
return get_vec_stochastic_actor(
action_spec,
@@ -263,6 +546,16 @@ def get_stochastic_actor(observation_keys, action_spec, agent_config):
dropout=agent_config.dropout,
activation_class=nn.ReLU,
)
+ elif "pixels" in observation_keys and not "observation" in observation_keys:
+ return get_img_only_stochastic_actor(
+ img_in_keys="pixels",
+ action_spec=action_spec,
+ num_cells=[agent_config.num_cells, agent_config.num_cells],
+ normalization=agent_config.normalization,
+ dropout=agent_config.dropout,
+ activation_class=nn.ReLU,
+ img_shape=observation_spec["pixels"].shape,
+ )
elif "pixels" in observation_keys and "observation" in observation_keys:
return get_mixed_stochastic_actor(
action_spec,
@@ -272,6 +565,7 @@ def get_stochastic_actor(observation_keys, action_spec, agent_config):
normalization=agent_config.normalization,
dropout=agent_config.dropout,
activation_class=nn.ReLU,
+ img_shape=observation_spec["pixels"].shape,
)
else:
raise NotImplementedError("Actor for this observation space not implemented")
@@ -330,6 +624,85 @@ def get_vec_stochastic_actor(
return actor
+def get_img_only_stochastic_actor(
+ action_spec,
+ img_in_keys,
+ num_cells=[256, 256],
+ normalization="None",
+ dropout=0.0,
+ activation_class=nn.ReLU,
+ img_shape=(3, 64, 64),
+):
+
+ normalization = get_normalization(normalization)
+ # image encoder
+ cnn = ConvNet(
+ activation_class=activation_class,
+ num_cells=[32, 64, 64],
+ kernel_sizes=[8, 4, 3],
+ strides=[4, 2, 1],
+ )
+ cnn_output = cnn(torch.ones(img_shape))
+ mlp = MLP(
+ in_features=cnn_output.shape[-1],
+ activation_class=activation_class,
+ out_features=128,
+ num_cells=[256],
+ )
+ image_encoder = SafeModule(
+ torch.nn.Sequential(cnn, mlp),
+ in_keys=[img_in_keys],
+ out_keys=["pixel_embedding"],
+ )
+
+ # output head
+ mlp = MLP(
+ activation_class=torch.nn.ReLU,
+ out_features=2 * action_spec.shape[-1],
+ num_cells=num_cells,
+ norm_class=normalization,
+ norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None,
+ dropout=dropout,
+ )
+ actor_module = SafeModule(
+ mlp,
+ in_keys=["pixel_embedding"],
+ out_keys=["params"],
+ )
+ actor_extractor = NormalParamExtractor(
+ scale_mapping=f"biased_softplus_{1.0}",
+ scale_lb=0.1,
+ )
+
+ extractor_module = SafeModule(
+ actor_extractor,
+ in_keys=["params"],
+ out_keys=[
+ "loc",
+ "scale",
+ ],
+ )
+ actor_net_combined = SafeSequential(image_encoder, actor_module, extractor_module)
+
+ dist_class = TanhNormal
+ dist_kwargs = {
+ "min": action_spec.space.low,
+ "max": action_spec.space.high,
+ "tanh_loc": False,
+ }
+ actor = ProbabilisticActor(
+ spec=action_spec,
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ module=actor_net_combined,
+ distribution_class=dist_class,
+ distribution_kwargs=dist_kwargs,
+ default_interaction_mode="random",
+ return_log_prob=False,
+ )
+ return actor
+
+
def get_mixed_stochastic_actor(
action_spec,
vec_in_keys,
@@ -338,6 +711,7 @@ def get_mixed_stochastic_actor(
normalization="None",
dropout=0.0,
activation_class=nn.ReLU,
+ img_shape=(3, 64, 64),
):
normalization = get_normalization(normalization)
@@ -348,7 +722,7 @@ def get_mixed_stochastic_actor(
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
- cnn_output = cnn(torch.ones((3, 64, 64)))
+ cnn_output = cnn(torch.ones(img_shape))
mlp = MLP(
in_features=cnn_output.shape[-1],
activation_class=activation_class,
diff --git a/src/utils.py b/src/utils.py
index 3ea4c73..50adae4 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -6,7 +6,7 @@
from environments import ALL_2WHEELER_ENVS, ALL_ROBOARM_ENVS, ALL_WALKER_ENVS
from moviepy.editor import concatenate_videoclips, ImageClip
from omegaconf import DictConfig
-from tensordict import TensorDictBase
+from tensordict import TensorDict, TensorDictBase
from torchrl.envs.utils import step_mdp
from tqdm import tqdm
@@ -49,7 +49,11 @@ def logout(agent):
x = input("Do you want to save the replay buffer? (y/n)")
if x == "y":
save_name = input("Enter the name of the file to save: ")
- agent.replay_buffer.dump(save_name)
+ # agent.replay_buffer.dump(save_name)
+ batched_data = agent.replay_buffer.storage._storage[
+ : agent.replay_buffer.__len__()
+ ]
+ batched_data.save(save_name, copy_existing=True)
def login(agent):
@@ -82,7 +86,6 @@ def prefill_buffer(env, agent, num_episodes=10, stop_on_done=False):
inpt = input("Press Enter to start prefilling episode: ")
for e in tqdm(range(num_episodes), desc="Prefilling buffer"):
print("Prefill episode: ", e)
-
td = env.reset()
done = False
truncated = False
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 0a21c6c..754a812 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -17,9 +17,9 @@ def collection_round(env, agent, max_steps=1000):
td = step_mdp(td)
-def get_env(env):
+def get_env(env, img_shape=(64, 64, 3)):
if env == "mixed":
- env = MixedObsDummyEnv()
+ env = MixedObsDummyEnv(img_shape=img_shape)
env = TransformedEnv(
env, Compose(ToTensorImage(in_keys=["pixels"], from_int=True))
)
@@ -49,7 +49,9 @@ def test_random_agent(env, device):
else:
device = "cpu"
with initialize(config_path="../conf"):
- cfg = compose(config_name="config", overrides=["device=" + device])
+ cfg = compose(
+ config_name="config", overrides=["device=" + device, "agent=random"]
+ )
# Test data collection
env = get_env(env)
agent, _ = get_agent(env.action_spec, env.observation_spec, cfg)
@@ -171,3 +173,154 @@ def test_drq_agent(env, device):
eval_td2 = agent.get_eval_action(td)
assert torch.allclose(eval_td1["action"], eval_td2["action"])
+
+
+@pytest.mark.parametrize(
+ "env",
+ ["mixed", "vec", "vec_goal"],
+)
+@pytest.mark.parametrize(
+ "device",
+ ["cpu", "cuda"],
+)
+def test_iql_agent(env, device):
+ if torch.cuda.is_available() and device == "cuda":
+ device = "cuda"
+ else:
+ device = "cpu"
+ with initialize(config_path="../conf"):
+ cfg = compose(config_name="config", overrides=["agent=iql", "device=" + device])
+
+ # Test data collection
+ env = get_env(env)
+ agent, _ = get_agent(env.action_spec, env.observation_spec, cfg)
+ collection_round(env, agent, max_steps=10)
+ # Test training
+ agent.train(batch_size=1, num_updates=1)
+
+ # Test evaluation
+ td = env.reset()
+ td1 = agent.get_action(td)
+ td2 = agent.get_action(td)
+
+ assert not torch.allclose(td1["action"], td2["action"])
+
+ agent.eval()
+ td = env.reset()
+ eval_td1 = agent.get_eval_action(td)
+ eval_td2 = agent.get_eval_action(td)
+
+ assert torch.allclose(eval_td1["action"], eval_td2["action"])
+
+
+@pytest.mark.parametrize(
+ "env",
+ ["mixed", "vec", "vec_goal"],
+)
+@pytest.mark.parametrize(
+ "device",
+ ["cpu", "cuda"],
+)
+def test_cql_agent(env, device):
+ if torch.cuda.is_available() and device == "cuda":
+ device = "cuda"
+ else:
+ device = "cpu"
+ with initialize(config_path="../conf"):
+ cfg = compose(config_name="config", overrides=["agent=cql", "device=" + device])
+
+ # Test data collection
+ env = get_env(env)
+ agent, _ = get_agent(env.action_spec, env.observation_spec, cfg)
+ collection_round(env, agent, max_steps=10)
+ # Test training
+ agent.train(batch_size=1, num_updates=1)
+
+ # Test evaluation
+ td = env.reset()
+ td1 = agent.get_action(td)
+ td2 = agent.get_action(td)
+
+ assert not torch.allclose(td1["action"], td2["action"])
+
+ agent.eval()
+ td = env.reset()
+ eval_td1 = agent.get_eval_action(td)
+ eval_td2 = agent.get_eval_action(td)
+
+ assert torch.allclose(eval_td1["action"], eval_td2["action"])
+
+
+@pytest.mark.parametrize(
+ "env",
+ ["mixed", "vec", "vec_goal"],
+)
+@pytest.mark.parametrize(
+ "device",
+ ["cpu", "cuda"],
+)
+def test_bc_agent(env, device):
+ if torch.cuda.is_available() and device == "cuda":
+ device = "cuda"
+ else:
+ device = "cpu"
+ with initialize(config_path="../conf"):
+ cfg = compose(config_name="config", overrides=["agent=bc", "device=" + device])
+
+ # Test data collection
+ env = get_env(env)
+ agent, _ = get_agent(env.action_spec, env.observation_spec, cfg)
+ collection_round(env, agent, max_steps=10)
+ # Test training
+ agent.train(batch_size=1, num_updates=1)
+
+ # Test evaluation
+ agent.eval()
+ td = env.reset()
+ eval_td1 = agent.get_eval_action(td)
+ eval_td2 = agent.get_eval_action(td)
+
+ assert torch.allclose(eval_td1["action"], eval_td2["action"])
+
+
+@pytest.mark.parametrize(
+ "env",
+ ["mixed"],
+)
+@pytest.mark.parametrize(
+ "img_shape",
+ [(64, 64, 3), (128, 128, 3)],
+)
+@pytest.mark.parametrize(
+ "device",
+ ["cpu", "cuda"],
+)
+def test_mixd_obs_size_agent(env, device, img_shape):
+ if torch.cuda.is_available() and device == "cuda":
+ device = "cuda"
+ else:
+ device = "cpu"
+ with initialize(config_path="../conf"):
+ cfg = compose(config_name="config", overrides=["agent=td3", "device=" + device])
+
+ # Test data collection
+ env = get_env(env, img_shape)
+ agent, _ = get_agent(env.action_spec, env.observation_spec, cfg)
+ collection_round(env, agent, max_steps=10)
+
+ # Test training
+ agent.train(batch_size=1, num_updates=1)
+
+ # Test evaluation
+ td = env.reset()
+ td1 = agent.get_action(td)
+ td2 = agent.get_action(td)
+
+ assert not torch.allclose(td1["action"], td2["action"])
+
+ agent.eval()
+ td = env.reset()
+ eval_td1 = agent.get_eval_action(td)
+ eval_td2 = agent.get_eval_action(td)
+
+ assert torch.allclose(eval_td1["action"], eval_td2["action"])