Skip to content

Adds option to define the concatenation dimension in the ObservationManager and change counter update in CommandManager #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 17, 2025
Merged
2 changes: 1 addition & 1 deletion source/isaaclab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.39.4"
version = "0.39.5"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
15 changes: 15 additions & 0 deletions source/isaaclab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
Changelog
---------

0.39.5 (2025-05-16)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added support for concatenation of observations along different dimensions in :class:`~isaaclab.managers.observation_manager.ObservationManager`.

Changed
^^^^^^^

* Updated the :class:`~isaaclab.managers.command_manager.CommandManager` to update the command counter after the
resampling call.


0.39.4 (2025-05-16)
~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 2 additions & 2 deletions source/isaaclab/isaaclab/managers/command_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def _resample(self, env_ids: Sequence[int]):
if len(env_ids) != 0:
# resample the time left before resampling
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
# increment the command counter
self.command_counter[env_ids] += 1
# resample the command
self._resample_command(env_ids)
# increment the command counter
self.command_counter[env_ids] += 1

"""
Implementation specific functions.
Expand Down
12 changes: 11 additions & 1 deletion source/isaaclab/isaaclab/managers/manager_term_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,22 @@ class ObservationGroupCfg:
concatenate_terms: bool = True
"""Whether to concatenate the observation terms in the group. Defaults to True.

If true, the observation terms in the group are concatenated along the last dimension.
If true, the observation terms in the group are concatenated along the dimension specified through :attr:`concatenate_dim`.
Otherwise, they are kept separate and returned as a dictionary.

If the observation group contains terms of different dimensions, it must be set to False.
"""

concatenate_dim: int = -1
"""Dimension along to concatenate the different observation terms. Defaults to -1, which
means the last dimension of the observation terms.

If :attr:`concatenate_terms` is True, this parameter specifies the dimension along which the observation terms are concatenated.
The indicated dimension depends on the shape of the observations. For instance, for a 2D RGB image of shape (H, W, C), the dimension
0 means concatenating along the height, 1 along the width, and 2 along the channels. The offset due
to the batched environment is handled automatically.
"""

enable_corruption: bool = False
"""Whether to enable corruption for the observation group. Defaults to False.

