Skip to content

Commit 0788139

Browse files
committed
fix: resolve flake8 F401
1 parent 3bf8680 commit 0788139

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

evaluation/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
88

9-
import evaluation.tasks # noqa: F401; needed for AutoTask.__subclass__() to work correctly
9+
import evaluation.tasks # noqa: F401
1010
from evaluation.tasks.auto_task import AutoTask
1111
from evaluation.utils.log import get_logger
1212

evaluation/train.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
88

9-
import evaluation.tasks # needed for AutoTask.__subclass__() to work correctly
9+
import evaluation.tasks # noqa: F401
1010
from evaluation.tasks.auto_task import AutoTask
1111
from evaluation.utils.log import get_logger
1212

@@ -28,6 +28,7 @@ class EvaluationArguments:
2828
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
2929
)
3030
tag: Optional[str] = field(default=None, metadata={"help": "Identifier for the evaluation run."})
31+
english_only: Optional[bool] = field(default=True, metadata={"help": "Whether to run evaluation in English only."})
3132

3233

3334
def main():
@@ -64,9 +65,15 @@ def main():
6465

6566
for eval_task in eval_args.eval_tasks:
6667
logger.info(f"Benchmarking {eval_task}...")
67-
task = AutoTask.from_task_name(eval_task, tokenizer=tokenizer, model=model, device=device)
68+
task = AutoTask.from_task_name(
69+
eval_task,
70+
model=model,
71+
tokenizer=tokenizer,
72+
device=device,
73+
english_only=eval_args.english_only,
74+
)
6875
set_seed(train_args.seed)
69-
task.train()
76+
task.evaluate()
7077
task.save_metrics(output_dir, logger)
7178

7279

0 commit comments

Comments
 (0)