Skip to content

Commit

Permalink
minors and a refined diff_pipeline tool
Browse files Browse the repository at this point in the history
  • Loading branch information
nuneslu committed Mar 18, 2024
1 parent b114fec commit ebdc87e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
20 changes: 13 additions & 7 deletions lidiff/tools/diff_completion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
class DiffCompletion(LightningModule):
def __init__(self, diff_path, refine_path, denoising_steps, cond_weight):
super().__init__()
hparams = yaml.safe_load(open(diff_path.split('checkpoints')[0] + '/hparams.yaml'))
self.save_hyperparameters(hparams)
ckpt_diff = torch.load(diff_path)
self.save_hyperparameters(ckpt_diff['hyper_parameters'])
assert denoising_steps <= self.hparams['diff']['t_steps'], \
f"The number of denoising steps cannot be bigger than T={self.hparams['diff']['t_steps']} (you've set '-T {denoising_steps}')"

ckpt_diff = torch.load(diff_path)
self.partial_enc = minknet.MinkGlobalEnc(in_channels=3, out_channels=self.hparams['model']['out_dim']).cuda()
self.model = minknet.MinkUNetDiff(in_channels=3, out_channels=self.hparams['model']['out_dim']).cuda()
self.model_refine = minknet.MinkUNet(in_channels=3, out_channels=3*6)
Expand Down Expand Up @@ -169,9 +168,17 @@ def completion_loop(self, x_init, x_t, x_cond, x_uncond):

return x_t.F.cpu().detach().numpy()

def load_pcd(pcd_file):
if pcd_file.endswith('.bin'):
return np.fromfile(pcd_file, dtype=np.float32).reshape((-1,4))[:,:3]
elif pcd_file.endswith('.ply'):
return np.array(o3d.io.read_point_cloud(pcd_file).points)
else:
print(f"Point cloud format '.{pcd_file.split('.')[-1]}' not supported. (supported formats: .bin (kitti format), .ply)")

@click.command()
@click.option('--diff', '-d', type=str, default='', help='path to the scan sequence')
@click.option('--refine', '-r', type=str, default='', help='path to the scan sequence')
@click.option('--diff', '-d', type=str, default='checkpoints/diff_net.ckpt', help='path to the scan sequence')
@click.option('--refine', '-r', type=str, default='checkpoints/refine_net.ckpt', help='path to the scan sequence')
@click.option('--denoising_steps', '-T', type=int, default=50, help='number of denoising steps (default: 50)')
@click.option('--cond_weight', '-s', type=float, default=6.0, help='conditioning weight (default: 6.0)')
def main(diff, refine, denoising_steps, cond_weight):
Expand All @@ -188,8 +195,7 @@ def main(diff, refine, denoising_steps, cond_weight):

for pcd_path in tqdm.tqdm(natsorted(os.listdir(path))):
pcd_file = os.path.join(path, pcd_path)
input_pcd = o3d.io.read_point_cloud(pcd_file)
points = np.array(input_pcd.points)
points = load_pcd(pcd_file)

start = time.time()
refine_scan, diff_scan = diff_completion.complete_scan(points)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages

pkg_name = 'pcdiff'
pkg_name = 'lidiff'
setup(name=pkg_name, version='1.0', packages=find_packages())

0 comments on commit ebdc87e

Please sign in to comment.