Skip to content

Commit 744476c

Browse files
authored
Merge pull request #59 from bigscience-workshop/piqa
Add PIQA dataset
2 parents 0788139 + 1e9a7d6 commit 744476c

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

evaluation/tasks/piqa/__init__.py

Whitespace-only changes.

evaluation/tasks/piqa/english.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

evaluation/tasks/piqa/piqa.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Module for any additional processing required for the TyDi QA dataset
2+
# HuggingFace dataset link: https://huggingface.co/datasets/piqa
3+
from datasets import load_dataset
4+
from jinja2 import Template
5+
from torch.utils.data import Dataset
6+
from tqdm import tqdm
7+
8+
from evaluation.tasks.auto_task import AutoTask
9+
10+
11+
TEMPLATE = Template(
12+
"""
13+
Given a goal and 2 solutions, choose the most appropriate solution.
14+
Goal: {{goal}}
15+
{{'Solution 1'}}: {{sol1}}
16+
{{'Solution 2'}}: {{sol2}}
17+
Answer:
18+
"""
19+
)
20+
21+
22+
class PIQADataset(Dataset):
23+
def __init__(self, tokenizer):
24+
super().__init__()
25+
piqa = load_dataset("piqa", split="validation")
26+
self.items = []
27+
28+
for sample in piqa:
29+
prompt = TEMPLATE.render(
30+
goal=sample["goal"],
31+
sol1=sample["sol1"],
32+
sol2=sample["sol2"],
33+
)
34+
35+
# Tokenize and construct this sample
36+
inputs = tokenizer(
37+
prompt,
38+
return_tensors="pt",
39+
)
40+
self.items.append(
41+
{
42+
"prompt": prompt,
43+
"input_ids": inputs["input_ids"],
44+
"attention_mask": inputs["attention_mask"],
45+
"input_len": inputs["attention_mask"].shape[1],
46+
"label": [sample["sol1"], sample["sol2"]][sample["label"]],
47+
}
48+
)
49+
50+
def __len__(self):
51+
return len(self.items)
52+
53+
def __getitem__(self, index):
54+
return self.items[index]
55+
56+
57+
class PIQATask(AutoTask):
58+
@staticmethod
59+
def get_display_name() -> str:
60+
return "piqa"
61+
62+
def evaluate(self) -> None:
63+
dataset = PIQADataset(self.tokenizer)
64+
65+
substring_matches = 0
66+
for sample in tqdm(dataset, desc=f"Evaluating {self.get_display_name()}"):
67+
output = self.model.generate(
68+
input_ids=sample["input_ids"].to(self.device),
69+
attention_mask=sample["attention_mask"].to(self.device),
70+
max_length=min(sample["input_len"] * 2, self.model.config.n_positions),
71+
)
72+
prompt_len = len(sample["prompt"])
73+
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
74+
predicted_answer = decoded_output[prompt_len:]
75+
76+
label = sample["label"]
77+
substring_match = int(label.lower() in predicted_answer.lower())
78+
79+
substring_matches += substring_match
80+
81+
self.metrics = {
82+
"substring_match": substring_matches / len(dataset) * 100,
83+
}

0 commit comments

Comments
 (0)