-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathutils.py
93 lines (70 loc) · 2.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn.functional as F
all_times = []
class Timer:
def __init__(self, name, enabled=True):
self.name = name
self.enabled = enabled
if self.enabled:
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
if self.enabled:
self.start.record()
def __exit__(self, type, value, traceback):
global all_times
if self.enabled:
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
all_times.append(elapsed)
print(self.name, elapsed)
def coords_grid(b, n, h, w, **kwargs):
""" coordinate grid """
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
coords = torch.stack(torch.meshgrid(y, x, indexing="ij"))
return coords[[1,0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1)
def coords_grid_with_index(d, **kwargs):
"""
coordinate grid with frame index
Returns:
coords (Tensor): grid of x-, y-coordinates & depth value for each frame (B,n_frames,3,H,W)
index (Tensor): (B,n_frames,H,W)
"""
b, n, h, w = d.shape
i = torch.ones_like(d)
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
y, x = torch.stack(torch.meshgrid(y, x, indexing="ij"))
y = y.view(1, 1, h, w).repeat(b, n, 1, 1)
x = x.view(1, 1, h, w).repeat(b, n, 1, 1)
coords = torch.stack([x, y, d], dim=2)
index = torch.arange(0, n, dtype=torch.float, **kwargs)
index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w)
return coords, index
def patchify(x, patch_size=3):
""" extract patches from video """
b, n, c, h, w = x.shape
x = x.view(b*n, c, h, w)
y = F.unfold(x, patch_size)
y = y.transpose(1,2)
return y.reshape(b, -1, c, patch_size, patch_size)
def pyramidify(fmap, lvls=[1]):
""" turn fmap into a pyramid """
b, n, c, h, w = fmap.shape
pyramid = []
for lvl in lvls:
gmap = F.avg_pool2d(fmap.view(b*n, c, h, w), lvl, stride=lvl)
pyramid += [ gmap.view(b, n, c, h//lvl, w//lvl) ]
return pyramid
def all_pairs_exclusive(n, **kwargs):
ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
k = ii != jj
return ii[k].reshape(-1), jj[k].reshape(-1)
def set_depth(patches, depth):
patches[...,2,:,:] = depth[...,None,None]
return patches
def flatmeshgrid(*args, **kwargs):
grid = torch.meshgrid(*args, **kwargs)
return (x.reshape(-1) for x in grid)