diff --git a/.gitignore b/.gitignore index 374c6a1..a4fac52 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,6 @@ dmypy.json # wandb wandb/ # hydra -outputs/ \ No newline at end of file +outputs/ +# .pth files +*.pth \ No newline at end of file diff --git a/README.md b/README.md index 4b667aa..ff9629b 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ Before running experiments, please review and modify the configuration settings ### Robots -Robots utilized for our experiments. Building instructions can be found [here](https://sites.google.com/view/bricksrl/building-instructions). +Robots utilized for our experiments. Building instructions can be found [here](https://bricksrl.github.io/ProjectPage/). | ![2wheeler](https://drive.google.com/uc?export=view&id=1IxqQ1VZchPZMNXyZnTULuNy53-LMYT6W) | ![Walker](https://drive.google.com/uc?export=view&id=1ImR0f1UNjC4sUHXWWg_D06eukrh-doW9) | ![RoboArm](https://drive.google.com/uc?export=view&id=1IYCJrl5rZBvOb6xKwbSUZqYrVwKjCpJH) | |:--:|:--:|:--:| @@ -140,7 +140,7 @@ Robots utilized for our experiments. Building instructions can be found [here](h
Click me -Evaluation videos of the trained agents can be found [here](https://sites.google.com/view/bricksrl/main). +Evaluation videos of the trained agents can be found [here](https://bricksrl.github.io/ProjectPage/). ### 2Wheeler Results: @@ -159,7 +159,44 @@ Evaluation videos of the trained agents can be found [here](https://sites.google
+### Offline RL +
+ Click me +With the use of precollected [offline datasets](https://huggingface.co/datasets/compsciencelab/BricksRL-Datasets) we can pretrain agents with offline RL to perform a task without the need of real world interaction. Such pretrained policies can be evaluated directly or used for later training to fine tuning the pretrained policy on the real robot. + +#### Datasets +The datasets can be downloaded from huggingface and contain expert and random transitions for the 2Wheeler (RunAway-v0 and Spinning-v0), Walker (Walker-v0) and RoboArm (RoboArm-v0) robots. + + ```bash + git lfs install + git clone git@hf.co:datasets/compsciencelab/BricksRL-Datasets + ``` + +The datasets consist of TensorDicts containing expert and random transitions, which can be directly loaded into the replay buffer. When initiating (pre-)training, simply provide the path to the desired TensorDict when prompted to load the replay buffer. ## High-Level Examples In the [example notebook](example_notebook.ipynb) we provide high-level training examples to train a **SAC agent** in the **RoboArmSim-v0** environment and a **TD3 agent** in the **WalkerSim-v0** enviornment. + +#### Pretrain an Agent + +The execution of an experiment for offline training is similar to the online training except that you run the **pretrain.py** script: + + ```bash + python experiments/walker/pretrain.py + ``` + +Trained policies can then be evaluated as before with: + + ```bash + python experiments/walker/eval.py + ``` + +Or run training for fine-tuning the policy on the real robot: + + ```bash + python experiments/walker/train.py + ``` + + +
diff --git a/conf/agent/bc.yaml b/conf/agent/bc.yaml new file mode 100644 index 0000000..37e07ed --- /dev/null +++ b/conf/agent/bc.yaml @@ -0,0 +1,11 @@ +name: bc +lr: 3e-4 +batch_size: 256 +num_updates: 1 +prefill_episodes: 0 + + +policy_type: deterministic # stochastic or deterministic +num_cells: 256 +dropout: 0.01 +normalization: LayerNorm diff --git a/conf/agent/cql.yaml b/conf/agent/cql.yaml new file mode 100644 index 0000000..07c2082 --- /dev/null +++ b/conf/agent/cql.yaml @@ -0,0 +1,28 @@ +name: cql +lr: 3e-4 +batch_size: 256 +num_updates: 1 +prefill_episodes: 10 + +bc_steps: 1000 + +# CQL specific +num_cells: 256 +gamma: 0.99 +soft_update_eps: 0.995 +loss_function: l2 +temperature: 1.0 +min_q_weight: 1.0 +max_q_backup: False +deterministic_backup: False +num_random: 10 +with_lagrange: True +lagrange_thresh: 5.0 # tau + +normalization: None +dropout: 0.0 + +prb: 0 +buffer_size: 1000000 +pretrain: False +reset_params: False \ No newline at end of file diff --git a/conf/agent/iql.yaml b/conf/agent/iql.yaml new file mode 100644 index 0000000..d2ca4a8 --- /dev/null +++ b/conf/agent/iql.yaml @@ -0,0 +1,20 @@ +name: iql +lr: 3e-4 +batch_size: 256 +num_updates: 1 +prefill_episodes: 0 + +num_cells: 256 +gamma: 0.99 +soft_update_eps: 0.995 +loss_function: l2 +temperature: 1.0 +expectile: 0.5 + +normalization: None +dropout: 0.0 + +prb: 0 +buffer_size: 1000000 +pretrain: False +reset_params: False \ No newline at end of file diff --git a/conf/agent/td3.yaml b/conf/agent/td3.yaml index 4465a71..f341f50 100644 --- a/conf/agent/td3.yaml +++ b/conf/agent/td3.yaml @@ -15,4 +15,6 @@ dropout: 0.0 prb: 0 buffer_size: 1000000 -reset_params: False \ No newline at end of file +reset_params: False +use_bc: False +alpha: 1.0 \ No newline at end of file diff --git a/conf/config.yaml b/conf/config.yaml index 609b137..478d1c4 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -4,10 +4,10 @@ run_name: "" verbose: 0 device: "cuda" -episodes: 200 +episodes: 250 defaults: - _self_ # random, sac, td3, droq - agent: sac - - env: roboarm_sim-v0 \ No newline at end of file + - env: walker_sim-v0 \ No newline at end of file diff --git a/environments/__init__.py b/environments/__init__.py index fe070f6..e53dbf7 100644 --- a/environments/__init__.py +++ b/environments/__init__.py @@ -2,9 +2,7 @@ from torchrl.envs import ( CatFrames, Compose, - DoubleToFloat, ObservationNorm, - RewardSum, ToTensorImage, TransformedEnv, ) @@ -17,28 +15,32 @@ from environments.walker_v0.WalkerEnv import WalkerEnv_v0 from environments.walker_v0.WalkerEnvSim import WalkerEnvSim_v0 - VIDEO_LOGGING_ENVS = ["roboarm_mixed-v0", "walker_mixed-v0"] ALL_2WHEELER_ENVS = ["spinning-v0", "runaway-v0"] ALL_WALKER_ENVS = [ "walker-v0", "walker_sim-v0", ] -ALL_ROBOARM_ENVS = ["roboarm-v0", "roboarm_mixed-v0", "roboarm_sim-v0"] +ALL_ROBOARM_ENVS = [ + "roboarm-v0", + "roboarm_mixed-v0", + "roboarm_sim-v0", +] ALL_ENVS = ALL_2WHEELER_ENVS + ALL_WALKER_ENVS + ALL_ROBOARM_ENVS -def make_env(config): +def make_env(config, pretrain=False): """ Creates a new environment based on the provided configuration. Args: config: A configuration object containing the environment name and maximum episode steps. + pretrain: A boolean indicating whether the environment is for pretraining. Returns: A tuple containing the new environment, its action space, and its state space. """ - env = make(name=config.env.name, env_conf=config.env) + env = make(name=config.env.name, env_conf=config.env, pretrain=pretrain) observation_keys = [key for key in env.observation_spec.keys()] transforms = [] @@ -76,24 +78,27 @@ def make_env(config): return env, action_spec, state_spec -def make(name="RunAway", env_conf=None): +def make(name="RunAway", env_conf=None, pretrain=False): if name == "runaway-v0": return RunAwayEnv_v0( max_episode_steps=env_conf.max_episode_steps, min_distance=env_conf.min_distance, verbose=env_conf.verbose, + pretrain=pretrain, ) elif name == "spinning-v0": return SpinningEnv_v0( max_episode_steps=env_conf.max_episode_steps, sleep_time=env_conf.sleep_time, verbose=env_conf.verbose, + pretrain=pretrain, ) elif name == "walker-v0": return WalkerEnv_v0( max_episode_steps=env_conf.max_episode_steps, verbose=env_conf.verbose, sleep_time=env_conf.sleep_time, + pretrain=pretrain, ) elif name == "walker_sim-v0": return WalkerEnvSim_v0( @@ -109,6 +114,7 @@ def make(name="RunAway", env_conf=None): verbose=env_conf.verbose, sleep_time=env_conf.sleep_time, reward_signal=env_conf.reward_signal, + pretrain=pretrain, ) elif name == "roboarm_sim-v0": return RoboArmSimEnv_v0( @@ -125,6 +131,7 @@ def make(name="RunAway", env_conf=None): reward_signal=env_conf.reward_signal, camera_id=env_conf.camera_id, goal_radius=env_conf.goal_radius, + pretrain=pretrain, ) else: print("Environment not found") diff --git a/environments/base/base_env.py b/environments/base/base_env.py index 761a034..bc8f936 100644 --- a/environments/base/base_env.py +++ b/environments/base/base_env.py @@ -16,12 +16,15 @@ class BaseEnv(EnvBase): Args: action_dim (int): The dimensionality of the action space. state_dim (int): The dimensionality of the state space. + use_hub (bool): Whether to use the Pybricks hub for communication, if False, only the observation spec and action specs are created and can be used. + verbose (bool): Whether to print verbose output. """ def __init__( self, action_dim: int, state_dim: int, + use_hub: bool = True, verbose: bool = False, ): self.verbose = verbose @@ -36,11 +39,14 @@ def __init__( # buffer state in case of missing data self.buffered_state = np.zeros(self.state_dim, dtype=np.float32) - self.hub = PybricksHub( - state_dim=state_dim, out_format_str=self.state_format_str - ) - self.hub.connect() - print("Connected to hub.") + if use_hub: + self.hub = PybricksHub( + state_dim=state_dim, out_format_str=self.state_format_str + ) + self.hub.connect() + print("Connected to hub.") + else: + self.hub = None super().__init__(batch_size=torch.Size([1])) def send_to_hub(self, action: np.array) -> None: @@ -129,6 +135,8 @@ class BaseSimEnv(EnvBase): Args: action_dim (int): The dimensionality of the action space. state_dim (int): The dimensionality of the state space. + verbose (bool): Whether to print verbose output. + use_hub (bool): This argument is kept for compatibility but is not used in the simulation environment. """ def __init__( @@ -136,6 +144,7 @@ def __init__( action_dim: int, state_dim: int, verbose: bool = False, + use_hub: bool = False, ): self.verbose = verbose self.action_dim = action_dim diff --git a/environments/dummy/mixed_obs_dummy.py b/environments/dummy/mixed_obs_dummy.py index 09b9dd3..14f9b16 100644 --- a/environments/dummy/mixed_obs_dummy.py +++ b/environments/dummy/mixed_obs_dummy.py @@ -21,7 +21,7 @@ class MixedObsDummyEnv(EnvBase): observation_key = "observation" pixel_observation_key = "pixels" - def __init__(self, max_episode_steps=10): + def __init__(self, max_episode_steps=10, img_shape=(64, 64, 3)): self.max_episode_steps = max_episode_steps self._batch_size = torch.Size([1]) self.action_spec = BoundedTensorSpec( @@ -36,8 +36,8 @@ def __init__(self, max_episode_steps=10): ) pixel_observation_spec = BoundedTensorSpec( - low=torch.zeros((1,) + (64, 64, 3), dtype=torch.uint8), - high=torch.ones((1,) + (64, 64, 3), dtype=torch.uint8) * 255, + low=torch.zeros((1,) + img_shape, dtype=torch.uint8), + high=torch.ones((1,) + img_shape, dtype=torch.uint8) * 255, ) self.observation_spec = CompositeSpec(shape=(1,)) diff --git a/environments/roboarm_mixed_v0/RoboArmMixedEnv.py b/environments/roboarm_mixed_v0/RoboArmMixedEnv.py index 6a72a5d..3f5c123 100644 --- a/environments/roboarm_mixed_v0/RoboArmMixedEnv.py +++ b/environments/roboarm_mixed_v0/RoboArmMixedEnv.py @@ -68,6 +68,7 @@ def __init__( max_episode_steps: int = 50, sleep_time: float = 0.0, verbose: bool = False, + pretrain: bool = False, reward_signal: str = "dense", camera_id: int = 0, goal_radius: float = 25, @@ -131,7 +132,10 @@ def __init__( self.goal_positions = self.init_camera_position() super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=1 - pretrain, ) def init_camera_position( diff --git a/environments/roboarm_v0/RoboArmEnv.py b/environments/roboarm_v0/RoboArmEnv.py index cd8028a..ea8b176 100644 --- a/environments/roboarm_v0/RoboArmEnv.py +++ b/environments/roboarm_v0/RoboArmEnv.py @@ -32,6 +32,7 @@ def __init__( max_episode_steps: int = 50, sleep_time: float = 0.0, verbose: bool = False, + pretrain: bool = False, reward_signal: str = "dense", ): self.sleep_time = sleep_time @@ -77,7 +78,10 @@ def __init__( self.observation_spec.set(self.observation_key, observation_spec) self.observation_spec.set(self.goal_observation_key, observation_spec) super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=1 - pretrain, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/environments/roboarm_v0/RoboArmSim.py b/environments/roboarm_v0/RoboArmSim.py index d0c3ee7..cca28cb 100644 --- a/environments/roboarm_v0/RoboArmSim.py +++ b/environments/roboarm_v0/RoboArmSim.py @@ -77,7 +77,10 @@ def __init__( self.observation_spec.set(self.observation_key, observation_spec) self.observation_spec.set(self.goal_observation_key, observation_spec) super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=False, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/environments/runaway_v0/RunAwayEnv.py b/environments/runaway_v0/RunAwayEnv.py index 89297ed..7adb5ba 100644 --- a/environments/runaway_v0/RunAwayEnv.py +++ b/environments/runaway_v0/RunAwayEnv.py @@ -45,6 +45,7 @@ def __init__( min_distance: float = 40, sleep_time: float = 0.2, verbose: bool = False, + pretrain: bool = False, ): self.sleep_time = sleep_time self.min_distance = min_distance @@ -81,7 +82,10 @@ def __init__( ) self.verbose = verbose super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=1 - pretrain, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/environments/spinning_v0/SpinningEnv.py b/environments/spinning_v0/SpinningEnv.py index 84a86fe..eab608b 100644 --- a/environments/spinning_v0/SpinningEnv.py +++ b/environments/spinning_v0/SpinningEnv.py @@ -39,6 +39,7 @@ def __init__( max_episode_steps: int = 50, sleep_time: float = 0.2, verbose: bool = False, + pretrain: bool = False, ): self.sleep_time = sleep_time self._batch_size = torch.Size([1]) @@ -74,7 +75,10 @@ def __init__( ) super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=1 - pretrain, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/environments/walker_v0/WalkerEnv.py b/environments/walker_v0/WalkerEnv.py index 6371cfb..9d8bdb9 100644 --- a/environments/walker_v0/WalkerEnv.py +++ b/environments/walker_v0/WalkerEnv.py @@ -45,6 +45,7 @@ def __init__( max_episode_steps: int = 50, sleep_time: float = 0.0, verbose: bool = False, + pretrain: bool = False, ): self.sleep_time = sleep_time self._batch_size = torch.Size([1]) @@ -83,7 +84,10 @@ def __init__( {self.observation_key: observation_spec}, shape=(1,) ) super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=1 - pretrain, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/environments/walker_v0/WalkerEnvSim.py b/environments/walker_v0/WalkerEnvSim.py index 2ab17a2..1a1e8e4 100644 --- a/environments/walker_v0/WalkerEnvSim.py +++ b/environments/walker_v0/WalkerEnvSim.py @@ -74,7 +74,10 @@ def __init__( {self.observation_key: observation_spec}, shape=(1,) ) super().__init__( - action_dim=self.action_dim, state_dim=self.state_dim, verbose=verbose + action_dim=self.action_dim, + state_dim=self.state_dim, + verbose=verbose, + use_hub=False, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/experiments/2wheeler/eval.py b/experiments/2wheeler/eval.py index 3d4d978..7d91075 100644 --- a/experiments/2wheeler/eval.py +++ b/experiments/2wheeler/eval.py @@ -91,7 +91,6 @@ def run(cfg: DictConfig) -> None: except KeyboardInterrupt: print("Evaluation interrupted by user.") - logout(agent) env.close() diff --git a/experiments/2wheeler/pretrain.py b/experiments/2wheeler/pretrain.py new file mode 100644 index 0000000..0849e94 --- /dev/null +++ b/experiments/2wheeler/pretrain.py @@ -0,0 +1,61 @@ +import os +import sys + +import hydra +import wandb +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm + +# Add the project root to PYTHONPATH +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from environments import make_env +from src.agents import get_agent +from src.utils import login, logout, setup_check, tensordict2dict + + +@hydra.main(version_base=None, config_path=project_root + "/conf", config_name="config") +def run(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + + # make environment. + setup_check(robot="2wheeler", config=cfg) + env, action_space, state_space = make_env(cfg, pretrain=True) + + # make agent + agent, project_name = get_agent(action_space, state_space, cfg) + login(agent) + + # initialize wandb + wandb.init(project=project_name) + wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + wandb.watch(agent.actor, log_freq=1) if agent.actor else None + + batch_size = cfg.agent.batch_size + num_updates = cfg.agent.num_updates + train_episodes = cfg.episodes + print("Start training...") + try: + for e in tqdm(range(train_episodes), desc="Training"): + + loss_info = agent.train(batch_size=batch_size, num_updates=num_updates) + + # Metrics Logging + log_dict = { + "epoch": e, + "buffer_size": agent.replay_buffer.__len__(), + } + log_dict.update(tensordict2dict(loss_info)) + wandb.log(log_dict) + + except KeyboardInterrupt: + print("Training interrupted by user.") + + logout(agent) + env.close() + + +if __name__ == "__main__": + run() diff --git a/experiments/roboarm/eval.py b/experiments/roboarm/eval.py index 8a088a8..c482ca1 100644 --- a/experiments/roboarm/eval.py +++ b/experiments/roboarm/eval.py @@ -61,6 +61,7 @@ def run(cfg: DictConfig) -> None: image_caputres.append( td.get(("next", "original_pixels")).cpu().numpy() ) + agent.add_experience(td) total_agent_step_time = time.time() - step_start_time total_step_times.append(total_agent_step_time) done = td.get(("next", "done"), False) diff --git a/experiments/roboarm/pretrain.py b/experiments/roboarm/pretrain.py new file mode 100644 index 0000000..5ff54d7 --- /dev/null +++ b/experiments/roboarm/pretrain.py @@ -0,0 +1,61 @@ +import os +import sys + +import hydra +import wandb +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm + +# Add the project root to PYTHONPATH +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from environments import make_env +from src.agents import get_agent +from src.utils import login, logout, setup_check, tensordict2dict + + +@hydra.main(version_base=None, config_path=project_root + "/conf", config_name="config") +def run(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + + # make environment. + setup_check(robot="roboarm", config=cfg) + env, action_space, state_space = make_env(cfg, pretrain=True) + + # make agent + agent, project_name = get_agent(action_space, state_space, cfg) + login(agent) + + # initialize wandb + wandb.init(project=project_name) + wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + wandb.watch(agent.actor, log_freq=1) if agent.actor else None + + batch_size = cfg.agent.batch_size + num_updates = cfg.agent.num_updates + train_episodes = cfg.episodes + print("Start training...") + try: + for e in tqdm(range(train_episodes), desc="Training"): + + loss_info = agent.train(batch_size=batch_size, num_updates=num_updates) + + # Metrics Logging + log_dict = { + "epoch": e, + "buffer_size": agent.replay_buffer.__len__(), + } + log_dict.update(tensordict2dict(loss_info)) + wandb.log(log_dict) + + except KeyboardInterrupt: + print("Training interrupted by user.") + + logout(agent) + env.close() + + +if __name__ == "__main__": + run() diff --git a/experiments/walker/pretrain.py b/experiments/walker/pretrain.py new file mode 100644 index 0000000..602d3e6 --- /dev/null +++ b/experiments/walker/pretrain.py @@ -0,0 +1,61 @@ +import os +import sys + +import hydra +import wandb +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm + +# Add the project root to PYTHONPATH +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from environments import make_env +from src.agents import get_agent +from src.utils import login, logout, setup_check, tensordict2dict + + +@hydra.main(version_base=None, config_path=project_root + "/conf", config_name="config") +def run(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + + # make environment. + setup_check(robot="walker", config=cfg) + env, action_space, state_space = make_env(cfg, pretrain=True) + + # make agent + agent, project_name = get_agent(action_space, state_space, cfg) + login(agent) + + # initialize wandb + wandb.init(project=project_name) + wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + wandb.watch(agent.actor, log_freq=1) if agent.actor else None + + batch_size = cfg.agent.batch_size + num_updates = cfg.agent.num_updates + train_episodes = cfg.episodes + print("Start training...") + try: + for e in tqdm(range(train_episodes), desc="Training"): + + loss_info = agent.train(batch_size=batch_size, num_updates=num_updates) + + # Metrics Logging + log_dict = { + "epoch": e, + "buffer_size": agent.replay_buffer.__len__(), + } + log_dict.update(tensordict2dict(loss_info)) + wandb.log(log_dict) + + except KeyboardInterrupt: + print("Training interrupted by user.") + + logout(agent) + env.close() + + +if __name__ == "__main__": + run() diff --git a/requirements.txt b/requirements.txt index 7ca509c..8b7f967 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ pybricksdev -tensordict==0.4.0 -torchrl==0.4.0 +tensordict==0.5.0 +torchrl==0.5.0 hydra-core==1.3.2 wandb==0.16.1 opencv-python==4.9.0.80 @@ -9,4 +9,5 @@ tqdm==4.66.1 pytest==8.0.2 ufmt pre-commit -numpy==1.24.1 \ No newline at end of file +numpy==1.24.1 +pynput \ No newline at end of file diff --git a/src/agents/__init__.py b/src/agents/__init__.py index cf74e48..5e5cafd 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -1,8 +1,11 @@ +from src.agents.behavior_cloning import BehavioralCloningAgent +from src.agents.cql import CQLAgent +from src.agents.iql import IQLAgent from src.agents.random import RandomAgent from src.agents.sac import SACAgent from src.agents.td3 import TD3Agent -all_agents = ["td3", "sac", "random"] +all_agents = ["td3", "sac", "iql", "cql", "bc", "random"] def get_agent(action_spec, state_spec, cfg): @@ -20,6 +23,13 @@ def get_agent(action_spec, state_spec, cfg): agent_config=cfg.agent, device=cfg.device, ) + elif cfg.agent.name == "bc": + agent = BehavioralCloningAgent( + action_spec=action_spec, + state_spec=state_spec, + agent_config=cfg.agent, + device=cfg.device, + ) elif cfg.agent.name == "random": agent = RandomAgent( action_spec=action_spec, @@ -27,6 +37,20 @@ def get_agent(action_spec, state_spec, cfg): agent_config=cfg.agent, device=cfg.device, ) + elif cfg.agent.name == "iql": + agent = IQLAgent( + action_spec=action_spec, + state_spec=state_spec, + agent_config=cfg.agent, + device=cfg.device, + ) + elif cfg.agent.name == "cql": + agent = CQLAgent( + action_spec=action_spec, + state_spec=state_spec, + agent_config=cfg.agent, + device=cfg.device, + ) else: raise NotImplementedError( f"Agent {cfg.agent.name} not implemented, please choose from {all_agents}" diff --git a/src/agents/base.py b/src/agents/base.py index d43b6d9..8605478 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -13,18 +13,17 @@ class BaseAgent: """Implements a base agent used to interact with the lego robots. Args: - state_space (gym.Space): The state space of the environment. - action_space (gym.Space): The action space of the environment. - device (torch.device): The device to use for computation. - observation_keys (Tuple[str]): The keys used to access the observation in the tensor dictionary. + state_spec (TensorSpec): The state specification of the environment. + action_spec (TensorSpec): The action specification of the environment. + agent_name (str): The name of the agent. + device (str): The device to use for computation. Attributes: - state_space (gym.Space): The state space of the environment. - action_space (gym.Space): The action space of the environment. - state_dim (int): The dimension of the state space. - action_dim (int): The dimension of the action space. - device (torch.device): The device to use for computation. - observation_keys (Tuple[str]): The keys used to access the observation in the tensor dictionary. + name (str): The name of the agent. + observation_spec (TensorSpec): The state specification of the environment. + action_spec (TensorSpec): The action specification of the environment. + device (str): The device to use for computation. + observation_keys (List[str]): The keys used to access the observation in the tensor dictionary. """ def __init__( diff --git a/src/agents/behavior_cloning.py b/src/agents/behavior_cloning.py new file mode 100644 index 0000000..bfb39cb --- /dev/null +++ b/src/agents/behavior_cloning.py @@ -0,0 +1,153 @@ +import numpy as np +import tensordict as td +import torch +from tensordict import TensorDictBase +from torch import nn, optim +from torchrl.data import BoundedTensorSpec, TensorDictReplayBuffer + +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import RenameTransform, ToTensorImage +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from src.agents.base import BaseAgent +from src.networks.networks import get_deterministic_actor, get_stochastic_actor + + +def initialize(net, std=0.02): + for p, n in net.named_parameters(): + if "weight" in p: + # nn.init.xavier_uniform_(n) + nn.init.normal_(n, mean=0, std=std) + elif "bias" in p: + nn.init.zeros_(n) + + +class BehavioralCloningAgent(BaseAgent): + def __init__(self, state_spec, action_spec, agent_config, device="cpu"): + super(BehavioralCloningAgent, self).__init__( + state_spec, action_spec, agent_config.name, device + ) + + if agent_config.policy_type == "deterministic": + self.actor = get_deterministic_actor(state_spec, action_spec, agent_config) + elif agent_config.policy_type == "stochastic": + raise NotImplementedError( + "Stochastic actor training is not implemented yet" + ) + # TODO: Implement stochastic actor training + # self.actor = get_stochastic_actor( + # state_spec, action_spec, agent_config + # ) + else: + raise ValueError( + "policy_type not recognized, choose deterministic or stochastic" + ) + self.actor.to(device) + # initialize networks + self.init_nets([self.actor]) + + self.optimizer = optim.Adam( + self.actor.parameters(), lr=agent_config.lr, weight_decay=0.0 + ) + + # create replay buffer + self.batch_size = agent_config.batch_size + self.replay_buffer = self.create_replay_buffer() + + # general stats + self.collected_transitions = 0 + self.do_pretrain = False + self.episodes = 0 + + def get_agent_statedict(self): + """Save agent""" + act_statedict = self.actor.state_dict() + return {"actor": act_statedict} + + def load_model(self, path): + """load model""" + try: + statedict = torch.load(path) + self.actor.load_state_dict(statedict["actor"]) + print("Model loaded") + except: + raise ValueError("Model not loaded") + + def load_replaybuffer(self, path): + """load replay buffer""" + try: + loaded_data = TensorDictBase.load_memmap(path) + self.replay_buffer.extend(loaded_data) + if self.replay_buffer._batch_size != self.batch_size: + Warning( + "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size." + ) + self.replay_buffer._batch_size = self.batch_size + print("Replay Buffer loaded") + print("Replay Buffer size: ", self.replay_buffer.__len__(), "\n") + except: + raise ValueError("Replay Buffer not loaded") + + def eval(self): + """Sets the agent to evaluation mode.""" + self.actor.eval() + + @torch.no_grad() + def get_eval_action(self, td: TensorDictBase) -> TensorDictBase: + """Get eval action from actor network""" + with set_exploration_type(ExplorationType.MODE): + out_td = self.actor(td.to(self.device)) + return out_td + + def create_replay_buffer( + self, + buffer_size=1000000, + buffer_scratch_dir="./tmp", + device="cpu", + prefetch=3, + ): + """Create replay buffer""" + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + ), + batch_size=self.batch_size, + ) + replay_buffer.append_transform(lambda x: x.to(device)) + # TODO: check if we have image in observation space if so add this transform + # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True)) + + return replay_buffer + + @torch.no_grad() + def get_action(self, td: TensorDictBase) -> TensorDictBase: + """Get action from actor network""" + with set_exploration_type(ExplorationType.RANDOM): + out_td = self.actor(td.to(self.device)) + return out_td + + def add_experience(self, transition: td.TensorDict): + """Add experience to replay buffer""" + """Add experience to replay buffer""" + self.replay_buffer.extend(transition) + self.collected_transitions += 1 + + def train(self, batch_size=64, num_updates=1): + """Train the agent""" + log_data = {} + + for i in range(num_updates): + batch = self.replay_buffer.sample(batch_size).to(self.device) + orig_action = batch.get("action").clone() + + out_dict = self.actor(batch) + loss = torch.mean((out_dict.get("action") - orig_action) ** 2) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + log_data.update({"loss": loss}) + return log_data diff --git a/src/agents/cql.py b/src/agents/cql.py new file mode 100644 index 0000000..54d9d2d --- /dev/null +++ b/src/agents/cql.py @@ -0,0 +1,252 @@ +import tensordict as td +import torch +from tensordict import TensorDictBase +from torch import optim +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import SoftUpdate + +from torchrl.objectives.cql import CQLLoss + +from src.agents.base import BaseAgent +from src.networks.networks import get_critic, get_stochastic_actor + + +class CQLAgent(BaseAgent): + def __init__(self, state_spec, action_spec, agent_config, device="cpu"): + super(CQLAgent, self).__init__( + state_spec, action_spec, agent_config.name, device + ) + + with_lagrange = agent_config.with_lagrange + + self.actor = get_stochastic_actor(state_spec, action_spec, agent_config) + self.critic = get_critic(state_spec, agent_config) + + self.actor.to(device) + self.critic.to(device) + + # initialize networks + self.init_nets([self.actor, self.critic]) + + # define loss function + self.loss_module = CQLLoss( + actor_network=self.actor, + qvalue_network=self.critic, + loss_function=agent_config.loss_function, + temperature=agent_config.temperature, + min_q_weight=agent_config.min_q_weight, + max_q_backup=agent_config.max_q_backup, + deterministic_backup=agent_config.deterministic_backup, + num_random=agent_config.num_random, + with_lagrange=agent_config.with_lagrange, + lagrange_thresh=agent_config.lagrange_thresh, + ) + # Define Target Network Updater + self.target_net_updater = SoftUpdate( + self.loss_module, eps=agent_config.soft_update_eps + ) + self.target_net_updater.init_() + + # Reset weights + self.reset_params = agent_config.reset_params + + # Define Replay Buffer + self.batch_size = agent_config.batch_size + self.replay_buffer = self.create_replay_buffer( + prb=agent_config.prb, + buffer_size=agent_config.buffer_size, + device=device, + ) + + # Define Optimizer + critic_params = list( + self.loss_module.qvalue_network_params.flatten_keys().values() + ) + actor_params = list( + self.loss_module.actor_network_params.flatten_keys().values() + ) + self.optimizer_actor = optim.Adam( + actor_params, lr=agent_config.lr, weight_decay=0.0 + ) + self.optimizer_critic = optim.Adam( + critic_params, lr=agent_config.lr, weight_decay=0.0 + ) + self.optimizer_alpha = optim.Adam( + [self.loss_module.log_alpha], + lr=3.0e-4, + ) + if with_lagrange: + self.alpha_prime_optim = torch.optim.Adam( + [self.loss_module.log_alpha_prime], + lr=agent_config.lr, + ) + else: + self.alpha_prime_optim = None + # general stats + self.collected_transitions = 0 + self.total_updates = 0 + self.do_pretrain = agent_config.pretrain + self.bc_steps = agent_config.bc_steps + + def get_agent_statedict(self): + """Save agent""" + act_statedict = self.actor.state_dict() + critic_statedict = self.critic.state_dict() + return {"actor": act_statedict, "critic": critic_statedict} + + def load_model(self, path): + """load model""" + try: + statedict = torch.load(path) + self.actor.load_state_dict(statedict["actor"]) + self.critic.load_state_dict(statedict["critic"]) + print("Model loaded") + except: + raise ValueError("Model not loaded") + + def load_replaybuffer(self, path): + """load replay buffer""" + try: + loaded_data = TensorDictBase.load_memmap(path) + self.replay_buffer.extend(loaded_data) + if self.replay_buffer._batch_size != self.batch_size: + Warning( + "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size." + ) + self.replay_buffer._batch_size = self.batch_size + print("Replay Buffer loaded") + print("Replay Buffer size: ", self.replay_buffer.__len__(), "\n") + except: + raise ValueError("Replay Buffer not loaded") + + def reset_networks(self): + """reset network parameters""" + print("Resetting Networks!") + self.loss_module.actor_network_params.apply(self.reset_parameter) + self.loss_module.target_actor_network_params.apply(self.reset_parameter) + self.loss_module.qvalue_network_params.apply(self.reset_parameter) + self.loss_module.target_qvalue_network_params.apply(self.reset_parameter) + + def eval(self): + """Sets the agent to evaluation mode.""" + self.actor.eval() + + def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase: + # TODO not ideal to have this here + td.pop("scale") + td.pop("loc") + td.pop("params") + if "vector_obs_embedding" in td.keys(): + td.pop("vector_obs_embedding") + if "image_embedding" in td.keys(): + td.pop("image_embedding") + + def create_replay_buffer( + self, + prb=False, + buffer_size=100000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, + ): + """Create replay buffer""" + # TODO: make this part of base off policy agent + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=1, + storage=LazyTensorStorage( + buffer_size, + ), + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + ), + batch_size=self.batch_size, + ) + replay_buffer.append_transform(lambda x: x.to(device)) + # TODO: check if we have image in observation space if so add this transform + # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True)) + return replay_buffer + + @torch.no_grad() + def get_action(self, td: TensorDictBase) -> TensorDictBase: + """Get action from actor network""" + with set_exploration_type(ExplorationType.RANDOM): + out_td = self.actor(td.to(self.device)) + self.td_preprocessing(out_td) + return out_td + + @torch.no_grad() + def get_eval_action(self, td: TensorDictBase) -> TensorDictBase: + """Get eval action from actor network""" + with set_exploration_type(ExplorationType.MODE): + out_td = self.actor(td.to(self.device)) + self.td_preprocessing(out_td) + return out_td + + def add_experience(self, transition: td.TensorDict): + """Add experience to replay buffer""" + self.replay_buffer.extend(transition) + self.collected_transitions += 1 + + def train(self, batch_size=64, num_updates=1): + """Train the agent""" + self.actor.train() + for i in range(num_updates): + self.total_updates += 1 + # Sample a batch from the replay buffer + batch = self.replay_buffer.sample(batch_size) + # Compute CQL Loss + loss = self.loss_module(batch) + + # Update alpha + alpha_loss = loss["loss_alpha"] + alpha_prime_loss = loss["loss_alpha_prime"] + self.optimizer_alpha.zero_grad() + alpha_loss.backward() + self.optimizer_alpha.step() + + # Update Actpr Network + # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks + if self.total_updates >= self.bc_steps: + actor_loss = loss["loss_actor"] + else: + actor_loss = loss["loss_actor_bc"] + self.optimizer_actor.zero_grad() + actor_loss.backward() + self.optimizer_actor.step() + + if self.alpha_prime_optim is not None: + self.alpha_prime_optim.zero_grad() + alpha_prime_loss.backward(retain_graph=True) + self.alpha_prime_optim.step() + + # Update Critic Network + q_loss = loss["loss_qvalue"] + cql_loss = loss["loss_cql"] + + q_loss = q_loss + cql_loss + self.optimizer_critic.zero_grad() + q_loss.backward(retain_graph=False) + self.optimizer_critic.step() + + # Update Target Networks + self.target_net_updater.step() + # Update Prioritized Replay Buffer + if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer): + self.replay_buffer.update_priorities( + batch["indices"], + loss["critic_loss"].detach().cpu().numpy(), + ) + self.actor.eval() + return loss diff --git a/src/agents/iql.py b/src/agents/iql.py new file mode 100644 index 0000000..244bb17 --- /dev/null +++ b/src/agents/iql.py @@ -0,0 +1,254 @@ +import tensordict as td +import torch +from tensordict import TensorDictBase +from torch import optim +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage +from torchrl.envs.transforms import ToTensorImage +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import SoftUpdate + +from torchrl.objectives.iql import IQLLoss + +from src.agents.base import BaseAgent +from src.networks.networks import get_critic, get_stochastic_actor, get_value_operator + + +class IQLAgent(BaseAgent): + def __init__(self, state_spec, action_spec, agent_config, device="cpu"): + super(IQLAgent, self).__init__( + state_spec, action_spec, agent_config.name, device + ) + + self.actor = get_stochastic_actor(state_spec, action_spec, agent_config) + self.critic = get_critic(state_spec, agent_config) + + self.value = get_value_operator(state_spec, agent_config) + + self.actor.to(device) + self.critic.to(device) + self.value.to(device) + + # initialize networks + self.init_nets([self.actor, self.critic, self.value]) + + # define loss function + self.loss_module = IQLLoss( + actor_network=self.actor, + qvalue_network=self.critic, + value_network=self.value, + num_qvalue_nets=2, + temperature=agent_config.temperature, + expectile=agent_config.expectile, + loss_function=agent_config.loss_function, + ) + # Define Target Network Updater + self.target_net_updater = SoftUpdate( + self.loss_module, eps=agent_config.soft_update_eps + ) + self.target_net_updater.init_() + + # Reset weights + self.reset_params = agent_config.reset_params + + # Define Replay Buffer + self.batch_size = agent_config.batch_size + + self.replay_buffer = self.create_replay_buffer( + prb=agent_config.prb, + buffer_size=agent_config.buffer_size, + device=device, + ) + + # Define Optimizer + critic_params = list( + self.loss_module.qvalue_network_params.flatten_keys().values() + ) + value_params = list( + self.loss_module.value_network_params.flatten_keys().values() + ) + actor_params = list( + self.loss_module.actor_network_params.flatten_keys().values() + ) + self.optimizer_actor = optim.Adam( + actor_params, lr=agent_config.lr, weight_decay=0.0 + ) + self.optimizer_critic = optim.Adam( + critic_params, lr=agent_config.lr, weight_decay=0.0 + ) + self.optimizer_value = optim.Adam( + value_params, lr=agent_config.lr, weight_decay=0.0 + ) + + # general stats + self.collected_transitions = 0 + self.total_updates = 0 + self.do_pretrain = agent_config.pretrain + + def get_agent_statedict(self): + """Save agent""" + act_statedict = self.actor.state_dict() + critic_statedict = self.critic.state_dict() + value_statedict = self.value.state_dict() + return { + "actor": act_statedict, + "critic": critic_statedict, + "value": value_statedict, + } + + def load_model(self, path): + """load model""" + + try: + statedict = torch.load(path) + self.actor.load_state_dict(statedict["actor"]) + self.critic.load_state_dict(statedict["critic"]) + self.value.load_state_dict(statedict["value"]) + print("Model loaded") + except: + raise ValueError("Model not loaded") + + def load_replaybuffer(self, path): + """load replay buffer""" + try: + loaded_data = TensorDictBase.load_memmap(path) + self.replay_buffer.extend(loaded_data) + if self.replay_buffer._batch_size != self.batch_size: + Warning( + "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size." + ) + self.replay_buffer._batch_size = self.batch_size + print("Replay Buffer loaded") + print("Replay Buffer size: ", self.replay_buffer.__len__(), "\n") + except: + raise ValueError("Replay Buffer not loaded") + + def reset_networks(self): + """reset network parameters""" + print("Resetting Networks!") + self.loss_module.actor_network_params.apply(self.reset_parameter) + self.loss_module.target_actor_network_params.apply(self.reset_parameter) + self.loss_module.qvalue_network_params.apply(self.reset_parameter) + self.loss_module.target_qvalue_network_params.apply(self.reset_parameter) + self.loss_module.value_network_params.apply(self.reset_parameter) + + def eval(self): + """Sets the agent to evaluation mode.""" + self.actor.eval() + + def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase: + # TODO not ideal to have this here + td.pop("scale") + td.pop("loc") + td.pop("params") + if "vector_obs_embedding" in td.keys(): + td.pop("vector_obs_embedding") + if "image_embedding" in td.keys(): + td.pop("image_embedding") + + def create_replay_buffer( + self, + prb=False, + buffer_size=100000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, + ): + """Create replay buffer""" + # TODO: make this part of base off policy agent + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=1, + storage=LazyTensorStorage( + buffer_size, + device=device, + ), + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + ), + batch_size=self.batch_size, + ) + replay_buffer.append_transform(lambda x: x.to(device)) + # TODO: check if we have image in observation space if so add this transform + # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True)) + + return replay_buffer + + @torch.no_grad() + def get_action(self, td: TensorDictBase) -> TensorDictBase: + """Get action from actor network""" + with set_exploration_type(ExplorationType.RANDOM): + out_td = self.actor(td.to(self.device)) + self.td_preprocessing(out_td) + return out_td + + @torch.no_grad() + def get_eval_action(self, td: TensorDictBase) -> TensorDictBase: + """Get eval action from actor network""" + with set_exploration_type(ExplorationType.MODE): + out_td = self.actor(td.to(self.device)) + self.td_preprocessing(out_td) + return out_td + + def add_experience(self, transition: td.TensorDict): + """Add experience to replay buffer""" + self.replay_buffer.extend(transition) + self.collected_transitions += 1 + + def pretrain(self, wandb, batch_size=64, num_updates=1): + """Pretrain the agent with simple behavioral cloning""" + # TODO: implement pretrain for testing + # for i in range(num_updates): + # batch = self.replay_buffer.sample(batch_size) + # pred, _ = self.actor(batch["observations"].float()) + # loss = torch.mean((pred - batch["actions"]) ** 2) + # self.optimizer.zero_grad() + # loss.backward() + # self.optimizer.step() + # wandb.log({"pretrain/loss": loss.item()}) + + def train(self, batch_size=64, num_updates=1): + """Train the agent""" + self.actor.train() + for i in range(num_updates): + self.total_updates += 1 + if self.reset_params and self.total_updates % self.reset_params == 0: + self.reset_networks() + # Sample a batch from the replay buffer + batch = self.replay_buffer.sample(batch_size) + # Compute IQL Loss + loss = self.loss_module(batch) + + # Update Actpr Network + self.optimizer_actor.zero_grad() + loss["loss_actor"].backward() + self.optimizer_actor.step() + # Update Critic Network + self.optimizer_critic.zero_grad() + loss["loss_qvalue"].backward() + torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) + self.optimizer_critic.step() + # Update Value Network + self.optimizer_value.zero_grad() + loss["loss_value"].backward() + self.optimizer_value.step() + + # Update Target Networks + self.target_net_updater.step() + # Update Prioritized Replay Buffer + if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer): + self.replay_buffer.update_priorities( + batch["indices"], + loss["critic_loss"].detach().cpu().numpy(), + ) + self.actor.eval() + return loss diff --git a/src/agents/random.py b/src/agents/random.py index 55a37ef..f599f43 100644 --- a/src/agents/random.py +++ b/src/agents/random.py @@ -1,5 +1,7 @@ import torch from tensordict import TensorDictBase +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage from src.agents.base import BaseAgent @@ -11,7 +13,13 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"): ) self.actor = None - self.replay_buffer = {} + self.replay_buffer = self.create_replay_buffer( + batch_size=256, + prb=False, + buffer_size=1000000, + device=device, + buffer_scratch_dir="/tmp", + ) def eval(self): """Sets the agent to evaluation mode.""" @@ -30,8 +38,41 @@ def get_eval_action(self, tensordict: TensorDictBase): def add_experience(self, transition: TensorDictBase): """Add experience to replay buffer""" - pass + self.replay_buffer.extend(transition) def train(self, batch_size=64, num_updates=1): """Train the agent""" return {} + + def create_replay_buffer( + self, + batch_size=256, + prb=False, + buffer_size=100000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, + ): + """Create replay buffer""" + # TODO: make this part of base off policy agent + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=1, + storage=LazyTensorStorage( + buffer_size, + ), + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + ), + batch_size=batch_size, + ) + return replay_buffer diff --git a/src/agents/sac.py b/src/agents/sac.py index 7636c18..a87049f 100644 --- a/src/agents/sac.py +++ b/src/agents/sac.py @@ -19,10 +19,8 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"): state_spec, action_spec, agent_config.name, device ) - self.actor = get_stochastic_actor( - self.observation_keys, action_spec, agent_config - ) - self.critic = get_critic(self.observation_keys, agent_config) + self.actor = get_stochastic_actor(state_spec, action_spec, agent_config) + self.critic = get_critic(state_spec, agent_config) self.actor.to(device) self.critic.to(device) @@ -52,14 +50,13 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"): self.batch_size = agent_config.batch_size # Define Replay Buffer + self.buffer_batch_size = agent_config.batch_size self.replay_buffer = self.create_replay_buffer( - batch_size=self.batch_size, prb=agent_config.prb, buffer_size=agent_config.buffer_size, - device=device, buffer_scratch_dir="/tmp", + device=device, ) - # Define Optimizer critic_params = list( self.loss_module.qvalue_network_params.flatten_keys().values() @@ -101,7 +98,8 @@ def load_model(self, path): def load_replaybuffer(self, path): """load replay buffer""" try: - self.replay_buffer.load(path) + loaded_data = TensorDictBase.load_memmap(path) + self.replay_buffer.extend(loaded_data) if self.replay_buffer._batch_size != self.batch_size: Warning( "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size." @@ -136,10 +134,9 @@ def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase: def create_replay_buffer( self, - batch_size=256, prb=False, buffer_size=100000, - buffer_scratch_dir=None, + buffer_scratch_dir=".", device="cpu", prefetch=3, ): @@ -163,9 +160,12 @@ def create_replay_buffer( buffer_size, scratch_dir=buffer_scratch_dir, ), - batch_size=batch_size, + batch_size=self.batch_size, ) replay_buffer.append_transform(lambda x: x.to(device)) + # TODO: check if we have image in observation space if so add this transform + # replay_buffer.append_transform(ToTensorImage(from_int=True, shape_tolerant=True)) + return replay_buffer @torch.no_grad() diff --git a/src/agents/td3.py b/src/agents/td3.py index 5c1b9e5..706d821 100644 --- a/src/agents/td3.py +++ b/src/agents/td3.py @@ -10,6 +10,7 @@ from torchrl.modules import AdditiveGaussianWrapper from torchrl.objectives import SoftUpdate from torchrl.objectives.td3 import TD3Loss +from torchrl.objectives.td3_bc import TD3BCLoss from src.agents.base import BaseAgent from src.networks.networks import get_critic, get_deterministic_actor @@ -30,12 +31,13 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"): state_spec, action_spec, agent_config.name, device ) - self.actor = get_deterministic_actor( - self.observation_keys, action_spec, agent_config - ) - self.critic = get_critic(self.observation_keys, agent_config) + self.actor = get_deterministic_actor(state_spec, action_spec, agent_config) + self.critic = get_critic(state_spec, agent_config) self.model = nn.ModuleList([self.actor, self.critic]).to(device) + + print(self.actor) + print(self.critic) # initialize networks self.init_nets(self.model) @@ -48,14 +50,27 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"): ).to(device) # define loss function - self.loss_module = TD3Loss( - actor_network=self.model[0], - qvalue_network=self.model[1], - action_spec=action_spec, - num_qvalue_nets=2, - loss_function=agent_config.loss_function, - separate_losses=False, - ) + self.use_bc = agent_config.use_bc + if not self.use_bc: + self.loss_module = TD3Loss( + actor_network=self.model[0], + qvalue_network=self.model[1], + action_spec=action_spec, + num_qvalue_nets=2, + loss_function=agent_config.loss_function, + separate_losses=False, + ) + else: + self.loss_module = TD3BCLoss( + actor_network=self.model[0], + qvalue_network=self.model[1], + action_spec=action_spec, + num_qvalue_nets=2, + loss_function=agent_config.loss_function, + separate_losses=False, + alpha=agent_config.alpha, + ) + # Define Target Network Updater self.target_net_updater = SoftUpdate( self.loss_module, eps=agent_config.soft_update_eps @@ -65,7 +80,6 @@ def __init__(self, state_spec, action_spec, agent_config, device="cpu"): self.batch_size = agent_config.batch_size # Define Replay Buffer self.replay_buffer = self.create_replay_buffer( - batch_size=self.batch_size, prb=agent_config.prb, buffer_size=agent_config.buffer_size, device=device, @@ -112,7 +126,8 @@ def load_model(self, path): def load_replaybuffer(self, path): """load replay buffer""" try: - self.replay_buffer.load(path) + loaded_data = TensorDictBase.load_memmap(path) + self.replay_buffer.extend(loaded_data) if self.replay_buffer._batch_size != self.batch_size: Warning( "Batch size of the loaded replay buffer is different from the agent's config batch size! Rewriting the batch size to match the agent's config batch size." @@ -133,7 +148,6 @@ def reset_networks(self): def create_replay_buffer( self, - batch_size=256, prb=False, buffer_size=100000, buffer_scratch_dir=None, @@ -160,9 +174,17 @@ def create_replay_buffer( buffer_size, scratch_dir=buffer_scratch_dir, ), - batch_size=batch_size, + batch_size=self.batch_size, ) replay_buffer.append_transform(lambda x: x.to(device)) + # TODO: check if we have image in observation space if so add this transform + # replay_buffer.append_transform( + # ToTensorImage( + # from_int=True, + # shape_tolerant=True, + # in_keys=["pixels", ("next", "pixels")], + # ) + # ) return replay_buffer def td_preprocessing(self, td: TensorDictBase) -> TensorDictBase: @@ -215,7 +237,10 @@ def train(self, batch_size=64, num_updates=1): else: sampled_tensordict = sampled_tensordict.clone() # Update Critic Network - q_loss, _ = self.loss_module.value_loss(sampled_tensordict) + if self.use_bc: + q_loss, _ = self.loss_module.qvalue_loss(sampled_tensordict) + else: + q_loss, _ = self.loss_module.value_loss(sampled_tensordict) self.optimizer_critic.zero_grad() q_loss.backward() self.optimizer_critic.step() diff --git a/src/networks/networks.py b/src/networks/networks.py index 4d82763..b1ff9c0 100644 --- a/src/networks/networks.py +++ b/src/networks/networks.py @@ -3,7 +3,6 @@ from tensordict.nn.distributions import NormalParamExtractor from torchrl.modules import ( - AdditiveGaussianWrapper, ConvNet, MLP, ProbabilisticActor, @@ -12,7 +11,7 @@ TanhModule, ValueOperator, ) -from torchrl.modules.distributions import TanhDelta, TanhNormal +from torchrl.modules.distributions import TanhNormal def get_normalization(normalization): @@ -26,7 +25,9 @@ def get_normalization(normalization): raise NotImplementedError(f"Normalization {normalization} not implemented") -def get_critic(observation_keys, agent_config): +def get_critic(observation_spec, agent_config): + observation_keys = [key for key in observation_spec.keys()] + if "observation" in observation_keys and not "pixels" in observation_keys: return get_vec_critic( in_keys=observation_keys, @@ -36,6 +37,17 @@ def get_critic(observation_keys, agent_config): normalization=agent_config.normalization, dropout=agent_config.dropout, ) + elif "pixels" in observation_keys and not "observation" in observation_keys: + return get_img_only_critic( + img_in_keys="pixels", + num_cells=[agent_config.num_cells, agent_config.num_cells], + out_features=1, + activation_class=nn.ReLU, + normalization=agent_config.normalization, + dropout=agent_config.dropout, + img_shape=observation_spec["pixels"].shape, + ) + elif "pixels" in observation_keys and "observation" in observation_keys: return get_mixed_critic( vec_in_keys="observation", @@ -45,11 +57,171 @@ def get_critic(observation_keys, agent_config): activation_class=nn.ReLU, normalization=agent_config.normalization, dropout=agent_config.dropout, + img_shape=observation_spec["pixels"].shape, ) else: raise NotImplementedError("Critic for this observation space not implemented") +def get_value_operator(observation_spec, agent_config): + observation_keys = [key for key in observation_spec.keys()] + if "observation" in observation_keys and not "pixels" in observation_keys: + return get_vec_value( + in_keys=observation_keys, + num_cells=[agent_config.num_cells, agent_config.num_cells], + out_features=1, + activation_class=nn.ReLU, + normalization=agent_config.normalization, + dropout=agent_config.dropout, + ) + elif "pixels" in observation_keys and not "observation" in observation_keys: + return get_img_only_value( + img_in_keys="pixels", + num_cells=[agent_config.num_cells, agent_config.num_cells], + out_features=1, + activation_class=nn.ReLU, + normalization=agent_config.normalization, + dropout=agent_config.dropout, + img_shape=observation_spec["pixels"].shape, + ) + elif "pixels" in observation_keys and "observation" in observation_keys: + return get_mixed_value( + vec_in_keys="observation", + img_in_keys="pixels", + num_cells=[agent_config.num_cells, agent_config.num_cells], + out_features=1, + activation_class=nn.ReLU, + normalization=agent_config.normalization, + dropout=agent_config.dropout, + img_shape=observation_spec["pixels"].shape, + ) + + +def get_vec_value( + in_keys=["observation"], + num_cells=[256, 256], + out_features=1, + activation_class=nn.ReLU, + normalization="None", + dropout=0.0, +): + """Returns a critic network""" + normalization = get_normalization(normalization) + qvalue_net = MLP( + num_cells=num_cells, + out_features=out_features, + activation_class=activation_class, + norm_class=normalization, + norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None, + dropout=dropout, + ) + + qvalue = ValueOperator( + in_keys=in_keys, + module=qvalue_net, + ) + return qvalue + + +def get_img_only_value( + img_in_keys, + num_cells=[256, 256], + out_features=1, + activation_class=nn.ReLU, + normalization="None", + dropout=0.0, + img_shape=(3, 64, 64), +): + normalization = get_normalization(normalization) + # image encoder + cnn = ConvNet( + activation_class=activation_class, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + cnn_output = cnn(torch.ones(img_shape)) + mlp = MLP( + in_features=cnn_output.shape[-1], + activation_class=activation_class, + out_features=128, + num_cells=[256], + ) + image_encoder = SafeModule( + torch.nn.Sequential(cnn, mlp), + in_keys=[img_in_keys], + out_keys=["pixel_embedding"], + ) + + # output head + mlp = MLP( + activation_class=torch.nn.ReLU, + out_features=out_features, + num_cells=num_cells, + norm_class=normalization, + norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None, + dropout=dropout, + ) + v_head = ValueOperator(mlp, ["pixel_embedding"]) + # model + return SafeSequential(image_encoder, v_head) + + +def get_mixed_value( + vec_in_keys, + img_in_keys, + num_cells=[256, 256], + out_features=1, + activation_class=nn.ReLU, + normalization="None", + dropout=0.0, + img_shape=(3, 64, 64), +): + normalization = get_normalization(normalization) + # image encoder + cnn = ConvNet( + activation_class=activation_class, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + cnn_output = cnn(torch.ones(img_shape)) + mlp = MLP( + in_features=cnn_output.shape[-1], + activation_class=activation_class, + out_features=128, + num_cells=[256], + ) + image_encoder = SafeModule( + torch.nn.Sequential(cnn, mlp), + in_keys=[img_in_keys], + out_keys=["pixel_embedding"], + ) + + # vector_obs encoder + mlp = MLP( + activation_class=activation_class, + out_features=32, + num_cells=[128], + ) + vector_obs_encoder = SafeModule( + mlp, in_keys=[vec_in_keys], out_keys=["obs_embedding"] + ) + + # output head + mlp = MLP( + activation_class=torch.nn.ReLU, + out_features=out_features, + num_cells=num_cells, + norm_class=normalization, + norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None, + dropout=dropout, + ) + v_head = ValueOperator(mlp, ["pixel_embedding", "obs_embedding"]) + # model + return SafeSequential(image_encoder, vector_obs_encoder, v_head) + + def get_vec_critic( in_keys=["observation"], num_cells=[256, 256], @@ -84,6 +256,7 @@ def get_mixed_critic( activation_class=nn.ReLU, normalization="None", dropout=0.0, + img_shape=(3, 64, 64), ): normalization = get_normalization(normalization) # image encoder @@ -93,7 +266,7 @@ def get_mixed_critic( kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ) - cnn_output = cnn(torch.ones((3, 64, 64))) + cnn_output = cnn(torch.ones(img_shape)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=activation_class, @@ -130,7 +303,53 @@ def get_mixed_critic( return SafeSequential(image_encoder, vector_obs_encoder, v_head) -def get_deterministic_actor(observation_keys, action_spec, agent_config): +def get_img_only_critic( + img_in_keys, + num_cells=[256, 256], + out_features=1, + activation_class=nn.ReLU, + normalization="None", + dropout=0.0, + img_shape=(3, 64, 64), +): + normalization = get_normalization(normalization) + # image encoder + cnn = ConvNet( + activation_class=activation_class, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + cnn_output = cnn(torch.ones(img_shape)) + mlp = MLP( + in_features=cnn_output.shape[-1], + activation_class=activation_class, + out_features=128, + num_cells=[256], + ) + image_encoder = SafeModule( + torch.nn.Sequential(cnn, mlp), + in_keys=[img_in_keys], + out_keys=["pixel_embedding"], + ) + + # output head + mlp = MLP( + activation_class=torch.nn.ReLU, + out_features=out_features, + num_cells=num_cells, + norm_class=normalization, + norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None, + dropout=dropout, + ) + v_head = ValueOperator(mlp, ["pixel_embedding", "action"]) + # model + return SafeSequential(image_encoder, v_head) + + +def get_deterministic_actor(observation_spec, action_spec, agent_config): + observation_keys = [key for key in observation_spec.keys()] + if "observation" in observation_keys and not "pixels" in observation_keys: return get_vec_deterministic_actor( action_spec=action_spec, @@ -139,6 +358,15 @@ def get_deterministic_actor(observation_keys, action_spec, agent_config): activation_class=nn.ReLU, ) + elif "pixels" in observation_keys and not "observation" in observation_keys: + return get_img_only_det_actor( + img_in_keys="pixels", + action_spec=action_spec, + num_cells=[agent_config.num_cells, agent_config.num_cells], + activation_class=nn.ReLU, + img_shape=observation_spec["pixels"].shape, + ) + elif "pixels" in observation_keys and "observation" in observation_keys: return get_mixed_deterministic_actor( vec_in_keys="observation", @@ -146,6 +374,7 @@ def get_deterministic_actor(observation_keys, action_spec, agent_config): action_spec=action_spec, num_cells=[agent_config.num_cells, agent_config.num_cells], activation_class=nn.ReLU, + img_shape=observation_spec["pixels"].shape, ) else: raise NotImplementedError("Actor for this observation space not implemented") @@ -190,6 +419,58 @@ def get_vec_deterministic_actor( return actor +def get_img_only_det_actor( + img_in_keys, + action_spec, + num_cells=[256, 256], + activation_class=nn.ReLU, + normalization="None", + dropout=0.0, + img_shape=(3, 64, 64), +): + normalization = get_normalization(normalization) + # image encoder + cnn = ConvNet( + activation_class=activation_class, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + cnn_output = cnn(torch.ones(img_shape)) + mlp = MLP( + in_features=cnn_output.shape[-1], + activation_class=activation_class, + out_features=128, + num_cells=[256], + ) + image_encoder = SafeModule( + torch.nn.Sequential(cnn, mlp), + in_keys=[img_in_keys], + out_keys=["pixel_embedding"], + ) + + # output head + mlp = MLP( + activation_class=torch.nn.ReLU, + out_features=action_spec.shape[-1], + num_cells=num_cells, + norm_class=normalization, + norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None, + dropout=dropout, + ) + combined = SafeModule(mlp, ["pixel_embedding"], out_keys=["param"]) + out_module = TanhModule( + in_keys=["param"], + out_keys=["action"], + spec=action_spec, + ) + return SafeSequential( + image_encoder, + combined, + out_module, + ) + + def get_mixed_deterministic_actor( vec_in_keys, img_in_keys, @@ -198,6 +479,7 @@ def get_mixed_deterministic_actor( activation_class=nn.ReLU, normalization="None", dropout=0.0, + img_shape=(3, 64, 64), ): normalization = get_normalization(normalization) # image encoder @@ -207,7 +489,7 @@ def get_mixed_deterministic_actor( kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ) - cnn_output = cnn(torch.ones((3, 64, 64))) + cnn_output = cnn(torch.ones(img_shape)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=activation_class, @@ -253,7 +535,8 @@ def get_mixed_deterministic_actor( ) -def get_stochastic_actor(observation_keys, action_spec, agent_config): +def get_stochastic_actor(observation_spec, action_spec, agent_config): + observation_keys = [key for key in observation_spec.keys()] if "observation" in observation_keys and not "pixels" in observation_keys: return get_vec_stochastic_actor( action_spec, @@ -263,6 +546,16 @@ def get_stochastic_actor(observation_keys, action_spec, agent_config): dropout=agent_config.dropout, activation_class=nn.ReLU, ) + elif "pixels" in observation_keys and not "observation" in observation_keys: + return get_img_only_stochastic_actor( + img_in_keys="pixels", + action_spec=action_spec, + num_cells=[agent_config.num_cells, agent_config.num_cells], + normalization=agent_config.normalization, + dropout=agent_config.dropout, + activation_class=nn.ReLU, + img_shape=observation_spec["pixels"].shape, + ) elif "pixels" in observation_keys and "observation" in observation_keys: return get_mixed_stochastic_actor( action_spec, @@ -272,6 +565,7 @@ def get_stochastic_actor(observation_keys, action_spec, agent_config): normalization=agent_config.normalization, dropout=agent_config.dropout, activation_class=nn.ReLU, + img_shape=observation_spec["pixels"].shape, ) else: raise NotImplementedError("Actor for this observation space not implemented") @@ -330,6 +624,85 @@ def get_vec_stochastic_actor( return actor +def get_img_only_stochastic_actor( + action_spec, + img_in_keys, + num_cells=[256, 256], + normalization="None", + dropout=0.0, + activation_class=nn.ReLU, + img_shape=(3, 64, 64), +): + + normalization = get_normalization(normalization) + # image encoder + cnn = ConvNet( + activation_class=activation_class, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + cnn_output = cnn(torch.ones(img_shape)) + mlp = MLP( + in_features=cnn_output.shape[-1], + activation_class=activation_class, + out_features=128, + num_cells=[256], + ) + image_encoder = SafeModule( + torch.nn.Sequential(cnn, mlp), + in_keys=[img_in_keys], + out_keys=["pixel_embedding"], + ) + + # output head + mlp = MLP( + activation_class=torch.nn.ReLU, + out_features=2 * action_spec.shape[-1], + num_cells=num_cells, + norm_class=normalization, + norm_kwargs={"normalized_shape": num_cells[-1]} if normalization else None, + dropout=dropout, + ) + actor_module = SafeModule( + mlp, + in_keys=["pixel_embedding"], + out_keys=["params"], + ) + actor_extractor = NormalParamExtractor( + scale_mapping=f"biased_softplus_{1.0}", + scale_lb=0.1, + ) + + extractor_module = SafeModule( + actor_extractor, + in_keys=["params"], + out_keys=[ + "loc", + "scale", + ], + ) + actor_net_combined = SafeSequential(image_encoder, actor_module, extractor_module) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.low, + "max": action_spec.space.high, + "tanh_loc": False, + } + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + out_keys=["action"], + module=actor_net_combined, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + return_log_prob=False, + ) + return actor + + def get_mixed_stochastic_actor( action_spec, vec_in_keys, @@ -338,6 +711,7 @@ def get_mixed_stochastic_actor( normalization="None", dropout=0.0, activation_class=nn.ReLU, + img_shape=(3, 64, 64), ): normalization = get_normalization(normalization) @@ -348,7 +722,7 @@ def get_mixed_stochastic_actor( kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ) - cnn_output = cnn(torch.ones((3, 64, 64))) + cnn_output = cnn(torch.ones(img_shape)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=activation_class, diff --git a/src/utils.py b/src/utils.py index 3ea4c73..50adae4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -6,7 +6,7 @@ from environments import ALL_2WHEELER_ENVS, ALL_ROBOARM_ENVS, ALL_WALKER_ENVS from moviepy.editor import concatenate_videoclips, ImageClip from omegaconf import DictConfig -from tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.envs.utils import step_mdp from tqdm import tqdm @@ -49,7 +49,11 @@ def logout(agent): x = input("Do you want to save the replay buffer? (y/n)") if x == "y": save_name = input("Enter the name of the file to save: ") - agent.replay_buffer.dump(save_name) + # agent.replay_buffer.dump(save_name) + batched_data = agent.replay_buffer.storage._storage[ + : agent.replay_buffer.__len__() + ] + batched_data.save(save_name, copy_existing=True) def login(agent): @@ -82,7 +86,6 @@ def prefill_buffer(env, agent, num_episodes=10, stop_on_done=False): inpt = input("Press Enter to start prefilling episode: ") for e in tqdm(range(num_episodes), desc="Prefilling buffer"): print("Prefill episode: ", e) - td = env.reset() done = False truncated = False diff --git a/tests/test_agents.py b/tests/test_agents.py index 0a21c6c..754a812 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -17,9 +17,9 @@ def collection_round(env, agent, max_steps=1000): td = step_mdp(td) -def get_env(env): +def get_env(env, img_shape=(64, 64, 3)): if env == "mixed": - env = MixedObsDummyEnv() + env = MixedObsDummyEnv(img_shape=img_shape) env = TransformedEnv( env, Compose(ToTensorImage(in_keys=["pixels"], from_int=True)) ) @@ -49,7 +49,9 @@ def test_random_agent(env, device): else: device = "cpu" with initialize(config_path="../conf"): - cfg = compose(config_name="config", overrides=["device=" + device]) + cfg = compose( + config_name="config", overrides=["device=" + device, "agent=random"] + ) # Test data collection env = get_env(env) agent, _ = get_agent(env.action_spec, env.observation_spec, cfg) @@ -171,3 +173,154 @@ def test_drq_agent(env, device): eval_td2 = agent.get_eval_action(td) assert torch.allclose(eval_td1["action"], eval_td2["action"]) + + +@pytest.mark.parametrize( + "env", + ["mixed", "vec", "vec_goal"], +) +@pytest.mark.parametrize( + "device", + ["cpu", "cuda"], +) +def test_iql_agent(env, device): + if torch.cuda.is_available() and device == "cuda": + device = "cuda" + else: + device = "cpu" + with initialize(config_path="../conf"): + cfg = compose(config_name="config", overrides=["agent=iql", "device=" + device]) + + # Test data collection + env = get_env(env) + agent, _ = get_agent(env.action_spec, env.observation_spec, cfg) + collection_round(env, agent, max_steps=10) + # Test training + agent.train(batch_size=1, num_updates=1) + + # Test evaluation + td = env.reset() + td1 = agent.get_action(td) + td2 = agent.get_action(td) + + assert not torch.allclose(td1["action"], td2["action"]) + + agent.eval() + td = env.reset() + eval_td1 = agent.get_eval_action(td) + eval_td2 = agent.get_eval_action(td) + + assert torch.allclose(eval_td1["action"], eval_td2["action"]) + + +@pytest.mark.parametrize( + "env", + ["mixed", "vec", "vec_goal"], +) +@pytest.mark.parametrize( + "device", + ["cpu", "cuda"], +) +def test_cql_agent(env, device): + if torch.cuda.is_available() and device == "cuda": + device = "cuda" + else: + device = "cpu" + with initialize(config_path="../conf"): + cfg = compose(config_name="config", overrides=["agent=cql", "device=" + device]) + + # Test data collection + env = get_env(env) + agent, _ = get_agent(env.action_spec, env.observation_spec, cfg) + collection_round(env, agent, max_steps=10) + # Test training + agent.train(batch_size=1, num_updates=1) + + # Test evaluation + td = env.reset() + td1 = agent.get_action(td) + td2 = agent.get_action(td) + + assert not torch.allclose(td1["action"], td2["action"]) + + agent.eval() + td = env.reset() + eval_td1 = agent.get_eval_action(td) + eval_td2 = agent.get_eval_action(td) + + assert torch.allclose(eval_td1["action"], eval_td2["action"]) + + +@pytest.mark.parametrize( + "env", + ["mixed", "vec", "vec_goal"], +) +@pytest.mark.parametrize( + "device", + ["cpu", "cuda"], +) +def test_bc_agent(env, device): + if torch.cuda.is_available() and device == "cuda": + device = "cuda" + else: + device = "cpu" + with initialize(config_path="../conf"): + cfg = compose(config_name="config", overrides=["agent=bc", "device=" + device]) + + # Test data collection + env = get_env(env) + agent, _ = get_agent(env.action_spec, env.observation_spec, cfg) + collection_round(env, agent, max_steps=10) + # Test training + agent.train(batch_size=1, num_updates=1) + + # Test evaluation + agent.eval() + td = env.reset() + eval_td1 = agent.get_eval_action(td) + eval_td2 = agent.get_eval_action(td) + + assert torch.allclose(eval_td1["action"], eval_td2["action"]) + + +@pytest.mark.parametrize( + "env", + ["mixed"], +) +@pytest.mark.parametrize( + "img_shape", + [(64, 64, 3), (128, 128, 3)], +) +@pytest.mark.parametrize( + "device", + ["cpu", "cuda"], +) +def test_mixd_obs_size_agent(env, device, img_shape): + if torch.cuda.is_available() and device == "cuda": + device = "cuda" + else: + device = "cpu" + with initialize(config_path="../conf"): + cfg = compose(config_name="config", overrides=["agent=td3", "device=" + device]) + + # Test data collection + env = get_env(env, img_shape) + agent, _ = get_agent(env.action_spec, env.observation_spec, cfg) + collection_round(env, agent, max_steps=10) + + # Test training + agent.train(batch_size=1, num_updates=1) + + # Test evaluation + td = env.reset() + td1 = agent.get_action(td) + td2 = agent.get_action(td) + + assert not torch.allclose(td1["action"], td2["action"]) + + agent.eval() + td = env.reset() + eval_td1 = agent.get_eval_action(td) + eval_td2 = agent.get_eval_action(td) + + assert torch.allclose(eval_td1["action"], eval_td2["action"])