Skip to content

Fixed buffer test #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Fixed buffer test #56

wants to merge 1 commit into from

Conversation

amsks
Copy link
Collaborator

@amsks amsks commented Jul 1, 2025

Issue: The _promote() function in the rollout buffer was not working with log probabilities and would. Therefore, a previous hack was to modify their shape to make it work. This necessitated the tests to be designed around this change for them to pass.

Fix: Extended functionality of the _promote() method in Rollout buffer, removed the transposes, and handled the case of logprpb shapes. Consequently, also changed the affected test cases. Seems to work now.

@TheEimer TheEimer requested a review from Copilot July 1, 2025 08:09
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes issues with the rollout buffer's handling of log probability tensor shapes by removing the misleading transposes and updating related tests. Key changes include:

  • Adjusting test cases in test/replay/test_rollout_buffer.py to provide log_probs with correct shapes.
  • Updating the _promote() helper in mighty/mighty_replay/mighty_rollout_buffer.py to support 3D tensors for actions and observations and removing the unnecessary transpose in add().
  • Squeezing extra dimensions in mighty/mighty_agents/ppo.py to simplify log_prob handling.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
test/replay/test_rollout_buffer.py Updated test cases to align with new rollout buffer shape handling and added debug prints
mighty/mighty_replay/mighty_rollout_buffer.py Modified _promote() behavior and removed log_probs transpose in add()
mighty/mighty_agents/ppo.py Removed an extra dimension from log_prob to match expectations

Comment on lines +612 to +613
print(f"Computed advantages: {advantages_computed}")
print(f"Computed returns: {returns_computed}")
Copy link
Preview

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider removing or conditionally enabling debug print statements in test_compute_returns_and_advantage_multi_env once debugging is complete to reduce test output noise.

Suggested change
print(f"Computed advantages: {advantages_computed}")
print(f"Computed returns: {returns_computed}")
if DEBUG_MODE:
print(f"Computed advantages: {advantages_computed}")
print(f"Computed returns: {returns_computed}")

Copilot uses AI. Check for mistakes.

print(f"MaxiBatch size: {len(maxi_batch)}")
print(f"Number of minibatches: {len(list(maxi_batch))}")
# Debug the buffer contents before sampling
print("\nDEBUG: Buffer contents before sampling:")
Copy link
Preview

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider removing or conditionally enabling debug print statements in test_sample_with_data to keep test logs clean after the issue is resolved.

Copilot uses AI. Check for mistakes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree :D

@@ -291,6 +291,12 @@ def process_transition( # type: ignore
else None
)


# FIX: Remove extra dimension from log_prob if present
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the "FIX" necessary here? Looks a bit like a "FIXME" tag

elif x.dim() == 3 and name in ["actions", "observations"]: # (timesteps, n_envs, features)
return x
else:
raise RuntimeError(f"Unexpected shape for {name}: {x.shape}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the class back in would be nice, I think. Else we have to follow the error stack

advantages=np.array([[0.0, 0.0]]), # (1, 2) ✓
returns=np.array([[0.0, 0.0]]), # (1, 2) ✓
episode_starts=np.array([[1, 1]]), # (1, 2) ✓
log_probs=np.array([[-0.5, -0.8]]), # (1, 2) - FIXED: No more transpose needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"FIX" can probably be removed here, right?

advantages_computed = buffer.advantages[0] # First timestep
returns_computed = buffer.returns[0] # First timestep

print(f"Computed advantages: {advantages_computed}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why print statements? Shouldn't this be another assert instead?

([[[5, 6, 7, 8]]], [[0.3, 0.4]], [[0.7, 0.8]], [[0.5]], [[-0.1]],
[[0.4]], [[0]], [[-0.8]], [[0.5]])
]
# Debug: Check what the sampling produced
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again with the debug prints! I think the copilot suggestion with the debug mode isn't bad. If that's not needed, we should remove them or move them to failing asserts.

values=np.array(vals), # (1, 1)
)
buffer.add(rb)
# Check each minibatch individually
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this doesn't check, it only prints?

assert minibatch.latents is not None, "Minibatch should have latents"
assert isinstance(minibatch.latents, torch.Tensor), "Latents should be tensor"

# First, let's understand what we actually got
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PRINTING xD

Copy link
Contributor

@TheEimer TheEimer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Remove "FIX" tags
  • Remove plain prints in test (either via debug mode or move to asserts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants