forked from cafeelmore/Pretty-Pytorch-Text-Classification
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_eval.py
More file actions
199 lines (165 loc) · 7.91 KB
/
train_eval.py
File metadata and controls
199 lines (165 loc) · 7.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# coding: UTF-8
import sys
import time
import datetime
from copy import deepcopy
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from sklearn import metrics
from sklearn.metrics import classification_report
from torchmetrics import Accuracy, F1Score
from utils import get_time_dif
from bert_optimizer import BertAdam
def test_model(config, model, loss_fn, metrics_dict, test_data):
# 3,test -------------------------------------------------
model.load_state_dict(torch.load(config.save_path))
test_step_runner = StepRunner(model = model, stage = "test",
loss_fn = loss_fn, metrics_dict = deepcopy(metrics_dict))
test_epoch_runner = EpochRunner(test_step_runner)
with torch.no_grad():
test_metrics = test_epoch_runner(test_data)
print("<<<<<< Test Result >>>>>>")
for name, metric in test_metrics.items():
print("- {0} : {1}".format(name, metric))
labels = test_step_runner.all_labels
preds = test_step_runner.all_preds
class_names = config.label_dict.keys()
print(classification_report(labels, preds, target_names=class_names))
# api for training and testing models
def train_and_test(config, model, train_iter, dev_iter, test_iter):
loss_fn = nn.CrossEntropyLoss()
param_optimizer = list(model.named_parameters()) # get all parameters
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
optimizer = BertAdam(optimizer_grouped_parameters,
lr = config.learning_rate,
warmup = 0.05,
t_total = len(train_iter) * config.num_epochs)
metrics_dict = {"acc": Accuracy().to(config.device),'f1': F1Score(num_classes = config.num_classes, average = 'macro').to(config.device)}
df_history = train_model(config, model, optimizer, loss_fn, metrics_dict,
train_data = train_iter, val_data = dev_iter, monitor="val_loss")
test_model(config, model, loss_fn, metrics_dict, test_iter)
return df_history
def print_log(info):
now_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("\n" + "==========" * 8 + "%s" % now_time)
print(str(info) + "\n")
class StepRunner:
'''Step runner for each training steps.
Param:
model: the training model
loss_fn: loss function
stage: current stage of the model. Default: 'train'.
metric_dict: a dictionary for all selected metrics
optimizer: the selected optimizer for model training
'''
def __init__(self, model, loss_fn, stage="train", metrics_dict=None, optimizer=None):
self.model = model
self.loss_fn = loss_fn
self.stage = stage
self.metrics_dict = metrics_dict
self.optimizer = optimizer
if self.stage == 'test':
self.all_preds = []
self.all_labels = []
def step(self, features, labels):
# loss
preds = self.model(features)
loss = self.loss_fn(preds, labels)
# backward()
if self.optimizer is not None and self.stage == "train":
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
# metrics
step_metrics = {self.stage + "_" + name: metric_fn(preds, labels).item()
for name, metric_fn in self.metrics_dict.items()}
if self.stage == "test":
labels = labels.data.cpu().numpy()
predicted = torch.max(preds.data, 1)[1].cpu().numpy()
self.all_preds = np.append(self.all_preds, predicted)
self.all_labels = np.append(self.all_labels, labels)
return loss.item(), step_metrics
def train_step(self, features, labels):
self.model.train() # training mode, the dropout layer works
return self.step(features, labels)
@torch.no_grad()
def eval_step(self, features, labels):
self.model.eval() # eval mode, the dropout layer doesn't work
return self.step(features, labels)
def __call__(self, features, labels):
if self.stage == "train":
return self.train_step(features, labels)
else:
return self.eval_step(features, labels)
class EpochRunner:
'''Step runner for each training epoch.
Param:
step_runner: the step_runner for each training step
'''
def __init__(self, step_runner):
self.step_runner = step_runner
self.stage = self.step_runner.stage
def __call__(self, dataloader):
total_loss, step = 0, 0
loop = tqdm(enumerate(dataloader), total=len(dataloader), file=sys.stdout)
for i, batch in loop:
loss, step_metrics = self.step_runner(*batch)
step_log = dict({self.stage + "_loss": loss}, **step_metrics)
total_loss += loss
step += 1
if i != len(dataloader) - 1:
loop.set_postfix(**step_log)
else:
epoch_loss = total_loss / step
epoch_metrics = {self.stage + "_" + name: metric_fn.compute().item()
for name, metric_fn in self.step_runner.metrics_dict.items()}
epoch_log = dict({self.stage + "_loss": epoch_loss}, **epoch_metrics)
loop.set_postfix(**epoch_log)
for name, metric_fn in self.step_runner.metrics_dict.items():
metric_fn.reset()
return epoch_log
def train_model(config, model, optimizer, loss_fn, metrics_dict,
train_data, val_data=None, monitor="val_loss"):
epochs = config.num_epochs
ckpt_path = config.save_path
patience = config.patience
history = {}
for epoch in range(1, epochs + 1):
print_log("Epoch {0} / {1}".format(epoch, epochs))
# 1,train -------------------------------------------------
train_step_runner = StepRunner(model = model, stage = "train",
loss_fn = loss_fn, metrics_dict = deepcopy(metrics_dict),
optimizer = optimizer)
train_epoch_runner = EpochRunner(train_step_runner)
train_metrics = train_epoch_runner(train_data)
for name, metric in train_metrics.items():
history[name] = history.get(name, []) + [metric]
# 2,validate -------------------------------------------------
if val_data:
val_step_runner = StepRunner(model = model, stage = "val",
loss_fn = loss_fn, metrics_dict = deepcopy(metrics_dict))
val_epoch_runner = EpochRunner(val_step_runner)
with torch.no_grad():
val_metrics = val_epoch_runner(val_data)
val_metrics["epoch"] = epoch
for name, metric in val_metrics.items():
history[name] = history.get(name, []) + [metric]
# 3,early-stopping -------------------------------------------------
arr_scores = history[monitor]
best_score_idx = np.argmin(arr_scores)
if best_score_idx == len(arr_scores) - 1:
torch.save(model.state_dict(), ckpt_path)
print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor, arr_scores[best_score_idx]))
if len(arr_scores) - best_score_idx > patience:
print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
monitor, patience))
break
model.load_state_dict(torch.load(ckpt_path))
return pd.DataFrame(history)