Skip to content

Release 22/inconsistent device tensor action in trainers [Don't Merge] #6225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 48 commits into
base: develop
Choose a base branch
from

Conversation

Jkho80
Copy link

@Jkho80 Jkho80 commented Jul 20, 2025

Proposed change(s)

Hi,

While working with mlagents-learn and running my environment with --torch-device cuda, I encountered multiple runtime errors such as:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

After investigating the root cause, I found that several tensor operations in the codebase implicitly rely on default CPU-based tensors. These issues are often introduced when:

  • Creating tensors directly with torch.tensor() Using NumPy-based Parameters or Variable without specifying device (e.g., torch.tensor(numpy_value) instead of torch.from_numpy(numpy_value).to(device)).

  • Combining GPU-based tensors with CPU-based masks or constants Computation.

This is especially problematic in scenarios where ML-Agents interacts with compute shaders or when training with Unity in GPU mode, as values originating from RAM (via NumPy) default to CPU memory and cause device mismatches in PyTorch computations.

What I did:
To address this, I:

Identified common points where tensors were created without explicit device assignment.

Ensured that all relevant tensors (especially masks, constants, and externally created inputs) are moved to the correct device using .to(device) based on the context tensor.

This should make training on CUDA more stable and prevent errors due to cross-device tensor operations.

Thanks for the great work on ML-Agents!

Useful links (Github issues, JIRA tickets, ML-Agents forum threads etc.)

None

Types of change(s)

  • Bug fix
  • New feature
  • Code refactor
  • Breaking change
  • Documentation update
  • Other (please describe)

Checklist

  • Added tests that prove my fix is effective or that my feature works
  • Updated the changelog (if applicable)
  • Updated the documentation (if applicable)
  • Updated the migration guide (if applicable)

Other comments

mlagents-learn RobotReacher.yaml --run-id robot01 --torch-device cuda --force --debug

Error and Success Run Log
2329d7c6ddebf38fe43547da85e9b35
f8da9ce9b8cac220f80e5a5d5b049da

@CLAassistant
Copy link

CLAassistant commented Jul 20, 2025

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
2 out of 4 committers have signed the CLA.

✅ Jkho80
✅ miguelalonsojr
❌ Aurimas Petrovas
❌ AlexRibard


Aurimas Petrovas seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cycode: Security vulnerabilities found in newly introduced dependency.

Ecosystem PyPI
Dependency grpcio
Dependency Paths grpcio 1.48.2
Direct Dependency Yes

The following vulnerabilities were introduced:

GHSA CVE Severity Fixed Version
GHSA-496j-2rq6-j6cc CVE-2023-33953 HIGH 1.53.2
GHSA-6628-q6j9-w8vg CVE-2023-1428 HIGH 1.53.0
GHSA-cfgp-2977-2fmm CVE-2023-32731 HIGH 1.53.0

Highest fixed version: 1.53.2

Description

Detects when new vulnerabilities affect your dependencies.

Tell us how you wish to proceed using one of the following commands:

Tag Short Description
#cycode_vulnerable_package_fix_this_violation Fix this violation via a commit to this branch
#cycode_ignore_manifest_here <reason> Applies to this manifest in this request only

⚠️ When commenting on Github, you may need to refresh the page to see the latest updates.

⚠️ Due to API limitations, we can not comment on the exact line (58)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants