Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion physicsnemo/utils/neighbors/radius_search/_warp_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ def radius_search_impl(
if points.device != queries.device:
raise ValueError("points and queries must be on the same device")

input_dtype = points.dtype

# Warp supports only fp32, so we have to cast:
if points.dtype != torch.float32:
points = points.to(torch.float32)
if queries.dtype != torch.float32:
queries = queries.to(torch.float32)

N_queries = len(queries)

# Compute follows data.
Expand Down Expand Up @@ -321,7 +329,7 @@ def radius_search_impl(
f"Total found neighbors is too large: {total_count} >= 2**31 - 1"
)

return gather_neighbors(
indices, points, distances, num_neighbors = gather_neighbors(
grid,
points.device,
wp_points,
Expand Down Expand Up @@ -392,6 +400,8 @@ def radius_search_impl(
)

# Handle the matrix of return values:
points = points.to(input_dtype)
distances = distances.to(input_dtype)
return indices, points, distances, num_neighbors

# This is to enable torch.compile:
Expand Down
4 changes: 4 additions & 0 deletions physicsnemo/utils/neighbors/radius_search/radius_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def radius_search(
the outputs may be ordered differently by the two backends. Do not rely on the exact order of
the neighbors in the outputs.

Note:
With the Warp backend, there will be an automatic casting of inputs to float32 from reduced precision,
and results will be returned in their original precision.

Args:
points (torch.Tensor): The reference point cloud tensor of shape (N, 3) where N is the number
of points.
Expand Down
54 changes: 36 additions & 18 deletions test/utils/neighbors/test_radius_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def test_radius_search(
else:
indexes = results

print(f"Indexes shape: {indexes.shape}")
# Basic shape checks - there should be one index array per query point
if max_points is not None:
assert indexes.shape[0] == query_space_points.shape[0]
Expand Down Expand Up @@ -149,7 +148,6 @@ def test_radius_search(
assert (dists[mask][1:] == 0).all()

if return_points:
print(points.shape)
if max_points is not None:
assert points.shape[0] == query_space_points.shape[0]
assert points.shape[1] == max_points
Expand Down Expand Up @@ -183,7 +181,6 @@ def test_radius_search(
# finds exactly one point within radius 0.1 (the 0.05 displaced point)
# This is how many are possible by the data:
expected_matches = min(int(radius / 0.05), 6)
print(radius, expected_matches)
# expected_matches = 1
if max_points is not None:
# Some limit has been imposed:
Expand All @@ -193,8 +190,6 @@ def test_radius_search(
# the first from the assertion.

matches_per_query = (indexes != 0).sum(dim=1)
# print(torch.where(matches_per_query == 2))
# print(matches_per_query[1:20])
assert (matches_per_query[1:] == expected_matches).all()

else:
Expand Down Expand Up @@ -295,12 +290,6 @@ def test_radius_search_comparison(device, max_points):
)

if max_points is not None:
print(f"out_points_warp shape: {out_points_warp.shape}")
print(f"out_points_torch shape: {out_points_torch.shape}")
# print(f'out_points_warp.sum(dim=(0)): {out_points_warp.sum(dim=(0))}')
# print(f'out_points_torch.sum(dim=(0)): {out_points_torch.sum(dim=(0))}')
print(f"out_points_warp[1]: {out_points_warp[1]}")
print(f"out_points_torch[1]: {out_points_torch[1]}")
assert torch.allclose(out_points_warp.sum(dim=1), out_points_torch.sum(dim=1))
else:
assert torch.allclose(
Expand Down Expand Up @@ -331,9 +320,6 @@ def test_radius_search_gradients(device, max_points):
points = torch.randn(n_points, 3, device=device, requires_grad=True)
queries = torch.randn(n_queries, 3, device=device, requires_grad=True)

print(f"points shape: {points.shape}")
print(f"queries shape: {queries.shape}")

grads = {}
for backend in ["warp", "torch"]:
# Clone inputs for each backend to avoid in-place ops
Expand All @@ -355,16 +341,48 @@ def test_radius_search_gradients(device, max_points):
pts.grad.detach().clone() if pts.grad is not None else None,
qrs.grad.detach().clone() if qrs.grad is not None else None,
)
print(f"Index: {index}")
# Compare gradients between backends
pts_grad_warp, qrs_grad_warp = grads["warp"]
pts_grad_torch, qrs_grad_torch = grads["torch"]

print(f"Warp points grad: {pts_grad_warp}")
print(f"Torch points grad: {pts_grad_torch}")

assert torch.allclose(pts_grad_warp, pts_grad_torch, atol=1e-5), (
"Point gradients do not match"
)

# assert torch.allclose(qrs_grad_warp, qrs_grad_torch, atol=1e-5), "Query gradients do not match"


@pytest.mark.parametrize("precision", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("max_points", [8, None])
def test_radius_search_reduced_precision(device, precision, max_points):
"""
This is a functionality based test. We run in half precision and
make sure results are reasonable, but exact agreement from the alg is tested
elsewhere.
"""

torch.manual_seed(42)
n_points = 88
n_queries = 57
radius = 0.5

# Create points and queries with gradients enabled
points = torch.randn(n_points, 3, device=device, requires_grad=True).to(
dtype=precision
)
queries = torch.randn(n_queries, 3, device=device, requires_grad=True).to(
dtype=precision
)

index, out_points = radius_search(
points,
queries,
radius=radius,
max_points=max_points,
return_dists=False,
return_points=True,
backend="warp",
)

assert out_points.dtype == points.dtype