-
Notifications
You must be signed in to change notification settings - Fork 0
Fixed buffer test #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes issues with the rollout buffer's handling of log probability tensor shapes by removing the misleading transposes and updating related tests. Key changes include:
- Adjusting test cases in test/replay/test_rollout_buffer.py to provide log_probs with correct shapes.
- Updating the _promote() helper in mighty/mighty_replay/mighty_rollout_buffer.py to support 3D tensors for actions and observations and removing the unnecessary transpose in add().
- Squeezing extra dimensions in mighty/mighty_agents/ppo.py to simplify log_prob handling.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
test/replay/test_rollout_buffer.py | Updated test cases to align with new rollout buffer shape handling and added debug prints |
mighty/mighty_replay/mighty_rollout_buffer.py | Modified _promote() behavior and removed log_probs transpose in add() |
mighty/mighty_agents/ppo.py | Removed an extra dimension from log_prob to match expectations |
print(f"Computed advantages: {advantages_computed}") | ||
print(f"Computed returns: {returns_computed}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider removing or conditionally enabling debug print statements in test_compute_returns_and_advantage_multi_env once debugging is complete to reduce test output noise.
print(f"Computed advantages: {advantages_computed}") | |
print(f"Computed returns: {returns_computed}") | |
if DEBUG_MODE: | |
print(f"Computed advantages: {advantages_computed}") | |
print(f"Computed returns: {returns_computed}") |
Copilot uses AI. Check for mistakes.
print(f"MaxiBatch size: {len(maxi_batch)}") | ||
print(f"Number of minibatches: {len(list(maxi_batch))}") | ||
# Debug the buffer contents before sampling | ||
print("\nDEBUG: Buffer contents before sampling:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider removing or conditionally enabling debug print statements in test_sample_with_data to keep test logs clean after the issue is resolved.
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree :D
@@ -291,6 +291,12 @@ def process_transition( # type: ignore | |||
else None | |||
) | |||
|
|||
|
|||
# FIX: Remove extra dimension from log_prob if present |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the "FIX" necessary here? Looks a bit like a "FIXME" tag
elif x.dim() == 3 and name in ["actions", "observations"]: # (timesteps, n_envs, features) | ||
return x | ||
else: | ||
raise RuntimeError(f"Unexpected shape for {name}: {x.shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the class back in would be nice, I think. Else we have to follow the error stack
advantages=np.array([[0.0, 0.0]]), # (1, 2) ✓ | ||
returns=np.array([[0.0, 0.0]]), # (1, 2) ✓ | ||
episode_starts=np.array([[1, 1]]), # (1, 2) ✓ | ||
log_probs=np.array([[-0.5, -0.8]]), # (1, 2) - FIXED: No more transpose needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"FIX" can probably be removed here, right?
advantages_computed = buffer.advantages[0] # First timestep | ||
returns_computed = buffer.returns[0] # First timestep | ||
|
||
print(f"Computed advantages: {advantages_computed}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why print statements? Shouldn't this be another assert instead?
([[[5, 6, 7, 8]]], [[0.3, 0.4]], [[0.7, 0.8]], [[0.5]], [[-0.1]], | ||
[[0.4]], [[0]], [[-0.8]], [[0.5]]) | ||
] | ||
# Debug: Check what the sampling produced |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again with the debug prints! I think the copilot suggestion with the debug mode isn't bad. If that's not needed, we should remove them or move them to failing asserts.
values=np.array(vals), # (1, 1) | ||
) | ||
buffer.add(rb) | ||
# Check each minibatch individually |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this doesn't check, it only prints?
assert minibatch.latents is not None, "Minibatch should have latents" | ||
assert isinstance(minibatch.latents, torch.Tensor), "Latents should be tensor" | ||
|
||
# First, let's understand what we actually got |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PRINTING xD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Remove "FIX" tags
- Remove plain prints in test (either via debug mode or move to asserts)
Issue: The _promote() function in the rollout buffer was not working with log probabilities and would. Therefore, a previous hack was to modify their shape to make it work. This necessitated the tests to be designed around this change for them to pass.
Fix: Extended functionality of the _promote() method in Rollout buffer, removed the transposes, and handled the case of logprpb shapes. Consequently, also changed the affected test cases. Seems to work now.