Skip to content

Adds option to define the concatenation dimension in the ObservationManager and change counter update in CommandManager #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/isaaclab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.36.23"
version = "0.36.24"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
17 changes: 17 additions & 0 deletions source/isaaclab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
Changelog
---------


0.36.24 (2025-04-28)
~~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added support for concatenation of observations along different dimensions in :class:`~isaaclab.managers.observation_manager.ObservationManager`.

Changed
^^^^^^^

* Updated the :class:`~isaaclab.managers.command_manager.CommandManager` to update the command counter after the
resampling call.



0.36.23 (2025-04-24)
~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 2 additions & 2 deletions source/isaaclab/isaaclab/managers/command_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def _resample(self, env_ids: Sequence[int]):
if len(env_ids) != 0:
# resample the time left before resampling
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
# increment the command counter
self.command_counter[env_ids] += 1
# resample the command
self._resample_command(env_ids)
# increment the command counter
self.command_counter[env_ids] += 1

"""
Implementation specific functions.
Expand Down
11 changes: 10 additions & 1 deletion source/isaaclab/isaaclab/managers/manager_term_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,21 @@ class ObservationGroupCfg:
concatenate_terms: bool = True
"""Whether to concatenate the observation terms in the group. Defaults to True.

If true, the observation terms in the group are concatenated along the last dimension.
If true, the observation terms in the group are concatenated along the dimension specified in concatenate_dim.
Otherwise, they are kept separate and returned as a dictionary.

If the observation group contains terms of different dimensions, it must be set to False.
"""

concatenate_dim: int = -1
"""Dimension along to concatenate the different oberservation terms. Defaults to -1.

If concatenate_terms is True, this specifies the dimension along which the observation terms are concatenated.
The indicated dimension is the one of the observations, i.e. for a 2D RGB image (H, W, C), the dimension
0 means concatenating along the height, 1 along the width and 2 along the channels. The offset due
to the batched environment is handled automatically.
"""

enable_corruption: bool = False
"""Whether to enable corruption for the observation group. Defaults to False.

Expand Down
29 changes: 25 additions & 4 deletions source/isaaclab/isaaclab/managers/observation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ def __init__(self, cfg: object, env: ManagerBasedEnv):
# otherwise, keep the list of shapes as is
if self._group_obs_concatenate[group_name]:
try:
term_dims = [torch.tensor(dims, device="cpu") for dims in group_term_dims]
self._group_obs_dim[group_name] = tuple(torch.sum(torch.stack(term_dims, dim=0), dim=0).tolist())
term_dims = torch.stack([torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0)
if len(term_dims.shape) > 1:
dim_sum = torch.sum(term_dims[:, self._group_obs_concatenate_dim[group_name]], dim=0)
term_dims[0, self._group_obs_concatenate_dim[group_name]] = dim_sum
term_dims = term_dims[0]
else:
term_dims = torch.sum(term_dims, dim=0)
self._group_obs_dim[group_name] = tuple(term_dims.tolist())
except RuntimeError:
raise RuntimeError(
f"Unable to concatenate observation terms in group '{group_name}'."
Expand Down Expand Up @@ -330,7 +336,13 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso

# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
return torch.cat(list(group_obs.values()), dim=-1)
# set the concatenate dimension, account for the batch dimension if positive dimension is given
dim = (
self._group_obs_concatenate_dim[group_name] + 1
if self._group_obs_concatenate_dim[group_name] >= 0
else self._group_obs_concatenate_dim[group_name]
)
return torch.cat(list(group_obs.values()), dim=dim)
else:
return group_obs

Expand All @@ -347,6 +359,8 @@ def _prepare_terms(self):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_concatenate_dim: dict[str, int] = dict()

self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls
Expand Down Expand Up @@ -384,6 +398,7 @@ def _prepare_terms(self):
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
self._group_obs_concatenate_dim[group_name] = group_cfg.concatenate_dim
# check if config is dict already
if isinstance(group_cfg, dict):
group_cfg_items = group_cfg.items()
Expand All @@ -392,7 +407,13 @@ def _prepare_terms(self):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
if term_name in [
"enable_corruption",
"concatenate_terms",
"history_length",
"flatten_history_dim",
"concatenate_dim",
]:
continue
# check for non config
if term_cfg is None:
Expand Down
66 changes: 66 additions & 0 deletions source/isaaclab/test/managers/test_observation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,72 @@ class PolicyCfg(ObservationGroupCfg):
with self.assertRaises(ValueError):
self.obs_man = ObservationManager(cfg, self.env)

def test_concatenate_dim(self):
"""Test concatenation of observations along different dimensions."""

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

concatenate_terms = True
concatenate_dim = 1 # Concatenate along dimension 1
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})

@configclass
class CriticCfg(ObservationGroupCfg):
"""Test config class for critic observation group."""

concatenate_terms = True
concatenate_dim = 2 # Concatenate along dimension 2
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})

@configclass
class CriticCfg_neg_dim(ObservationGroupCfg):
"""Test config class for critic observation group."""

concatenate_terms = True
concatenate_dim = -1 # Concatenate along last dimension
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})

policy: ObservationGroupCfg = PolicyCfg()
critic: ObservationGroupCfg = CriticCfg()
critic_neg_dim: ObservationGroupCfg = CriticCfg_neg_dim()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()

# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
obs_critic: torch.Tensor = observations["critic"]
obs_critic_neg_dim: torch.Tensor = observations["critic_neg_dim"]

# check the observation shapes
# For policy: concatenated along dim 1, so width should be doubled
self.assertEqual((self.env.num_envs, 128, 512, 1), obs_policy.shape)
# For critic: concatenated along last dim, so channels should be doubled
self.assertEqual((self.env.num_envs, 128, 256, 2), obs_critic.shape)
# For critic_neg_dim: concatenated along last dim, so channels should be doubled
self.assertEqual((self.env.num_envs, 128, 256, 2), obs_critic_neg_dim.shape)

# verify the data is concatenated correctly
# For policy: check that the second half matches the first half
torch.testing.assert_close(obs_policy[:, :, :256, :], obs_policy[:, :, 256:, :])
# For critic: check that the second channel matches the first channel
torch.testing.assert_close(obs_critic[:, :, :, 0], obs_critic[:, :, :, 1])

# For critic_neg_dim: check that it is the same as critic
torch.testing.assert_close(obs_critic_neg_dim, obs_critic)


if __name__ == "__main__":
run_tests()