|
1 | 1 | # Module for any additional processing required for the WMT dataset |
2 | 2 | # HuggingFace dataset link: https://huggingface.co/datasets/wmt19 |
| 3 | +import torch |
| 4 | +from datasets import load_dataset |
| 5 | +from torch.utils.data import DataLoader, Dataset |
| 6 | +from tqdm import tqdm |
| 7 | + |
| 8 | +from evaluation.tasks.auto_task import AutoTask |
| 9 | + |
| 10 | + |
| 11 | +class WMTEnglishDataset(Dataset): |
| 12 | + def __init__(self, tokenizer, stride=512, max_len=1024, pair="kk-en"): |
| 13 | + super().__init__() |
| 14 | + assert "en" in pair, f"Expected `pair` to contain English, but got {pair} instead" |
| 15 | + wmt = load_dataset("wmt19", pair, split="validation")["translation"] |
| 16 | + text_list = [item["en"] for item in wmt] |
| 17 | + text = " ".join(text_list) |
| 18 | + input_ids = tokenizer(text, return_tensors="pt", verbose=False).input_ids.squeeze() |
| 19 | + self.input_ids = input_ids.unfold(size=max_len, step=stride, dimension=-1) |
| 20 | + |
| 21 | + def __len__(self): |
| 22 | + return len(self.input_ids) |
| 23 | + |
| 24 | + def __getitem__(self, index): |
| 25 | + return self.input_ids[index] |
| 26 | + |
| 27 | + |
| 28 | +class WMTTask(AutoTask): |
| 29 | + @staticmethod |
| 30 | + def get_display_name() -> str: |
| 31 | + return "wmt" |
| 32 | + |
| 33 | + def evaluate(self) -> None: |
| 34 | + stride = self.task_config["stride"] |
| 35 | + dataset = WMTEnglishDataset( |
| 36 | + self.tokenizer, stride=stride, max_len=self.model.config.n_positions, pair=self.task_config["pair"] |
| 37 | + ) |
| 38 | + # TODO: resolve conflict with tokenizer to support num_workers |
| 39 | + loader = DataLoader( |
| 40 | + dataset, |
| 41 | + batch_size=self.task_config["batch_size"], |
| 42 | + shuffle=False, |
| 43 | + drop_last=True, |
| 44 | + ) |
| 45 | + log_likelihoods = [] |
| 46 | + for input_ids in tqdm(loader, desc=f"Evaluating {self.get_display_name()}"): |
| 47 | + input_ids = input_ids.to(self.device) |
| 48 | + target_ids = input_ids.clone() |
| 49 | + # Exclude context tokens from loss computation |
| 50 | + target_ids[:, :-stride] = -100 |
| 51 | + with torch.no_grad(): |
| 52 | + outputs = self.model(input_ids, labels=target_ids) |
| 53 | + log_likelihood = outputs[0] |
| 54 | + log_likelihoods.append(log_likelihood) |
| 55 | + perplexity = torch.exp(torch.stack(log_likelihoods).sum() / len(loader)) |
| 56 | + self.metrics["perplexity"] = perplexity.item() |
0 commit comments