Skip to content

Commit ebd386c

Browse files
authored
Merge pull request #58 from bigscience-workshop/wmt2
Add WMT dataset
2 parents 744476c + 0573e93 commit ebd386c

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

evaluation/tasks/wmt/__init__.py

Whitespace-only changes.

evaluation/tasks/wmt/english.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"pair": "kk-en",
3+
"stride": 512,
4+
"batch_size": 8
5+
}

evaluation/tasks/wmt/wmt.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,56 @@
11
# Module for any additional processing required for the WMT dataset
22
# HuggingFace dataset link: https://huggingface.co/datasets/wmt19
3+
import torch
4+
from datasets import load_dataset
5+
from torch.utils.data import DataLoader, Dataset
6+
from tqdm import tqdm
7+
8+
from evaluation.tasks.auto_task import AutoTask
9+
10+
11+
class WMTEnglishDataset(Dataset):
12+
def __init__(self, tokenizer, stride=512, max_len=1024, pair="kk-en"):
13+
super().__init__()
14+
assert "en" in pair, f"Expected `pair` to contain English, but got {pair} instead"
15+
wmt = load_dataset("wmt19", pair, split="validation")["translation"]
16+
text_list = [item["en"] for item in wmt]
17+
text = " ".join(text_list)
18+
input_ids = tokenizer(text, return_tensors="pt", verbose=False).input_ids.squeeze()
19+
self.input_ids = input_ids.unfold(size=max_len, step=stride, dimension=-1)
20+
21+
def __len__(self):
22+
return len(self.input_ids)
23+
24+
def __getitem__(self, index):
25+
return self.input_ids[index]
26+
27+
28+
class WMTTask(AutoTask):
29+
@staticmethod
30+
def get_display_name() -> str:
31+
return "wmt"
32+
33+
def evaluate(self) -> None:
34+
stride = self.task_config["stride"]
35+
dataset = WMTEnglishDataset(
36+
self.tokenizer, stride=stride, max_len=self.model.config.n_positions, pair=self.task_config["pair"]
37+
)
38+
# TODO: resolve conflict with tokenizer to support num_workers
39+
loader = DataLoader(
40+
dataset,
41+
batch_size=self.task_config["batch_size"],
42+
shuffle=False,
43+
drop_last=True,
44+
)
45+
log_likelihoods = []
46+
for input_ids in tqdm(loader, desc=f"Evaluating {self.get_display_name()}"):
47+
input_ids = input_ids.to(self.device)
48+
target_ids = input_ids.clone()
49+
# Exclude context tokens from loss computation
50+
target_ids[:, :-stride] = -100
51+
with torch.no_grad():
52+
outputs = self.model(input_ids, labels=target_ids)
53+
log_likelihood = outputs[0]
54+
log_likelihoods.append(log_likelihood)
55+
perplexity = torch.exp(torch.stack(log_likelihoods).sum() / len(loader))
56+
self.metrics["perplexity"] = perplexity.item()

0 commit comments

Comments
 (0)