|
12 | 12 | from torch.utils.tensorboard import SummaryWriter
|
13 | 13 | from apex import amp, parallel
|
14 | 14 | from tqdm import tqdm
|
15 |
| -from transformers import BertTokenizerFast |
| 15 | +from transformers import DistilBertTokenizerFast |
16 | 16 |
|
17 | 17 | from model.dst_no_history import DST
|
18 | 18 | from config import Config
|
@@ -69,79 +69,48 @@ def train(model, reader, optimizer, writer, hparams, tokenizer):
|
69 | 69 | # learning rate scheduling
|
70 | 70 | for param in optimizer.param_groups:
|
71 | 71 | param["lr"] = learning_rate_schedule(train.global_step, train.max_iter, hparams)
|
| 72 | + |
| 73 | + prev_belief = None # belief for next turn |
| 74 | + for turn_idx in range(turns): |
| 75 | + distributed_batch_size = math.ceil(batch_size / hparams.num_gpus) |
| 76 | + |
| 77 | + # distribute batches to each gpu |
| 78 | + for key, value in inputs[turn_idx].items(): |
| 79 | + inputs[turn_idx][key] = distribute_data(value, hparams.num_gpus)[hparams.local_rank] |
| 80 | + contexts[turn_idx] = distribute_data(contexts[turn_idx], hparams.num_gpus)[hparams.local_rank] |
| 81 | + spans[turn_idx] = distribute_data(spans[turn_idx], hparams.num_gpus)[hparams.local_rank] |
72 | 82 |
|
73 |
| - try: |
74 |
| - prev_belief = None # belief for next turn |
75 |
| - for turn_idx in range(turns): |
76 |
| - distributed_batch_size = math.ceil(batch_size / hparams.num_gpus) |
77 |
| - |
78 |
| - # split batches for gpu memory |
79 |
| - context_len = 0 |
80 |
| - for idx in range(distributed_batch_size): |
81 |
| - context_len_ = len(contexts[turn_idx][idx]) |
82 |
| - if context_len < context_len_: |
83 |
| - context_len = context_len_ |
84 |
| - if context_len >= 40: |
85 |
| - small_batch_size = min(int(hparams.batch_size/hparams.num_gpus / 2), distributed_batch_size) |
86 |
| - else: |
87 |
| - small_batch_size = distributed_batch_size |
88 |
| - |
89 |
| - # distribute batches to each gpu |
90 |
| - for key, value in inputs[turn_idx].items(): |
91 |
| - inputs[turn_idx][key] = distribute_data(value, hparams.num_gpus)[hparams.local_rank] |
92 |
| - contexts[turn_idx] = distribute_data(contexts[turn_idx], hparams.num_gpus)[hparams.local_rank] |
93 |
| - spans[turn_idx] = distribute_data(spans[turn_idx], hparams.num_gpus)[hparams.local_rank] |
| 83 | + first_turn = (turn_idx == 0) |
94 | 84 |
|
95 |
| - first_turn = (turn_idx == 0) |
| 85 | + if not first_turn: |
| 86 | + inputs[turn_idx]["belief_gen"] = prev_belief |
96 | 87 |
|
97 |
| - if not first_turn: |
98 |
| - inputs[turn_idx]["belief_gen"] = prev_belief |
| 88 | + optimizer.zero_grad() |
| 89 | + loss, acc = model.forward(inputs[turn_idx], contexts[turn_idx], spans[turn_idx], first_turn) # loss: [batch], acc: [batch, slot] |
| 90 | + |
| 91 | + if turn_idx+1 < turns: |
| 92 | + prev_belief = inputs[turn_idx]["belief_gen"] |
99 | 93 |
|
100 |
| - prev_belief = [] |
101 |
| - |
102 |
| - for small_batch_idx in range(math.ceil(distributed_batch_size/small_batch_size)): |
103 |
| - small_inputs = {} |
104 |
| - for key, value in inputs[turn_idx].items(): |
105 |
| - small_inputs[key] = value[small_batch_size*small_batch_idx:small_batch_size*(small_batch_idx+1)] |
106 |
| - small_contexts = contexts[turn_idx][small_batch_size*small_batch_idx:small_batch_size*(small_batch_idx+1)] |
107 |
| - small_spans = spans[turn_idx][small_batch_size*small_batch_idx:small_batch_size*(small_batch_idx+1)] |
108 |
| - |
109 |
| - optimizer.zero_grad() |
110 |
| - loss, acc = model.forward(small_inputs, small_contexts, small_spans, first_turn) # loss: [batch], acc: [batch, slot] |
111 |
| - |
112 |
| - prev_belief.append(small_inputs["belief_gen"]) |
113 |
| - |
114 |
| - total_loss += loss.sum(dim=0).item() |
115 |
| - slot_acc += acc.sum(dim=1).sum(dim=0).item() |
116 |
| - joint_acc += (acc.mean(dim=1) == 1).sum(dim=0).item() |
117 |
| - batch_count += small_batch_size |
118 |
| - loss = loss.mean(dim=0) |
119 |
| - |
120 |
| - # distributed training |
121 |
| - with amp.scale_loss(loss, optimizer) as scaled_loss: |
122 |
| - scaled_loss.backward() |
123 |
| - |
124 |
| - optimizer.step() |
125 |
| - torch.cuda.empty_cache() |
126 |
| - |
127 |
| - prev_belief_ = [] |
128 |
| - for belief in prev_belief: |
129 |
| - prev_belief_ += belief |
130 |
| - prev_belief = prev_belief_ |
131 |
| - |
132 |
| - total_loss = total_loss / batch_count |
133 |
| - slot_acc = slot_acc / batch_count / len(ontology.all_info_slots) * 100 |
134 |
| - joint_acc = joint_acc / batch_count * 100 |
135 |
| - train.global_step += 1 |
136 |
| - if hparams.local_rank == 0: |
137 |
| - writer.add_scalar("Train/loss", total_loss, train.global_step) |
138 |
| - t.set_description("iter: {}, loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}".format(batch_idx+1, total_loss, joint_acc, slot_acc)) |
139 |
| - except RuntimeError as e: |
140 |
| - if hparams.local_rank == 0: |
141 |
| - print("\n!!! Error: {}".format(e)) |
142 |
| - print("batch size: {}, context length: {}".format(small_batch_size, context_len)) |
| 94 | + total_loss += loss.sum(dim=0).item() |
| 95 | + slot_acc += acc.sum(dim=1).sum(dim=0).item() |
| 96 | + joint_acc += (acc.mean(dim=1) == 1).sum(dim=0).item() |
| 97 | + batch_count += distributed_batch_size |
| 98 | + loss = loss.mean(dim=0) |
| 99 | + |
| 100 | + # distributed training |
| 101 | + with amp.scale_loss(loss, optimizer) as scaled_loss: |
| 102 | + scaled_loss.backward() |
| 103 | + |
| 104 | + optimizer.step() |
143 | 105 | torch.cuda.empty_cache()
|
144 |
| - exit(0) |
| 106 | + |
| 107 | + total_loss = total_loss / batch_count |
| 108 | + slot_acc = slot_acc / batch_count / len(ontology.all_info_slots) * 100 |
| 109 | + joint_acc = joint_acc / batch_count * 100 |
| 110 | + train.global_step += 1 |
| 111 | + if hparams.local_rank == 0: |
| 112 | + writer.add_scalar("Train/loss", total_loss, train.global_step) |
| 113 | + t.set_description("iter: {}, loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}".format(batch_idx+1, total_loss, joint_acc, slot_acc)) |
145 | 114 |
|
146 | 115 | def validate(model, reader, hparams, tokenizer):
|
147 | 116 | model.eval()
|
@@ -193,7 +162,6 @@ def validate(model, reader, hparams, tokenizer):
|
193 | 162 | t.set_description("iter: {}".format(batch_idx+1))
|
194 | 163 |
|
195 | 164 | model.train()
|
196 |
| - model.module.slot_encoder.eval() |
197 | 165 | model.module.value_encoder.eval() # fix value encoder
|
198 | 166 | val_loss = val_loss / batch_count
|
199 | 167 | slot_acc = slot_acc / batch_count / len(ontology.all_info_slots) * 100
|
@@ -252,7 +220,7 @@ def load(model, optimizer, save_path):
|
252 | 220 | end = time.time()
|
253 | 221 | logger.info("Loaded. {} secs".format(end-start))
|
254 | 222 |
|
255 |
| - tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") |
| 223 | + tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") |
256 | 224 |
|
257 | 225 | model = DST(hparams).cuda()
|
258 | 226 | optimizer = Adam(model.parameters(), hparams.lr)
|
|
0 commit comments