Skip to content

Commit 9c02851

Browse files
ryichandoclaude
andcommitted
feat: add GPU-accelerated LBVH with radix sort and optimize frontend
Replace CPU-based BVH with GPU-accelerated Linear BVH (LBVH) implementation using radix sort for morton code ordering. This improves collision detection performance significantly for large meshes. Additional optimizations: - Add numba-accelerated mesh generation (grid/rect faces, cylinder) - Optimize sphere violation checks using squared distance - Add early exit to triangle-triangle/edge distance calculations - Parallelize face-to-vertex weight computation - Standardize platform detection using platform.system() - Remove obsolete bvh-alloc-factor parameter 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 34f0f43 commit 9c02851

29 files changed

+2039
-662
lines changed

build-win-native/build.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ set EIGEN_DIR=%DEPS%\eigen-3.4.0
139139

140140
REM Source files
141141
set CPP_SRCS=%CPP_DIR%\simplelog\SimpleLog.cpp %CPP_DIR%\stub.cpp
142-
set CU_SRCS=%CPP_DIR%\buffer\buffer.cu %CPP_DIR%\main\main.cu %CPP_DIR%\utility\utility.cu %CPP_DIR%\utility\dispatcher.cu %CPP_DIR%\csrmat\csrmat.cu %CPP_DIR%\contact\contact.cu %CPP_DIR%\energy\energy.cu %CPP_DIR%\eigenanalysis\eigenanalysis.cu %CPP_DIR%\barrier\barrier.cu %CPP_DIR%\strainlimiting\strainlimiting.cu %CPP_DIR%\solver\solver.cu %CPP_DIR%\kernels\reduce.cu %CPP_DIR%\kernels\exclusive_scan.cu %CPP_DIR%\kernels\vec_ops.cu
142+
set CU_SRCS=%CPP_DIR%\buffer\buffer.cu %CPP_DIR%\main\main.cu %CPP_DIR%\utility\utility.cu %CPP_DIR%\utility\dispatcher.cu %CPP_DIR%\csrmat\csrmat.cu %CPP_DIR%\contact\contact.cu %CPP_DIR%\energy\energy.cu %CPP_DIR%\eigenanalysis\eigenanalysis.cu %CPP_DIR%\barrier\barrier.cu %CPP_DIR%\strainlimiting\strainlimiting.cu %CPP_DIR%\solver\solver.cu %CPP_DIR%\kernels\reduce.cu %CPP_DIR%\kernels\exclusive_scan.cu %CPP_DIR%\kernels\vec_ops.cu %CPP_DIR%\kernels\radix_sort.cu %CPP_DIR%\lbvh\lbvh.cu
143143

144144
REM Compiler flags
145145
set NVCC_FLAGS=-std=c++17 --expt-relaxed-constexpr --extended-lambda -O3 -rdc=true -shared -Wno-deprecated-gpu-targets

build-win-native/clear-cache.bat

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ if exist "!JUPYTER_DIR!" (
7878
echo [SKIP] Jupyter directory not found
7979
)
8080

