-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
69 lines (63 loc) · 3.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision.ops import sigmoid_focal_loss
import numpy as np
class Model(nn.Module):
def __init__(self, encoder, config, tokenizer, args, beta=0.9999):
super(Model, self).__init__()
self.encoder = encoder
self.config = config
self.tokenizer = tokenizer
self.args = args
self.beta = beta
def forward(self, inputs_ids=None, attn_mask=None, position_idx=None, labels=None):
# embedding
nodes_mask = position_idx.eq(0)
token_mask = position_idx.ge(2)
inputs_embeddings = self.encoder.roberta.embeddings.word_embeddings(inputs_ids)
nodes_to_token_mask = nodes_mask[:, :, None] & token_mask[:, None, :] & attn_mask
nodes_to_token_mask = nodes_to_token_mask / (nodes_to_token_mask.sum(-1) + 1e-10)[:, :, None]
avg_embeddings = torch.einsum("abc,acd->abd", nodes_to_token_mask, inputs_embeddings)
inputs_embeddings = inputs_embeddings * (~nodes_mask)[:, :, None] + avg_embeddings * nodes_mask[:, :, None]
# outputs = self.encoder(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx)[0]
outputs = self.encoder(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx, token_type_ids=position_idx.eq(-1).long())[0]
logits = outputs
prob = F.sigmoid(logits)
if labels is not None:
if self.args.training in ["standard", "augmentation", "down", "over"]:
labels = labels.float()
loss_ = torch.log(prob[:, 0] + 1e-10) * labels + torch.log((1 - prob)[:, 0] + 1e-10) * (1 - labels)
loss = -loss_.mean()
elif self.args.training == "weight":
labels = labels.float()
if len(torch.where(labels == 1)[0]) == 0:
loss = -torch.sum(torch.log((1 - prob)[:, 0] + 1e-10) * (1 - labels)) / len(
torch.where(labels == 0)[0])
else:
loss = torch.sum(torch.log(prob[:, 0] + 1e-10) * labels) / len(
torch.where(labels == 1)[0]) + torch.sum(
torch.log((1 - prob)[:, 0] + 1e-10) * (1 - labels)) / len(torch.where(labels == 0)[0])
loss /= -2
elif self.args.training == "cbl":
weight_0 = (1 - self.beta) / (1 - np.power(self.beta, len(torch.where(labels == 0)[0])))
if len(torch.where(labels == 1)[0]) > 0:
weight_1 = (1 - self.beta) / (1 - np.power(self.beta, len(torch.where(labels == 1)[0])))
labels = labels.float()
loss = torch.log(prob[:, 0] + 1e-10) * labels * weight_1 + torch.log((1 - prob)[:, 0] + 1e-10) * (
1 - labels) * weight_0
else:
loss = torch.log((1 - prob)[:, 0] + 1e-10) * (1 - labels) * weight_0
loss = -loss.mean()
elif self.args.training == "focal":
loss = sigmoid_focal_loss(logits, labels.view(len(labels), 1).float(), reduction="mean")
else:
labels = labels.float()
loss = torch.log(prob[:, 0] + 1e-10) * labels * ((1 - prob)[:, 0]) + torch.log(
(1 - prob)[:, 0] + 1e-10) * (1 - labels) * (prob[:, 0])
loss = -2 * loss.mean()
return loss, prob
else:
return prob