-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathba.py
182 lines (132 loc) · 5.34 KB
/
ba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
from torch_scatter import scatter_sum
from . import fastba
from . import lietorch
from .lietorch import SE3
from .utils import Timer
from . import projective_ops as pops
class CholeskySolver(torch.autograd.Function):
@staticmethod
def forward(ctx, H, b):
# don't crash training if cholesky decomp fails
U, info = torch.linalg.cholesky_ex(H)
if torch.any(info):
ctx.failed = True
return torch.zeros_like(b)
xs = torch.cholesky_solve(b, U)
ctx.save_for_backward(U, xs)
ctx.failed = False
return xs
@staticmethod
def backward(ctx, grad_x):
if ctx.failed:
return None, None
U, xs = ctx.saved_tensors
dz = torch.cholesky_solve(grad_x, U)
dH = -torch.matmul(xs, dz.transpose(-1,-2))
return dH, dz
# utility functions for scattering ops
def safe_scatter_add_mat(A, ii, jj, n, m):
v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m)
return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m)
def safe_scatter_add_vec(b, ii, n):
v = (ii >= 0) & (ii < n)
return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n)
# apply retraction operator to inv-depth maps
def disp_retr(disps, dz, ii):
ii = ii.to(device=dz.device)
return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1])
# apply retraction operator to poses
def pose_retr(poses, dx, ii):
ii = ii.to(device=dx.device)
return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1]))
def block_matmul(A, B):
""" block matrix multiply """
b, n1, m1, p1, q1 = A.shape
b, n2, m2, p2, q2 = B.shape
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2)
return torch.matmul(A, B).reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4)
def block_solve(A, B, ep=1.0, lm=1e-4):
""" block matrix solve """
b, n1, m1, p1, q1 = A.shape
b, n2, m2, p2, q2 = B.shape
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2)
A = A + (ep + lm * A) * torch.eye(n1*p1, device=A.device)
X = CholeskySolver.apply(A, B)
return X.reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4)
def block_show(A):
import matplotlib.pyplot as plt
b, n1, m1, p1, q1 = A.shape
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
plt.imshow(A[0].detach().cpu().numpy())
plt.show()
def BA(poses, patches, intrinsics, targets, weights, lmbda, ii, jj, kk, bounds, ep=100.0, PRINT=False, fixedp=1, structure_only=False):
""" bundle adjustment """
b = 1
n = max(ii.max().item(), jj.max().item()) + 1
coords, v, (Ji, Jj, Jz) = \
pops.transform(poses, patches, intrinsics, ii, jj, kk, jacobian=True)
p = coords.shape[3]
r = targets - coords[...,p//2,p//2,:]
v *= (r.norm(dim=-1) < 250).float()
in_bounds = \
(coords[...,p//2,p//2,0] > bounds[0]) & \
(coords[...,p//2,p//2,1] > bounds[1]) & \
(coords[...,p//2,p//2,0] < bounds[2]) & \
(coords[...,p//2,p//2,1] < bounds[3])
v *= in_bounds.float()
if PRINT:
print((r * v[...,None]).norm(dim=-1).mean().item())
r = (v[...,None] * r).unsqueeze(dim=-1)
weights = (v[...,None] * weights).unsqueeze(dim=-1)
wJiT = (weights * Ji).transpose(2,3)
wJjT = (weights * Jj).transpose(2,3)
wJzT = (weights * Jz).transpose(2,3)
Bii = torch.matmul(wJiT, Ji)
Bij = torch.matmul(wJiT, Jj)
Bji = torch.matmul(wJjT, Ji)
Bjj = torch.matmul(wJjT, Jj)
Eik = torch.matmul(wJiT, Jz)
Ejk = torch.matmul(wJjT, Jz)
vi = torch.matmul(wJiT, r)
vj = torch.matmul(wJjT, r)
# fix first pose
ii = ii.clone()
jj = jj.clone()
n = n - fixedp
ii = ii - fixedp
jj = jj - fixedp
kx, kk = torch.unique(kk, return_inverse=True, sorted=True)
m = len(kx)
B = safe_scatter_add_mat(Bii, ii, ii, n, n).view(b, n, n, 6, 6) + \
safe_scatter_add_mat(Bij, ii, jj, n, n).view(b, n, n, 6, 6) + \
safe_scatter_add_mat(Bji, jj, ii, n, n).view(b, n, n, 6, 6) + \
safe_scatter_add_mat(Bjj, jj, jj, n, n).view(b, n, n, 6, 6)
E = safe_scatter_add_mat(Eik, ii, kk, n, m).view(b, n, m, 6, 1) + \
safe_scatter_add_mat(Ejk, jj, kk, n, m).view(b, n, m, 6, 1)
C = safe_scatter_add_vec(torch.matmul(wJzT, Jz), kk, m)
v = safe_scatter_add_vec(vi, ii, n).view(b, n, 1, 6, 1) + \
safe_scatter_add_vec(vj, jj, n).view(b, n, 1, 6, 1)
w = safe_scatter_add_vec(torch.matmul(wJzT, r), kk, m)
if isinstance(lmbda, torch.Tensor):
lmbda = lmbda.reshape(*C.shape)
Q = 1.0 / (C + lmbda)
### solve w/ schur complement ###
EQ = E * Q[:,None]
if structure_only or n == 0:
dZ = (Q * w).view(b, -1, 1, 1)
else:
S = B - block_matmul(EQ, E.permute(0,2,1,4,3))
y = v - block_matmul(EQ, w.unsqueeze(dim=2))
dX = block_solve(S, y, ep=ep, lm=1e-4)
dZ = Q * (w - block_matmul(E.permute(0,2,1,4,3), dX).squeeze(dim=-1))
dX = dX.view(b, -1, 6)
dZ = dZ.view(b, -1, 1, 1)
x, y, disps = patches.unbind(dim=2)
disps = disp_retr(disps, dZ, kx).clamp(min=1e-3, max=10.0)
patches = torch.stack([x, y, disps], dim=2)
if not structure_only and n > 0:
poses = pose_retr(poses, dX, fixedp + torch.arange(n))
return poses, patches