-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
36 lines (29 loc) · 1.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from .csrc.build.libsegmentator import segment_mesh as segment_mesh_fn, segment_point as segment_point_fn
def segment_mesh(vertices, faces, kThresh=0.01, segMinVerts=20):
""" segment a mesh (CPU)
Args:
vertices (torch.Tensor): vertices of shape==(nv, 3)
faces (torch.Tensor): faces of shape==(nf, 3)
kThresh (float): segmentation cluster threshold parameter (larger values lead to larger segments)
segMinVerts (int): the minimum number of vertices per-segment, enforced by merging small clusters into larger segments
Returns:
index (torch.Tensor): the cluster index (starts from 0)
"""
index = segment_mesh_fn(vertices, faces, kThresh, segMinVerts)
index = torch.unique(index, return_inverse=True)[1]
return index
def segment_point(vertices, normals, edges, kThresh=0.01, segMinVerts=20):
""" segment a point cloud (CPU)
Args:
vertices (torch.Tensor): vertices of shape==(nv, 3)
normals (torch.Tensor): normals of shape==(nf, 3)
edges (torch.Tensor): edges of shape==(ne, 2)
kThresh (float): segmentation cluster threshold parameter (larger values lead to larger segments)
segMinVerts (int): the minimum number of vertices per-segment, enforced by merging small clusters into larger segments
Returns:
index (torch.Tensor): the cluster index (starts from 0)
"""
index = segment_point_fn(vertices, normals, edges, kThresh, segMinVerts)
index = torch.unique(index, return_inverse=True)[1]
return index