81+
REM Clear export directory in examples (legacy location cleanup)
82+
set EXPORT_DIR=!PROJ_ROOT!\examples\export
83+
if exist "!EXPORT_DIR!" (
84+
echo Removing export directory in examples...
85+
rmdir /s /q "!EXPORT_DIR!"
86+
if exist "!EXPORT_DIR!" (
87+
echo [FAIL] Could not remove !EXPORT_DIR!
88+
set HAS_ERROR=1
89+
) else (
90+
echo [OK] Removed !EXPORT_DIR!
91+
)
92+
) else (
93+
echo [SKIP] Export directory in examples not found
94+
)
95+
8196
echo.
8297
if %HAS_ERROR%==1 (
8398
echo === [FAIL] Some caches could not be cleared ===

frontend/_app_.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import os
77
import pickle
8+
import platform
89
import shutil
910

1011
from typing import Optional
@@ -187,7 +188,7 @@ def get_data_dirpath():
187188
with open(branch_file) as f:
188189
git_branch = f.read().strip()
189190
if git_branch:
190-
if os.name == 'nt': # Windows
191+
if platform.system() == "Windows": # Windows
191192
return os.path.join(
192193
base_dir, "local", "share", "ppf-cts", f"git-{git_branch}"
193194
)
@@ -212,7 +213,7 @@ def get_data_dirpath():
212213
except (subprocess.CalledProcessError, FileNotFoundError):
213214
git_branch = "unknown"
214215

215-
if os.name == 'nt': # Windows
216+
if platform.system() == "Windows": # Windows
216217
return os.path.join(
217218
base_dir, "local", "share", "ppf-cts", f"git-{git_branch}"
218219
)
@@ -243,7 +244,7 @@ def __init__(self, name: str, renew: bool, cache_dir: str = ""):
243244
if cache_dir:
244245
self._cache_dir = cache_dir
245246
else:
246-
if os.name == 'nt': # Windows - use project-relative cache
247+
if platform.system() == "Windows": # Windows - use project-relative cache
247248
frontend_dir = os.path.dirname(os.path.realpath(__file__))
248249
base_dir = os.path.dirname(frontend_dir)
249250
self._cache_dir = os.path.join(base_dir, "cache", "ppf-cts")

frontend/_invisible_collider_.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -42,51 +42,6 @@ def _check_wall_violation_single(
4242
return signed_dist < 0.0
4343

4444

45-
@njit(cache=True)
46-
def _check_sphere_violation_single(
47-
vertex: np.ndarray,
48-
sphere_center: np.ndarray,
49-
sphere_radius: float,
50-
is_inverted: bool,
51-
is_hemisphere: bool,
52-
) -> bool:
53-
"""Check if a single vertex violates sphere constraint.
54-
55-
For normal sphere: vertex must be OUTSIDE (distance >= radius)
56-
For inverted sphere: vertex must be INSIDE (distance <= radius)
57-
For hemisphere (bowl): above center.y, treat as cylinder (only check horizontal distance)
58-
59-
Args:
60-
vertex: Vertex position (3,)
61-
sphere_center: Center of sphere (3,)
62-
sphere_radius: Radius of sphere
63-
is_inverted: If True, collision is with inside of sphere
64-
is_hemisphere: If True, top half is open (bowl shape)
65-
66-
Returns:
67-
True if vertex violates the sphere constraint
68-
"""
69-
# For hemisphere, if vertex is above center.y, project center to vertex's y-level
70-
# This creates a cylinder-like region above the hemisphere
71-
if is_hemisphere and vertex[1] > sphere_center[1]:
72-
center = np.array([sphere_center[0], vertex[1], sphere_center[2]])
73-
else:
74-
center = sphere_center
75-
76-
diff = vertex - center
77-
dist_sq = diff[0] * diff[0] + diff[1] * diff[1] + diff[2] * diff[2]
78-
dist = np.sqrt(dist_sq)
79-
80-
if is_inverted:
81-
# For inverted sphere, vertex must be inside (distance <= radius)
82-
# Violation if distance > radius (outside the sphere)
83-
return dist > sphere_radius
84-
else:
85-
# For normal sphere, vertex must be outside (distance >= radius)
86-
# Violation if distance < radius (inside the sphere)
87-
return dist < sphere_radius
88-
89-
9045
@njit(parallel=True, cache=True)
9146
def _check_wall_violations_parallel(
9247
vertices: np.ndarray,
@@ -108,19 +63,59 @@ def _check_wall_violations_parallel(
10863
Number of violations found
10964
"""
11065
n_verts = len(vertices)
66+
67+
# Check violations in parallel
11168
for i in prange(n_verts):
11269
if not is_pinned[i] and _check_wall_violation_single(
11370
vertices[i], wall_pos, wall_normal
11471
):
11572
violations[i] = True
11673

74+
# Parallel reduction for counting
11775
count = 0
118-
for i in range(n_verts):
76+
for i in prange(n_verts):
11977
if violations[i]:
12078
count += 1
12179
return count
12280

12381

82+
@njit(cache=True)
83+
def _check_sphere_violation(
84+
vertex: np.ndarray,
85+
sphere_center: np.ndarray,
86+
sphere_radius_sq: float,
87+
is_inverted: bool,
88+
is_hemisphere: bool,
89+
) -> bool:
90+
"""Check if a single vertex violates sphere constraint using squared distance.
91+
92+
This version avoids sqrt for the check, only computing it when needed.
93+
"""
94+
# For hemisphere, if vertex is above center.y, project center to vertex's y-level
95+
if is_hemisphere and vertex[1] > sphere_center[1]:
96+
cx = sphere_center[0]
97+
cy = vertex[1]
98+
cz = sphere_center[2]
99+
else:
100+
cx = sphere_center[0]
101+
cy = sphere_center[1]
102+
cz = sphere_center[2]
103+
104+
dx = vertex[0] - cx
105+
dy = vertex[1] - cy
106+
dz = vertex[2] - cz
107+
dist_sq = dx*dx + dy*dy + dz*dz
108+
109+
if is_inverted:
110+
# For inverted sphere, vertex must be inside (distance <= radius)
111+
# Violation if distance > radius (outside the sphere)
112+
return dist_sq > sphere_radius_sq
113+
else:
114+
# For normal sphere, vertex must be outside (distance >= radius)
115+
# Violation if distance < radius (inside the sphere)
116+
return dist_sq < sphere_radius_sq
117+
118+
124119
@njit(parallel=True, cache=True)
125120
def _check_sphere_violations_parallel(
126121
vertices: np.ndarray,
@@ -146,14 +141,18 @@ def _check_sphere_violations_parallel(
146141
Number of violations found
147142
"""
148143
n_verts = len(vertices)
144+
sphere_radius_sq = sphere_radius * sphere_radius
145+
146+
# Check violations in parallel using squared distance
149147
for i in prange(n_verts):
150-
if not is_pinned[i] and _check_sphere_violation_single(
151-
vertices[i], sphere_center, sphere_radius, is_inverted, is_hemisphere
148+
if not is_pinned[i] and _check_sphere_violation(
149+
vertices[i], sphere_center, sphere_radius_sq, is_inverted, is_hemisphere
152150
):
153151
violations[i] = True
154152

153+
# Parallel reduction for counting
155154
count = 0
156-
for i in range(n_verts):
155+
for i in prange(n_verts):
157156
if violations[i]:
158157
count += 1
159158
return count

0 commit comments

Comments
 (0)