-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbleu_score.py
More file actions
72 lines (55 loc) · 3.13 KB
/
Copy pathbleu_score.py
File metadata and controls
72 lines (55 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# !pip install torchtext==0.6.0
import torch
from torch.utils.data import Dataset, DataLoader
from train import get_ds, get_model
from config import get_config
from beam_search import greedy_search, beam_search, length_penalty
from torchtext.data.metrics import bleu_score
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
config = get_config()
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
model_filename = get_weights_file_path(config)
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
model.eval()
trgs = []
pred_trgs = []
index = 0
with torch.no_grad():
for batch in val_dataloader:
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
if config['beam_search']:
model_out = beam_search(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device, beam_width=config['beam_width'])[0][0].squeeze(0)
else:
model_out, _ = greedy_search(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
source_text = batch["src_text"][0].split(' ')
target_text = batch["tgt_text"][0].split(' ')
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()).split(' ')
pred_trgs.append(model_out_text)
trgs.append([target_text])
index += 1
if (index + 1) % 100 == 0:
print(f"[{index + 1}/{len(val_dataloader)}]")
print(f"예측: {model_out_text}")
print(f"정답: {target_text}")
bleu = bleu_score(pred_trgs, trgs, max_n=4, weights=[0.25, 0.25, 0.25, 0.25])
print(f'Total BLEU Score = {bleu*100:.2f}')
individual_bleu1_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[1, 0, 0, 0])
individual_bleu2_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[0, 1, 0, 0])
individual_bleu3_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[0, 0, 1, 0])
individual_bleu4_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[0, 0, 0, 1])
print(f'Individual BLEU1 score = {individual_bleu1_score*100:.2f}')
print(f'Individual BLEU2 score = {individual_bleu2_score*100:.2f}')
print(f'Individual BLEU3 score = {individual_bleu3_score*100:.2f}')
print(f'Individual BLEU4 score = {individual_bleu4_score*100:.2f}')
cumulative_bleu1_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[1, 0, 0, 0])
cumulative_bleu2_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[1/2, 1/2, 0, 0])
cumulative_bleu3_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[1/3, 1/3, 1/3, 0])
cumulative_bleu4_score = bleu_score(pred_trgs, trgs, max_n=4, weights=[1/4, 1/4, 1/4, 1/4])
print(f'Cumulative BLEU1 score = {cumulative_bleu1_score*100:.2f}')
print(f'Cumulative BLEU2 score = {cumulative_bleu2_score*100:.2f}')
print(f'Cumulative BLEU3 score = {cumulative_bleu3_score*100:.2f}')
print(f'Cumulative BLEU4 score = {cumulative_bleu4_score*100:.2f}')