Currently mjx_env.step(state, action, n_substeps) only returns the final state of substepping. However, this makes it so that rewards can only be obtained for that state instead of being accumulated/averaged over the substeps as is normally done when frame skipping. I think this is a pretty simple feature in theory - just have an arg called return_substeps=False and return all the states over the substeps. I'm just wondering if theres a more memory efficient way to do so?
Currently mjx_env.step(state, action, n_substeps) only returns the final state of substepping. However, this makes it so that rewards can only be obtained for that state instead of being accumulated/averaged over the substeps as is normally done when frame skipping. I think this is a pretty simple feature in theory - just have an arg called
return_substeps=Falseand return all the states over the substeps. I'm just wondering if theres a more memory efficient way to do so?