-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_postprocessor.py
38 lines (29 loc) · 1.16 KB
/
gen_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Any
import torch
import torch.nn as nn
from .base_postprocessor import BasePostprocessor
class GENPostprocessor(BasePostprocessor):
def __init__(self, config):
super().__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.gamma = self.args.gamma
self.M = self.args.M
self.args_dict = self.config.postprocessor.postprocessor_sweep
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
output = net(data)
score = torch.softmax(output, dim=1)
_, pred = torch.max(score, dim=1)
conf = self.generalized_entropy(score, self.gamma, self.M)
return pred, conf
def set_hyperparam(self, hyperparam: list):
self.gamma = hyperparam[0]
self.M = hyperparam[1]
def get_hyperparam(self):
return [self.gamma, self.M]
def generalized_entropy(self, softmax_id_val, gamma=0.1, M=100):
probs = softmax_id_val
probs_sorted = torch.sort(probs, dim=1)[0][:, -M:]
scores = torch.sum(probs_sorted**gamma * (1 - probs_sorted)**(gamma),
dim=1)
return -scores