-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathmatch_segmentation.py
40 lines (28 loc) · 1.42 KB
/
match_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
class MatchSegmentation(nn.Module):
def __init__(self):
super(MatchSegmentation, self).__init__()
def forward(self, segmentation, prob, gt_instance, gt_plane_num):
"""
greedy matching
match segmentation with ground truth instance
:param segmentation: tensor with size (N, K)
:param prob: tensor with size (N, 1)
:param gt_instance: tensor with size (21, h, w)
:param gt_plane_num: int
:return: a (K, 1) long tensor indicate closest ground truth instance id, start from 0
"""
n, k = segmentation.size()
_, h, w = gt_instance.size()
assert (prob.size(0) == n and h*w == n)
# ingnore non planar region
gt_instance = gt_instance[:gt_plane_num, :, :].view(1, -1, h*w) # (1, gt_plane_num, h*w)
segmentation = segmentation.t().view(k, 1, h*w) # (k, 1, h*w)
# calculate instance wise cross entropy matrix (K, gt_plane_num)
gt_instance = gt_instance.type(torch.float32)
ce_loss = - (gt_instance * torch.log(segmentation + 1e-6) +
(1-gt_instance) * torch.log(1-segmentation + 1e-6)) # (k, gt_plane_num, k*w)
ce_loss = torch.mean(ce_loss, dim=2) # (k, gt_plane_num)
matching = torch.argmin(ce_loss, dim=1, keepdim=True)
return matching