Skip to content

Commit be41bb0

Browse files
pascal-rothMayankm96kellyguo11
authored
Adds option to define the concatenation dimension in the ObservationManager and change counter update in CommandManager (#2393)
# Description Added support for concatenation of observations along different dimensions in `ObservationManager`. Updates the position where the command counter is increased to allow checking for reset environments in the resample call of the `CommandManager` ## Type of change - New feature (non-breaking change which adds functionality) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --------- Signed-off-by: Pascal Roth <[email protected]> Signed-off-by: Kelly Guo <[email protected]> Signed-off-by: Kelly Guo <[email protected]> Co-authored-by: Mayank Mittal <[email protected]> Co-authored-by: Kelly Guo <[email protected]> Co-authored-by: Kelly Guo <[email protected]>
1 parent 963b53b commit be41bb0

File tree

6 files changed

+149
-32
lines changed

6 files changed

+149
-32
lines changed

source/isaaclab/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.39.4"
4+
version = "0.39.5"
55

66
# Description
77
title = "Isaac Lab framework for Robot Learning"

source/isaaclab/docs/CHANGELOG.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
Changelog
22
---------
33

4+
0.39.5 (2025-05-16)
5+
~~~~~~~~~~~~~~~~~~~
6+
7+
Added
8+
^^^^^
9+
10+
* Added support for concatenation of observations along different dimensions in :class:`~isaaclab.managers.observation_manager.ObservationManager`.
11+
12+
Changed
13+
^^^^^^^
14+
15+
* Updated the :class:`~isaaclab.managers.command_manager.CommandManager` to update the command counter after the
16+
resampling call.
17+
18+
419
0.39.4 (2025-05-16)
520
~~~~~~~~~~~~~~~~~~~
621

source/isaaclab/isaaclab/managers/command_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def _resample(self, env_ids: Sequence[int]):
181181
if len(env_ids) != 0:
182182
# resample the time left before resampling
183183
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
184-
# increment the command counter
185-
self.command_counter[env_ids] += 1
186184
# resample the command
187185
self._resample_command(env_ids)
186+
# increment the command counter
187+
self.command_counter[env_ids] += 1
188188

189189
"""
190190
Implementation specific functions.

source/isaaclab/isaaclab/managers/manager_term_cfg.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,22 @@ class ObservationGroupCfg:
201201
concatenate_terms: bool = True
202202
"""Whether to concatenate the observation terms in the group. Defaults to True.
203203
204-
If true, the observation terms in the group are concatenated along the last dimension.
204+
If true, the observation terms in the group are concatenated along the dimension specified through :attr:`concatenate_dim`.
205205
Otherwise, they are kept separate and returned as a dictionary.
206206
207207
If the observation group contains terms of different dimensions, it must be set to False.
208208
"""
209209

210+
concatenate_dim: int = -1
211+
"""Dimension along to concatenate the different observation terms. Defaults to -1, which
212+
means the last dimension of the observation terms.
213+
214+
If :attr:`concatenate_terms` is True, this parameter specifies the dimension along which the observation terms are concatenated.
215+
The indicated dimension depends on the shape of the observations. For instance, for a 2D RGB image of shape (H, W, C), the dimension
216+
0 means concatenating along the height, 1 along the width, and 2 along the channels. The offset due
217+
to the batched environment is handled automatically.
218+
"""
219+
210220
enable_corruption: bool = False
211221
"""Whether to enable corruption for the observation group. Defaults to False.
212222

source/isaaclab/isaaclab/managers/observation_manager.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,18 @@ def __init__(self, cfg: object, env: ManagerBasedEnv):
8888
# otherwise, keep the list of shapes as is
8989
if self._group_obs_concatenate[group_name]:
9090
try:
91-
term_dims = [torch.tensor(dims, device="cpu") for dims in group_term_dims]
92-
self._group_obs_dim[group_name] = tuple(torch.sum(torch.stack(term_dims, dim=0), dim=0).tolist())
91+
term_dims = torch.stack([torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0)
92+
if len(term_dims.shape) > 1:
93+
if self._group_obs_concatenate_dim[group_name] >= 0:
94+
dim = self._group_obs_concatenate_dim[group_name] - 1 # account for the batch offset
95+
else:
96+
dim = self._group_obs_concatenate_dim[group_name]
97+
dim_sum = torch.sum(term_dims[:, dim], dim=0)
98+
term_dims[0, dim] = dim_sum
99+
term_dims = term_dims[0]
100+
else:
101+
term_dims = torch.sum(term_dims, dim=0)
102+
self._group_obs_dim[group_name] = tuple(term_dims.tolist())
93103
except RuntimeError:
94104
raise RuntimeError(
95105
f"Unable to concatenate observation terms in group '{group_name}'."
@@ -330,7 +340,8 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
330340

331341
# concatenate all observations in the group together
332342
if self._group_obs_concatenate[group_name]:
333-
return torch.cat(list(group_obs.values()), dim=-1)
343+
# set the concatenate dimension, account for the batch dimension if positive dimension is given
344+
return torch.cat(list(group_obs.values()), dim=self._group_obs_concatenate_dim[group_name])
334345
else:
335346
return group_obs
336347

@@ -370,6 +381,8 @@ def _prepare_terms(self):
370381
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
371382
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
372383
self._group_obs_concatenate: dict[str, bool] = dict()
384+
self._group_obs_concatenate_dim: dict[str, int] = dict()
385+
373386
self._group_obs_term_history_buffer: dict[str, dict] = dict()
374387
# create a list to store modifiers that are classes
375388
# we store it as a separate list to only call reset on them and prevent unnecessary calls
@@ -407,6 +420,9 @@ def _prepare_terms(self):
407420
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
408421
# read common config for the group
409422
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
423+
self._group_obs_concatenate_dim[group_name] = (
424+
group_cfg.concatenate_dim + 1 if group_cfg.concatenate_dim >= 0 else group_cfg.concatenate_dim
425+
)
410426
# check if config is dict already
411427
if isinstance(group_cfg, dict):
412428
group_cfg_items = group_cfg.items()
@@ -415,7 +431,13 @@ def _prepare_terms(self):
415431
# iterate over all the terms in each group
416432
for term_name, term_cfg in group_cfg_items:
417433
# skip non-obs settings
418-
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
434+
if term_name in [
435+
"enable_corruption",
436+
"concatenate_terms",
437+
"history_length",
438+
"flatten_history_dim",
439+
"concatenate_dim",
440+
]:
419441
continue
420442
# check for non config
421443
if term_cfg is None:

source/isaaclab/test/managers/test_observation_manager.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -667,41 +667,43 @@ class CriticCfg(ObservationGroupCfg):
667667
assert torch.min(obs_critic["term_4"]) >= -0.5
668668
assert torch.max(obs_critic["term_4"]) <= 0.5
669669

670-
def test_serialize(self):
671-
"""Test serialize call for ManagerTermBase terms."""
672670

673-
serialize_data = {"test": 0}
671+
def test_serialize(setup_env):
672+
"""Test serialize call for ManagerTermBase terms."""
673+
env = setup_env
674674

675-
class test_serialize_term(ManagerTermBase):
675+
serialize_data = {"test": 0}
676676

677-
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
678-
super().__init__(cfg, env)
677+
class test_serialize_term(ManagerTermBase):
679678

680-
def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
681-
return grilled_chicken(env)
679+
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
680+
super().__init__(cfg, env)
682681

683-
def serialize(self) -> dict:
684-
return serialize_data
682+
def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
683+
return grilled_chicken(env)
685684

686-
@configclass
687-
class MyObservationManagerCfg:
688-
"""Test config class for observation manager."""
685+
def serialize(self) -> dict:
686+
return serialize_data
687+
688+
@configclass
689+
class MyObservationManagerCfg:
690+
"""Test config class for observation manager."""
689691

690-
@configclass
691-
class PolicyCfg(ObservationGroupCfg):
692-
"""Test config class for policy observation group."""
692+
@configclass
693+
class PolicyCfg(ObservationGroupCfg):
694+
"""Test config class for policy observation group."""
693695

694-
concatenate_terms = False
695-
term_1 = ObservationTermCfg(func=test_serialize_term)
696+
concatenate_terms = False
697+
term_1 = ObservationTermCfg(func=test_serialize_term)
696698

697-
policy: ObservationGroupCfg = PolicyCfg()
699+
policy: ObservationGroupCfg = PolicyCfg()
698700

699-
# create observation manager
700-
cfg = MyObservationManagerCfg()
701-
self.obs_man = ObservationManager(cfg, self.env)
701+
# create observation manager
702+
cfg = MyObservationManagerCfg()
703+
obs_man = ObservationManager(cfg, env)
702704

703-
# check expected output
704-
self.assertEqual(self.obs_man.serialize(), {"policy": {"term_1": serialize_data}})
705+
# check expected output
706+
assert obs_man.serialize() == {"policy": {"term_1": serialize_data}}
705707

706708

707709
def test_modifier_invalid_config(setup_env):
@@ -728,3 +730,71 @@ class PolicyCfg(ObservationGroupCfg):
728730

729731
with pytest.raises(ValueError):
730732
ObservationManager(cfg, env)
733+
734+
735+
def test_concatenate_dim(setup_env):
736+
"""Test concatenation of observations along different dimensions."""
737+
env = setup_env
738+
739+
@configclass
740+
class MyObservationManagerCfg:
741+
"""Test config class for observation manager."""
742+
743+
@configclass
744+
class PolicyCfg(ObservationGroupCfg):
745+
"""Test config class for policy observation group."""
746+
747+
concatenate_terms = True
748+
concatenate_dim = 1 # Concatenate along dimension 1
749+
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
750+
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
751+
752+
@configclass
753+
class CriticCfg(ObservationGroupCfg):
754+
"""Test config class for critic observation group."""
755+
756+
concatenate_terms = True
757+
concatenate_dim = 2 # Concatenate along dimension 2
758+
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
759+
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
760+
761+
@configclass
762+
class CriticCfg_neg_dim(ObservationGroupCfg):
763+
"""Test config class for critic observation group."""
764+
765+
concatenate_terms = True
766+
concatenate_dim = -1 # Concatenate along last dimension
767+
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
768+
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
769+
770+
policy: ObservationGroupCfg = PolicyCfg()
771+
critic: ObservationGroupCfg = CriticCfg()
772+
critic_neg_dim: ObservationGroupCfg = CriticCfg_neg_dim()
773+
774+
# create observation manager
775+
cfg = MyObservationManagerCfg()
776+
obs_man = ObservationManager(cfg, env)
777+
# compute observation using manager
778+
observations = obs_man.compute()
779+
780+
# obtain the group observations
781+
obs_policy: torch.Tensor = observations["policy"]
782+
obs_critic: torch.Tensor = observations["critic"]
783+
obs_critic_neg_dim: torch.Tensor = observations["critic_neg_dim"]
784+
785+
# check the observation shapes
786+
# For policy: concatenated along dim 1, so width should be doubled
787+
assert obs_policy.shape == (env.num_envs, 128, 512, 1)
788+
# For critic: concatenated along last dim, so channels should be doubled
789+
assert obs_critic.shape == (env.num_envs, 128, 256, 2)
790+
# For critic_neg_dim: concatenated along last dim, so channels should be doubled
791+
assert obs_critic_neg_dim.shape == (env.num_envs, 128, 256, 2)
792+
793+
# verify the data is concatenated correctly
794+
# For policy: check that the second half matches the first half
795+
torch.testing.assert_close(obs_policy[:, :, :256, :], obs_policy[:, :, 256:, :])
796+
# For critic: check that the second channel matches the first channel
797+
torch.testing.assert_close(obs_critic[:, :, :, 0], obs_critic[:, :, :, 1])
798+
799+
# For critic_neg_dim: check that it is the same as critic
800+
torch.testing.assert_close(obs_critic_neg_dim, obs_critic)

0 commit comments

Comments
 (0)