1+ from datasets import load_dataset
2+ from promptsource .templates import TemplateCollection
3+ from promptsource .utils import removeHyphen
14from transformers import AutoTokenizer
25
36from evaluation .tasks .tydiqa_secondary .tydiqa_secondary import TyDiQADataset
@@ -15,3 +18,18 @@ def test_prompt():
1518 "such as wound, ostomy, and continence nursing and burn center care.\n "
1619 ) in prompt
1720 assert prompt .endswith ("Answer:" )
21+
22+
23+ def test_promptsource_template ():
24+ ds_key , sub_key = "tydiqa" , "secondary_task"
25+ tydiqa_sec_vld_ds = load_dataset (ds_key , sub_key , split = "validation" , streaming = True )
26+ tydiqa_sec_vld_ds_en = filter (lambda x : x ["id" ].split ("-" )[0 ] == "english" , tydiqa_sec_vld_ds )
27+ template_collection = TemplateCollection ()
28+ tydiqa_sec_tmpls = template_collection .get_dataset (ds_key , sub_key )
29+ tmpl = tydiqa_sec_tmpls ["simple_question_reading_comp_2" ]
30+ prompt , _ = tmpl .apply (removeHyphen (next (tydiqa_sec_vld_ds_en )))
31+ assert (
32+ "Wound care encourages and speeds wound healing via cleaning and protection from reinjury or infection. "
33+ "Depending on each patient's needs, it can range from the simplest first aid to entire nursing specialties "
34+ "such as wound, ostomy, and continence nursing and burn center care.\n "
35+ ) in prompt
0 commit comments