Expand Down
30 changes: 26 additions & 4 deletions source/isaaclab/isaaclab/managers/observation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,18 @@ def __init__(self, cfg: object, env: ManagerBasedEnv):
# otherwise, keep the list of shapes as is
if self._group_obs_concatenate[group_name]:
try:
term_dims = [torch.tensor(dims, device="cpu") for dims in group_term_dims]
self._group_obs_dim[group_name] = tuple(torch.sum(torch.stack(term_dims, dim=0), dim=0).tolist())
term_dims = torch.stack([torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0)
if len(term_dims.shape) > 1:
if self._group_obs_concatenate_dim[group_name] >= 0:
dim = self._group_obs_concatenate_dim[group_name] - 1 # account for the batch offset
else:
dim = self._group_obs_concatenate_dim[group_name]
dim_sum = torch.sum(term_dims[:, dim], dim=0)
term_dims[0, dim] = dim_sum
term_dims = term_dims[0]
else:
term_dims = torch.sum(term_dims, dim=0)
self._group_obs_dim[group_name] = tuple(term_dims.tolist())
except RuntimeError:
raise RuntimeError(
f"Unable to concatenate observation terms in group '{group_name}'."
Expand Down Expand Up @@ -330,7 +340,8 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso

# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
return torch.cat(list(group_obs.values()), dim=-1)
# set the concatenate dimension, account for the batch dimension if positive dimension is given
return torch.cat(list(group_obs.values()), dim=self._group_obs_concatenate_dim[group_name])
else:
return group_obs

Expand Down Expand Up @@ -370,6 +381,8 @@ def _prepare_terms(self):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_concatenate_dim: dict[str, int] = dict()

self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls
Expand Down Expand Up @@ -407,6 +420,9 @@ def _prepare_terms(self):
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
self._group_obs_concatenate_dim[group_name] = (
group_cfg.concatenate_dim + 1 if group_cfg.concatenate_dim >= 0 else group_cfg.concatenate_dim
)
# check if config is dict already
if isinstance(group_cfg, dict):
group_cfg_items = group_cfg.items()
Expand All @@ -415,7 +431,13 @@ def _prepare_terms(self):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
if term_name in [
"enable_corruption",
"concatenate_terms",
"history_length",
"flatten_history_dim",
"concatenate_dim",
]:
continue
# check for non config
if term_cfg is None:
Expand Down
118 changes: 94 additions & 24 deletions source/isaaclab/test/managers/test_observation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,41 +667,43 @@ class CriticCfg(ObservationGroupCfg):
assert torch.min(obs_critic["term_4"]) >= -0.5
assert torch.max(obs_critic["term_4"]) <= 0.5

def test_serialize(self):
"""Test serialize call for ManagerTermBase terms."""

serialize_data = {"test": 0}
def test_serialize(setup_env):
"""Test serialize call for ManagerTermBase terms."""
env = setup_env

class test_serialize_term(ManagerTermBase):
serialize_data = {"test": 0}

def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
super().__init__(cfg, env)
class test_serialize_term(ManagerTermBase):

def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
return grilled_chicken(env)
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
super().__init__(cfg, env)

def serialize(self) -> dict:
return serialize_data
def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
return grilled_chicken(env)

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
def serialize(self) -> dict:
return serialize_data

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

concatenate_terms = False
term_1 = ObservationTermCfg(func=test_serialize_term)
concatenate_terms = False
term_1 = ObservationTermCfg(func=test_serialize_term)

policy: ObservationGroupCfg = PolicyCfg()
policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# create observation manager
cfg = MyObservationManagerCfg()
obs_man = ObservationManager(cfg, env)

# check expected output
self.assertEqual(self.obs_man.serialize(), {"policy": {"term_1": serialize_data}})
# check expected output
assert obs_man.serialize() == {"policy": {"term_1": serialize_data}}


def test_modifier_invalid_config(setup_env):
Expand All @@ -728,3 +730,71 @@ class PolicyCfg(ObservationGroupCfg):

with pytest.raises(ValueError):
ObservationManager(cfg, env)


def test_concatenate_dim(setup_env):
"""Test concatenation of observations along different dimensions."""
env = setup_env

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

concatenate_terms = True
concatenate_dim = 1 # Concatenate along dimension 1
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})

@configclass
class CriticCfg(ObservationGroupCfg):
"""Test config class for critic observation group."""

concatenate_terms = True
concatenate_dim = 2 # Concatenate along dimension 2
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})

@configclass
class CriticCfg_neg_dim(ObservationGroupCfg):
"""Test config class for critic observation group."""

concatenate_terms = True
concatenate_dim = -1 # Concatenate along last dimension
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})

policy: ObservationGroupCfg = PolicyCfg()
critic: ObservationGroupCfg = CriticCfg()
critic_neg_dim: ObservationGroupCfg = CriticCfg_neg_dim()

# create observation manager
cfg = MyObservationManagerCfg()
obs_man = ObservationManager(cfg, env)
# compute observation using manager
observations = obs_man.compute()

# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
obs_critic: torch.Tensor = observations["critic"]
obs_critic_neg_dim: torch.Tensor = observations["critic_neg_dim"]

# check the observation shapes
# For policy: concatenated along dim 1, so width should be doubled
assert obs_policy.shape == (env.num_envs, 128, 512, 1)
# For critic: concatenated along last dim, so channels should be doubled
assert obs_critic.shape == (env.num_envs, 128, 256, 2)
# For critic_neg_dim: concatenated along last dim, so channels should be doubled
assert obs_critic_neg_dim.shape == (env.num_envs, 128, 256, 2)

# verify the data is concatenated correctly
# For policy: check that the second half matches the first half
torch.testing.assert_close(obs_policy[:, :, :256, :], obs_policy[:, :, 256:, :])
# For critic: check that the second channel matches the first channel
torch.testing.assert_close(obs_critic[:, :, :, 0], obs_critic[:, :, :, 1])

# For critic_neg_dim: check that it is the same as critic
torch.testing.assert_close(obs_critic_neg_dim, obs_critic)