diff --git a/physicsnemo/utils/neighbors/radius_search/_warp_impl.py b/physicsnemo/utils/neighbors/radius_search/_warp_impl.py index efe92ec973..7a1393a98c 100644 --- a/physicsnemo/utils/neighbors/radius_search/_warp_impl.py +++ b/physicsnemo/utils/neighbors/radius_search/_warp_impl.py @@ -278,6 +278,14 @@ def radius_search_impl( if points.device != queries.device: raise ValueError("points and queries must be on the same device") + input_dtype = points.dtype + + # Warp supports only fp32, so we have to cast: + if points.dtype != torch.float32: + points = points.to(torch.float32) + if queries.dtype != torch.float32: + queries = queries.to(torch.float32) + N_queries = len(queries) # Compute follows data. @@ -321,7 +329,7 @@ def radius_search_impl( f"Total found neighbors is too large: {total_count} >= 2**31 - 1" ) - return gather_neighbors( + indices, points, distances, num_neighbors = gather_neighbors( grid, points.device, wp_points, @@ -392,6 +400,8 @@ def radius_search_impl( ) # Handle the matrix of return values: + points = points.to(input_dtype) + distances = distances.to(input_dtype) return indices, points, distances, num_neighbors # This is to enable torch.compile: diff --git a/physicsnemo/utils/neighbors/radius_search/radius_search.py b/physicsnemo/utils/neighbors/radius_search/radius_search.py index 29572c2274..a8a4e03c3f 100644 --- a/physicsnemo/utils/neighbors/radius_search/radius_search.py +++ b/physicsnemo/utils/neighbors/radius_search/radius_search.py @@ -60,6 +60,10 @@ def radius_search( the outputs may be ordered differently by the two backends. Do not rely on the exact order of the neighbors in the outputs. + Note: + With the Warp backend, there will be an automatic casting of inputs to float32 from reduced precision, + and results will be returned in their original precision. + Args: points (torch.Tensor): The reference point cloud tensor of shape (N, 3) where N is the number of points. diff --git a/test/utils/neighbors/test_radius_search.py b/test/utils/neighbors/test_radius_search.py index 8c77627c04..a03765aeaa 100644 --- a/test/utils/neighbors/test_radius_search.py +++ b/test/utils/neighbors/test_radius_search.py @@ -113,7 +113,6 @@ def test_radius_search( else: indexes = results - print(f"Indexes shape: {indexes.shape}") # Basic shape checks - there should be one index array per query point if max_points is not None: assert indexes.shape[0] == query_space_points.shape[0] @@ -149,7 +148,6 @@ def test_radius_search( assert (dists[mask][1:] == 0).all() if return_points: - print(points.shape) if max_points is not None: assert points.shape[0] == query_space_points.shape[0] assert points.shape[1] == max_points @@ -183,7 +181,6 @@ def test_radius_search( # finds exactly one point within radius 0.1 (the 0.05 displaced point) # This is how many are possible by the data: expected_matches = min(int(radius / 0.05), 6) - print(radius, expected_matches) # expected_matches = 1 if max_points is not None: # Some limit has been imposed: @@ -193,8 +190,6 @@ def test_radius_search( # the first from the assertion. matches_per_query = (indexes != 0).sum(dim=1) - # print(torch.where(matches_per_query == 2)) - # print(matches_per_query[1:20]) assert (matches_per_query[1:] == expected_matches).all() else: @@ -295,12 +290,6 @@ def test_radius_search_comparison(device, max_points): ) if max_points is not None: - print(f"out_points_warp shape: {out_points_warp.shape}") - print(f"out_points_torch shape: {out_points_torch.shape}") - # print(f'out_points_warp.sum(dim=(0)): {out_points_warp.sum(dim=(0))}') - # print(f'out_points_torch.sum(dim=(0)): {out_points_torch.sum(dim=(0))}') - print(f"out_points_warp[1]: {out_points_warp[1]}") - print(f"out_points_torch[1]: {out_points_torch[1]}") assert torch.allclose(out_points_warp.sum(dim=1), out_points_torch.sum(dim=1)) else: assert torch.allclose( @@ -331,9 +320,6 @@ def test_radius_search_gradients(device, max_points): points = torch.randn(n_points, 3, device=device, requires_grad=True) queries = torch.randn(n_queries, 3, device=device, requires_grad=True) - print(f"points shape: {points.shape}") - print(f"queries shape: {queries.shape}") - grads = {} for backend in ["warp", "torch"]: # Clone inputs for each backend to avoid in-place ops @@ -355,16 +341,48 @@ def test_radius_search_gradients(device, max_points): pts.grad.detach().clone() if pts.grad is not None else None, qrs.grad.detach().clone() if qrs.grad is not None else None, ) - print(f"Index: {index}") # Compare gradients between backends pts_grad_warp, qrs_grad_warp = grads["warp"] pts_grad_torch, qrs_grad_torch = grads["torch"] - print(f"Warp points grad: {pts_grad_warp}") - print(f"Torch points grad: {pts_grad_torch}") - assert torch.allclose(pts_grad_warp, pts_grad_torch, atol=1e-5), ( "Point gradients do not match" ) # assert torch.allclose(qrs_grad_warp, qrs_grad_torch, atol=1e-5), "Query gradients do not match" + + +@pytest.mark.parametrize("precision", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("max_points", [8, None]) +def test_radius_search_reduced_precision(device, precision, max_points): + """ + This is a functionality based test. We run in half precision and + make sure results are reasonable, but exact agreement from the alg is tested + elsewhere. + """ + + torch.manual_seed(42) + n_points = 88 + n_queries = 57 + radius = 0.5 + + # Create points and queries with gradients enabled + points = torch.randn(n_points, 3, device=device, requires_grad=True).to( + dtype=precision + ) + queries = torch.randn(n_queries, 3, device=device, requires_grad=True).to( + dtype=precision + ) + + index, out_points = radius_search( + points, + queries, + radius=radius, + max_points=max_points, + return_dists=False, + return_points=True, + backend="warp", + ) + + assert out_points.dtype == points.dtype