Skip to content

Commit 8f4ef80

Browse files
committedMar 24, 2020
use distil bert
1 parent e5f34ff commit 8f4ef80

File tree

5 files changed

+53
-88
lines changed

5 files changed

+53
-88
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ MultiWOZ_2.1/
55
log/
66
save/
77
runs/
8+
config.py

‎config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ class Config:
55
parser = argparse.ArgumentParser()
66

77
parser.add_argument("--data_path", default="data/MultiWOZ_2.1", type=str)
8-
parser.add_argument("--batch_size", default=16, type=int)
8+
parser.add_argument("--batch_size", default=32, type=int)
99
parser.add_argument("--max_len", default=100, type=int)
1010
parser.add_argument("--max_value_len", default=20, type=int)
1111
parser.add_argument("--max_context_len", default=450, type=int)

‎model/dst_no_history.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from transformers import BertModel, BertTokenizerFast
8+
from transformers import DistilBertModel, DistilBertTokenizerFast
99
import numpy as np
1010

1111
sys.path.append("../")
@@ -34,10 +34,10 @@ def __init__(self, hparams):
3434
"""
3535

3636
super(DST, self).__init__()
37-
self.context_encoder = BertModel.from_pretrained("bert-base-uncased") # use fine-tuning
37+
self.context_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased") # use fine-tuning
3838
self.context_encoder.train()
39-
self.value_encoder = BertModel.from_pretrained("bert-base-uncased").requires_grad_(False) # fix parameter
40-
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
39+
self.value_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased").requires_grad_(False) # fix parameter
40+
self.tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
4141
self.hidden_size = self.context_encoder.embeddings.word_embeddings.embedding_dim # 768
4242
self.linear_gate = nn.Linear(self.hidden_size, 3) # none, don't care, prediction
4343
self.linear_span = nn.Linear(self.hidden_size, 2) # start, end
@@ -98,11 +98,11 @@ def forward(self, turn_input, turn_context, turn_span, first_turn=False, train=T
9898
context[idx, :len(temp)] = torch.tensor(temp, dtype=torch.int64).cuda()
9999
if max_len < len(temp):
100100
max_len = len(temp)
101-
101+
102102
context = context[:, :max_len]
103103
context_mask = (context != 0)
104104

105-
outputs, _ = self.context_encoder(context, attention_mask=context_mask) # output: [batch, context_len, hidden]
105+
outputs = self.context_encoder(context, attention_mask=context_mask)[0] # output: [batch, context_len, hidden]
106106
gate_output = self.linear_gate(outputs[:, 0, :]) # gate_output: [batch, 3]
107107
gate_output = F.log_softmax(gate_output, dim=1)
108108

@@ -133,7 +133,7 @@ def forward(self, turn_input, turn_context, turn_span, first_turn=False, train=T
133133
value_list = self.value_ontology[slot_] + ["none"]
134134
for value in value_list:
135135
value_output = torch.tensor([self.tokenizer.encode(value)]).cuda()
136-
value_output, _ = self.value_encoder(value_output) # value_outputs: [1, value_len, hidden]
136+
value_output = self.value_encoder(value_output)[0] # value_outputs: [1, value_len, hidden]
137137
value_prob = torch.cosine_similarity(outputs[:, 0, :], value_output[:, 0, :], dim=1).unsqueeze(dim=1) # value_prob: [batch, 1]
138138
if value_probs is None:
139139
value_probs = value_prob
@@ -142,7 +142,7 @@ def forward(self, turn_input, turn_context, turn_span, first_turn=False, train=T
142142

143143
# cosine similarity of true value with context
144144
value_mask = (value_label != 0)
145-
true_value_output, _ = self.value_encoder(value_label, attention_mask=value_mask) # true_value_output: [batch, value_len, hidden]
145+
true_value_output = self.value_encoder(value_label, attention_mask=value_mask)[0] # true_value_output: [batch, value_len, hidden]
146146
true_value_probs = torch.cosine_similarity(outputs[:, 0, :], true_value_output[:, 0, :], dim=1).unsqueeze(dim=1) # true_value_prob: [batch, 1]
147147

148148
acc_slot = torch.ones(batch_size).cuda() # acc: [batch]

‎reader.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55

66
import torch
7-
from transformers import BertTokenizerFast
7+
from transformers import DistilBertTokenizerFast
88

99
import ontology
1010

@@ -16,7 +16,7 @@ def __init__(self, hparams):
1616
self.test = {}
1717
self.data_turns = {}
1818
self.data_path = hparams.data_path
19-
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
19+
self.tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
2020
self.batch_size = hparams.batch_size
2121
self.max_len = hparams.max_len
2222
self.max_value_len = hparams.max_value_len
@@ -300,9 +300,7 @@ def make_input(self, batch):
300300
turn_context_[idx, :len(turn_context[idx])+2] = torch.tensor([self.tokenizer.cls_token_id] + turn_context[idx] + [self.tokenizer.sep_token_id])
301301
for resp in batch[turn]["response"]:
302302
prev_resp.append(resp[1:-1])
303-
turn_context_ = turn_context_[:, :context_len+2]
304-
305-
turn_context_ = turn_context_.cuda()
303+
turn_context_ = turn_context_[:, :context_len+2].cuda()
306304

307305
contexts.append(turn_context_.clone().long())
308306
else: # not first turn
@@ -319,9 +317,7 @@ def make_input(self, batch):
319317
turn_context_[idx, :len(turn_context[idx])+2] = torch.tensor([self.tokenizer.cls_token_id] + turn_context[idx] + [self.tokenizer.sep_token_id])
320318
for resp in batch[turn]["response"]:
321319
prev_resp.append(resp[1:-1])
322-
turn_context_ = turn_context_[:, :min(context_len, self.max_context_len)+2]
323-
324-
turn_context_ = turn_context_.cuda()
320+
turn_context_ = turn_context_[:, :min(context_len, self.max_context_len)+2].cuda()
325321

326322
contexts.append(turn_context_.clone().long())
327323

‎train_distributed_no_history.py

+39-71
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.utils.tensorboard import SummaryWriter
1313
from apex import amp, parallel
1414
from tqdm import tqdm
15-
from transformers import BertTokenizerFast
15+
from transformers import DistilBertTokenizerFast
1616

1717
from model.dst_no_history import DST
1818
from config import Config
@@ -69,79 +69,48 @@ def train(model, reader, optimizer, writer, hparams, tokenizer):
6969
# learning rate scheduling
7070
for param in optimizer.param_groups:
7171
param["lr"] = learning_rate_schedule(train.global_step, train.max_iter, hparams)
72+
73+
prev_belief = None # belief for next turn
74+
for turn_idx in range(turns):
75+
distributed_batch_size = math.ceil(batch_size / hparams.num_gpus)
76+
77+
# distribute batches to each gpu
78+
for key, value in inputs[turn_idx].items():
79+
inputs[turn_idx][key] = distribute_data(value, hparams.num_gpus)[hparams.local_rank]
80+
contexts[turn_idx] = distribute_data(contexts[turn_idx], hparams.num_gpus)[hparams.local_rank]
81+
spans[turn_idx] = distribute_data(spans[turn_idx], hparams.num_gpus)[hparams.local_rank]
7282

73-
try:
74-
prev_belief = None # belief for next turn
75-
for turn_idx in range(turns):
76-
distributed_batch_size = math.ceil(batch_size / hparams.num_gpus)
77-
78-
# split batches for gpu memory
79-
context_len = 0
80-
for idx in range(distributed_batch_size):
81-
context_len_ = len(contexts[turn_idx][idx])
82-
if context_len < context_len_:
83-
context_len = context_len_
84-
if context_len >= 40:
85-
small_batch_size = min(int(hparams.batch_size/hparams.num_gpus / 2), distributed_batch_size)
86-
else:
87-
small_batch_size = distributed_batch_size
88-
89-
# distribute batches to each gpu
90-
for key, value in inputs[turn_idx].items():
91-
inputs[turn_idx][key] = distribute_data(value, hparams.num_gpus)[hparams.local_rank]
92-
contexts[turn_idx] = distribute_data(contexts[turn_idx], hparams.num_gpus)[hparams.local_rank]
93-
spans[turn_idx] = distribute_data(spans[turn_idx], hparams.num_gpus)[hparams.local_rank]
83+
first_turn = (turn_idx == 0)
9484

95-
first_turn = (turn_idx == 0)
85+
if not first_turn:
86+
inputs[turn_idx]["belief_gen"] = prev_belief
9687

97-
if not first_turn:
98-
inputs[turn_idx]["belief_gen"] = prev_belief
88+
optimizer.zero_grad()
89+
loss, acc = model.forward(inputs[turn_idx], contexts[turn_idx], spans[turn_idx], first_turn) # loss: [batch], acc: [batch, slot]
90+
91+
if turn_idx+1 < turns:
92+
prev_belief = inputs[turn_idx]["belief_gen"]
9993

100-
prev_belief = []
101-
102-
for small_batch_idx in range(math.ceil(distributed_batch_size/small_batch_size)):
103-
small_inputs = {}
104-
for key, value in inputs[turn_idx].items():
105-
small_inputs[key] = value[small_batch_size*small_batch_idx:small_batch_size*(small_batch_idx+1)]
106-
small_contexts = contexts[turn_idx][small_batch_size*small_batch_idx:small_batch_size*(small_batch_idx+1)]
107-
small_spans = spans[turn_idx][small_batch_size*small_batch_idx:small_batch_size*(small_batch_idx+1)]
108-
109-
optimizer.zero_grad()
110-
loss, acc = model.forward(small_inputs, small_contexts, small_spans, first_turn) # loss: [batch], acc: [batch, slot]
111-
112-
prev_belief.append(small_inputs["belief_gen"])
113-
114-
total_loss += loss.sum(dim=0).item()
115-
slot_acc += acc.sum(dim=1).sum(dim=0).item()
116-
joint_acc += (acc.mean(dim=1) == 1).sum(dim=0).item()
117-
batch_count += small_batch_size
118-
loss = loss.mean(dim=0)
119-
120-
# distributed training
121-
with amp.scale_loss(loss, optimizer) as scaled_loss:
122-
scaled_loss.backward()
123-
124-
optimizer.step()
125-
torch.cuda.empty_cache()
126-
127-
prev_belief_ = []
128-
for belief in prev_belief:
129-
prev_belief_ += belief
130-
prev_belief = prev_belief_
131-
132-
total_loss = total_loss / batch_count
133-
slot_acc = slot_acc / batch_count / len(ontology.all_info_slots) * 100
134-
joint_acc = joint_acc / batch_count * 100
135-
train.global_step += 1
136-
if hparams.local_rank == 0:
137-
writer.add_scalar("Train/loss", total_loss, train.global_step)
138-
t.set_description("iter: {}, loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}".format(batch_idx+1, total_loss, joint_acc, slot_acc))
139-
except RuntimeError as e:
140-
if hparams.local_rank == 0:
141-
print("\n!!! Error: {}".format(e))
142-
print("batch size: {}, context length: {}".format(small_batch_size, context_len))
94+
total_loss += loss.sum(dim=0).item()
95+
slot_acc += acc.sum(dim=1).sum(dim=0).item()
96+
joint_acc += (acc.mean(dim=1) == 1).sum(dim=0).item()
97+
batch_count += distributed_batch_size
98+
loss = loss.mean(dim=0)
99+
100+
# distributed training
101+
with amp.scale_loss(loss, optimizer) as scaled_loss:
102+
scaled_loss.backward()
103+
104+
optimizer.step()
143105
torch.cuda.empty_cache()
144-
exit(0)
106+
107+
total_loss = total_loss / batch_count
108+
slot_acc = slot_acc / batch_count / len(ontology.all_info_slots) * 100
109+
joint_acc = joint_acc / batch_count * 100
110+
train.global_step += 1
111+
if hparams.local_rank == 0:
112+
writer.add_scalar("Train/loss", total_loss, train.global_step)
113+
t.set_description("iter: {}, loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}".format(batch_idx+1, total_loss, joint_acc, slot_acc))
145114

146115
def validate(model, reader, hparams, tokenizer):
147116
model.eval()
@@ -193,7 +162,6 @@ def validate(model, reader, hparams, tokenizer):
193162
t.set_description("iter: {}".format(batch_idx+1))
194163

195164
model.train()
196-
model.module.slot_encoder.eval()
197165
model.module.value_encoder.eval() # fix value encoder
198166
val_loss = val_loss / batch_count
199167
slot_acc = slot_acc / batch_count / len(ontology.all_info_slots) * 100
@@ -252,7 +220,7 @@ def load(model, optimizer, save_path):
252220
end = time.time()
253221
logger.info("Loaded. {} secs".format(end-start))
254222

255-
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
223+
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
256224

257225
model = DST(hparams).cuda()
258226
optimizer = Adam(model.parameters(), hparams.lr)

0 commit comments

Comments
 (0)
Please sign in to comment.