Skip to content

Conversation

@coreyjadams
Copy link
Collaborator

Add reduced precision support for radius search / ball query by transparently casting and then casting back.

Warp hashmap doesn't actually support half precision, so this is our only path to functionality at this time.

PhysicsNeMo Pull Request

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 4, 2025

Greptile Overview

Greptile Summary

Implements reduced precision support for radius search/ball query operations by adding transparent casting to float32 and back. Since Warp's hashmap doesn't natively support half precision, the implementation casts inputs to float32 before processing and converts results back to the original precision.

Key changes:

  • Added dtype preservation logic in _warp_impl.py to save input dtype, cast to float32, and restore output dtype
  • Updated function flow to unpack gather_neighbors return values for dtype conversion
  • Added documentation warning about automatic casting behavior
  • Cleaned up debug print statements in tests
  • Added comprehensive test coverage for torch.float16, torch.bfloat16, and torch.float32 across CPU and CUDA devices

The implementation is straightforward and correct. Output points and distances are properly converted back to match the input precision.

Important Files Changed

File Analysis

Filename Score Overview
physicsnemo/utils/neighbors/radius_search/_warp_impl.py 5/5 Added reduced precision support by casting inputs to float32 and casting outputs back to original dtype
physicsnemo/utils/neighbors/radius_search/radius_search.py 5/5 Added documentation note about automatic casting behavior with Warp backend
test/utils/neighbors/test_radius_search.py 5/5 Removed debug print statements and added new test for reduced precision support (float16, bfloat16, float32)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. physicsnemo/utils/neighbors/radius_search/_warp_impl.py, line 367-371 (link)

    style: Variable shadowing - points parameter is overwritten here, which could cause confusion and breaks the reference to the input tensor for later casting back.

    When return_points=True and max_points is not None, this reassigns the points variable (line 367-371), shadowing the input parameter. Later on line 403, points.to(input_dtype) attempts to cast, but it's operating on the newly created tensor, not the original input. While this works functionally (the new tensor is created with dtype=torch.float32), it's confusing because the variable name collision makes the code harder to follow.

  2. physicsnemo/utils/neighbors/radius_search/_warp_impl.py, line 332 (link)

    style: Same variable shadowing issue for points in the unlimited neighbors path.

    The gather_neighbors function returns a new tensor named points which shadows the input parameter points. While functionally correct (since gather_neighbors creates a float32 tensor that's later cast back on line 403), the variable reuse makes the code harder to understand.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good. Bummer that Warp isn't dtype-flexible here, but assuming that's a given, I can't think of a better way to handle this situation.

Before merging, can we add a quick note in the docstring (either in _warp_impl or in the top-level wrapper, at your choice) indicating that this does an internal cast to fp32 (and hence, a new allocation) of the associated arrays?

Copy link
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last thing - let's check fp32 in tests

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

@coreyjadams
Copy link
Collaborator Author

Thanks for the reviews! I added a note about the transparent casting, I'll merge when CI is happy.

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

@coreyjadams coreyjadams closed this Dec 5, 2025
@coreyjadams coreyjadams reopened this Dec 5, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants