From 3a8de2d1f93736c2905b049938d254be62944396 Mon Sep 17 00:00:00 2001 From: jumelet Date: Mon, 24 Jan 2022 16:43:42 +0100 Subject: [PATCH 1/2] Add BLiMP task --- evaluation/tasks/blimp/__init__.py | 0 evaluation/tasks/blimp/blimp.py | 67 +++++++++++++++++++++++++++ evaluation/tasks/blimp/english.json | 3 ++ evaluation/tasks/blimp/task_names.py | 69 ++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+) create mode 100644 evaluation/tasks/blimp/__init__.py create mode 100644 evaluation/tasks/blimp/blimp.py create mode 100644 evaluation/tasks/blimp/english.json create mode 100644 evaluation/tasks/blimp/task_names.py diff --git a/evaluation/tasks/blimp/__init__.py b/evaluation/tasks/blimp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/tasks/blimp/blimp.py b/evaluation/tasks/blimp/blimp.py new file mode 100644 index 0000000..a87b388 --- /dev/null +++ b/evaluation/tasks/blimp/blimp.py @@ -0,0 +1,67 @@ +from datasets import load_dataset +from torch.utils.data import Dataset +from tqdm import tqdm + +from evaluation.tasks.auto_task import AutoTask + +from .task_names import blimp_task_names + + +class BLIMPDataset(Dataset): + def __init__(self): + super().__init__() + + self.items = [ + load_dataset("blimp", task, split="train") for task in blimp_task_names[:2] + ] + + def __len__(self): + return len(self.items) + + def __getitem__(self, index): + return self.items[index] + + +class BLIMPTask(AutoTask): + @staticmethod + def get_display_name() -> str: + return "blimp" + + def evaluate(self) -> None: + dataset = BLIMPDataset() + num_correct = 0 + num_items = 0 + + for task_dataset in dataset: + for sample in tqdm( + task_dataset, + desc=f"Evaluating {self.get_display_name()} - {task_dataset.config_name}", + ): + tokenized_good = self.tokenizer( + sample["sentence_good"], return_tensors="pt" + )["input_ids"] + tokenized_bad = self.tokenizer( + sample["sentence_bad"], return_tensors="pt" + )["input_ids"] + + logits_good = self.model( + input_ids=tokenized_good.to(self.device), + ).logits + logits_bad = self.model( + input_ids=tokenized_bad.to(self.device), + ).logits + + # Compute sentence log probabilities from full LM probability distribution + log_prob_good = logits_good[ + 0, range(tokenized_good.shape[1] - 1), tokenized_good[0, 1:] + ].sum() + log_prob_bad = logits_bad[ + 0, range(tokenized_bad.shape[1] - 1), tokenized_bad[0, 1:] + ].sum() + + if log_prob_good > log_prob_bad: + num_correct += 1 + + num_items += 1 + + self.metrics["accuracy"] = num_correct / num_items diff --git a/evaluation/tasks/blimp/english.json b/evaluation/tasks/blimp/english.json new file mode 100644 index 0000000..319b5d8 --- /dev/null +++ b/evaluation/tasks/blimp/english.json @@ -0,0 +1,3 @@ +{ + "target_langs": ["english"] +} \ No newline at end of file diff --git a/evaluation/tasks/blimp/task_names.py b/evaluation/tasks/blimp/task_names.py new file mode 100644 index 0000000..f4d789b --- /dev/null +++ b/evaluation/tasks/blimp/task_names.py @@ -0,0 +1,69 @@ +blimp_task_names = [ + "adjunct_island", + "anaphor_gender_agreement", + "anaphor_number_agreement", + "animate_subject_passive", + "animate_subject_trans", + "causative", + "complex_NP_island", + "coordinate_structure_constraint_complex_left_branch", + "coordinate_structure_constraint_object_extraction", + "determiner_noun_agreement_1", + "determiner_noun_agreement_2", + "determiner_noun_agreement_irregular_1", + "determiner_noun_agreement_irregular_2", + "determiner_noun_agreement_with_adj_2", + "determiner_noun_agreement_with_adj_irregular_1", + "determiner_noun_agreement_with_adj_irregular_2", + "determiner_noun_agreement_with_adjective_1", + "distractor_agreement_relational_noun", + "distractor_agreement_relative_clause", + "drop_argument", + "ellipsis_n_bar_1", + "ellipsis_n_bar_2", + "existential_there_object_raising", + "existential_there_quantifiers_1", + "existential_there_quantifiers_2", + "existential_there_subject_raising", + "expletive_it_object_raising", + "inchoative", + "intransitive", + "irregular_past_participle_adjectives", + "irregular_past_participle_verbs", + "irregular_plural_subject_verb_agreement_1", + "irregular_plural_subject_verb_agreement_2", + "left_branch_island_echo_question", + "left_branch_island_simple_question", + "matrix_question_npi_licensor_present", + "npi_present_1", + "npi_present_2", + "only_npi_licensor_present", + "only_npi_scope", + "passive_1", + "passive_2", + "principle_A_c_command", + "principle_A_case_1", + "principle_A_case_2", + "principle_A_domain_1", + "principle_A_domain_2", + "principle_A_domain_3", + "principle_A_reconstruction", + "regular_plural_subject_verb_agreement_1", + "regular_plural_subject_verb_agreement_2", + "sentential_negation_npi_licensor_present", + "sentential_negation_npi_scope", + "sentential_subject_island", + "superlative_quantifiers_1", + "superlative_quantifiers_2", + "tough_vs_raising_1", + "tough_vs_raising_2", + "transitive", + "wh_island", + "wh_questions_object_gap", + "wh_questions_subject_gap", + "wh_questions_subject_gap_long_distance", + "wh_vs_that_no_gap", + "wh_vs_that_no_gap_long_distance", + "wh_vs_that_with_gap", + "wh_vs_that_with_gap_long_distance", +] \ No newline at end of file From 660bd1cb26dbe33f101aebf80c1f7d8ae796be1b Mon Sep 17 00:00:00 2001 From: Jaap Jumelet Date: Mon, 24 Jan 2022 17:49:07 +0100 Subject: [PATCH 2/2] Remove slice from task_names --- evaluation/tasks/blimp/blimp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/tasks/blimp/blimp.py b/evaluation/tasks/blimp/blimp.py index a87b388..c06c07c 100644 --- a/evaluation/tasks/blimp/blimp.py +++ b/evaluation/tasks/blimp/blimp.py @@ -12,7 +12,7 @@ def __init__(self): super().__init__() self.items = [ - load_dataset("blimp", task, split="train") for task in blimp_task_names[:2] + load_dataset("blimp", task, split="train") for task in blimp_task_names ] def __len__(self):