Skip to content

Commit 4207819

Browse files
committed
feat: use promptsource templates
1 parent 573d10f commit 4207819

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

poetry.lock

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ tensorflow = "2.5.0"
1919
torch = "1.9.0"
2020
tqdm = "4.62.0"
2121
transformers = "4.9.1"
22+
promptsource = {git = "https://[email protected]/bigscience-workshop/promptsource.git", rev = "main"}
2223

2324
[tool.poetry.dev-dependencies]
2425
isort = "^5.9.3"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ tensorflow==2.5.0
44
torch==1.9.0
55
tqdm==4.62.0
66
transformers==4.9.1
7+
promptsource @ git+ssh://[email protected]/bigscience-workshop/promptsource.git@main

tests/test_tydiqa_secondary.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from datasets import load_dataset
2+
from promptsource.templates import TemplateCollection
3+
from promptsource.utils import removeHyphen
14
from transformers import AutoTokenizer
25

36
from evaluation.tasks.tydiqa_secondary.tydiqa_secondary import TyDiQADataset
@@ -15,3 +18,18 @@ def test_prompt():
1518
"such as wound, ostomy, and continence nursing and burn center care.\n"
1619
) in prompt
1720
assert prompt.endswith("Answer:")
21+
22+
23+
def test_promptsource_template():
24+
ds_key, sub_key = "tydiqa", "secondary_task"
25+
tydiqa_sec_vld_ds = load_dataset(ds_key, sub_key, split="validation", streaming=True)
26+
tydiqa_sec_vld_ds_en = filter(lambda x: x["id"].split("-")[0] == "english", tydiqa_sec_vld_ds)
27+
template_collection = TemplateCollection()
28+
tydiqa_sec_tmpls = template_collection.get_dataset(ds_key, sub_key)
29+
tmpl = tydiqa_sec_tmpls["simple_question_reading_comp_2"]
30+
prompt, _ = tmpl.apply(removeHyphen(next(tydiqa_sec_vld_ds_en)))
31+
assert (
32+
"Wound care encourages and speeds wound healing via cleaning and protection from reinjury or infection. "
33+
"Depending on each patient's needs, it can range from the simplest first aid to entire nursing specialties "
34+
"such as wound, ostomy, and continence nursing and burn center care.\n"
35+
) in prompt

0 commit comments

Comments
 (0)