-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpad_tensor.py
31 lines (28 loc) · 912 Bytes
/
pad_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import numpy as np
import torch.nn.functional as functional
def pad_tensor(inp):
"""
Pad a list of input tensors into a list of tensors with same dimension
:param inp: input tensor list
:return: output tensor list
"""
assert type(inp[0]) == torch.Tensor
it = iter(inp)
t = next(it)
max_shape = list(t.shape)
while True:
try:
t = next(it)
for i in range(len(max_shape)):
max_shape[i] = int(max(max_shape[i], t.shape[i]))
except StopIteration:
break
max_shape = np.array(max_shape)
padded_ts = []
for t in inp:
pad_pattern = np.zeros(2 * len(max_shape), dtype=np.int64)
pad_pattern[::-2] = max_shape - np.array(t.shape)
pad_pattern = tuple(pad_pattern.tolist())
padded_ts.append(functional.pad(t, pad_pattern, 'constant', 0))
return padded_ts