Skip to content
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
2696a49
use inspect-ai to evaluate aime25 and gsm8k
NathanHB Oct 7, 2025
578d530
revert file
NathanHB Oct 7, 2025
21fa870
working for 3 tasks
NathanHB Oct 7, 2025
27b2af1
parallel evals of tasks
NathanHB Oct 7, 2025
b9a610d
adds gpqa diamond to inspect
NathanHB Oct 8, 2025
25c1128
move tasks to individual files
NathanHB Oct 13, 2025
0d42edf
move tasks to individual files
NathanHB Oct 13, 2025
6cc3c04
enable extended tasks as well
NathanHB Oct 13, 2025
4c38951
run precomit hook
NathanHB Oct 13, 2025
d2fd5e1
fix mkqa
NathanHB Oct 13, 2025
2ddb0f9
chaange extended suite to lighteval
NathanHB Oct 13, 2025
ee97122
chaange extended suite to lighteval
NathanHB Oct 14, 2025
e2c8e22
add metdata to tasks
NathanHB Oct 14, 2025
c980ddb
add metdata to tasks
NathanHB Oct 14, 2025
57fe390
remove license notice and put docstring on top of file
NathanHB Oct 14, 2025
ee081f2
homogenize tags
NathanHB Oct 14, 2025
1ed1602
add docstring for all multilingual tasks
NathanHB Oct 14, 2025
f4b0e27
add docstring for all multilingual tasks
NathanHB Oct 14, 2025
81d9e4e
add name and dataset to metadata
NathanHB Oct 15, 2025
b734532
use TASKS_TABLE for multilingual tasks
NathanHB Oct 15, 2025
c3911fc
use TASKS_TABLE for default tasks
NathanHB Oct 15, 2025
e439f70
use TASKS_TABLE for default tasks
NathanHB Oct 15, 2025
6447ee7
loads all tasks correclty
NathanHB Oct 15, 2025
88754bf
move community tasks to default tasks and update doc
NathanHB Oct 16, 2025
5445f5c
move community tasks to default tasks and update doc
NathanHB Oct 16, 2025
f53bd76
Merge remote-tracking branch 'origin/main' into nathan-reorg-tasks
NathanHB Oct 16, 2025
6a0c615
revert uneeded changes
NathanHB Oct 16, 2025
1435e38
fix doc build
NathanHB Oct 16, 2025
15f41f2
fix doc build
NathanHB Oct 16, 2025
74e5c0f
remove custom tasks and let user decide if loading multilingual tasks
NathanHB Oct 16, 2025
aad136c
load-tasks multilingual fix
NathanHB Oct 16, 2025
242bc43
update doc
NathanHB Oct 16, 2025
6806bf8
remove uneeded file
NathanHB Oct 16, 2025
e94fa59
update readme
NathanHB Oct 16, 2025
8800d1a
update readme
NathanHB Oct 16, 2025
970f33b
update readme
NathanHB Oct 16, 2025
b8c26dc
fix test
NathanHB Oct 16, 2025
764de72
add back the custom tasks
NathanHB Oct 17, 2025
a326ea8
add back the custom tasks
NathanHB Oct 17, 2025
81081cd
fix tasks
NathanHB Oct 17, 2025
74b40f6
fix tasks
NathanHB Oct 17, 2025
083fb1b
fix tasks
NathanHB Oct 17, 2025
2dab2bf
fix tests
NathanHB Oct 17, 2025
57ca0e5
fix tests
NathanHB Oct 17, 2025
480e40a
add inspect-ai
NathanHB Oct 20, 2025
ade2900
add tasks
NathanHB Oct 29, 2025
079ceaf
add gpqa
NathanHB Oct 29, 2025
8d00799
make model config work
NathanHB Oct 29, 2025
cea5e99
Update src/lighteval/metrics/metrics.py
NathanHB Oct 29, 2025
fb47bb7
init
NathanHB Oct 30, 2025
2736bc9
Merge branch 'nathan-move-to-inspectai' of github.com:huggingface/lig…
NathanHB Oct 30, 2025
d5e6c9f
Merge branch 'main' into nathan-move-to-inspectai
NathanHB Oct 30, 2025
e55a9af
fix tests
NathanHB Oct 30, 2025
ba41f1c
Merge branch 'nathan-move-to-inspectai' of github.com:huggingface/lig…
NathanHB Oct 30, 2025
59c5dcc
fix tests
NathanHB Oct 30, 2025
40254db
fix tests
NathanHB Oct 30, 2025
53275fe
fix tests
NathanHB Oct 30, 2025
72e5c2b
add correct system prompt for hle
NathanHB Oct 30, 2025
7fc1753
add correct system prompt for hle
NathanHB Oct 30, 2025
260d744
review suggestions
NathanHB Nov 3, 2025
835b799
add doc
NathanHB Nov 3, 2025
c216a27
change buttons
NathanHB Nov 3, 2025
21e6020
change buttons
NathanHB Nov 3, 2025
7e65400
change buttons
NathanHB Nov 3, 2025
0a4f6be
move benchmark finder to openeval org
NathanHB Nov 3, 2025
b661d0d
better help for eval
NathanHB Nov 3, 2025
f142b39
better help for eval
NathanHB Nov 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ keywords = ["evaluation", "nlp", "llm"]
dependencies = [
# Base dependencies
"transformers>=4.54.0",
"inspect-ai",
"accelerate",
"huggingface_hub[hf_xet]>=0.30.2",
"torch>=2.0,<3.0",
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import lighteval.main_baseline
import lighteval.main_custom
import lighteval.main_endpoint
import lighteval.main_inspect
import lighteval.main_nanotron
import lighteval.main_sglang
import lighteval.main_tasks
Expand Down Expand Up @@ -69,6 +70,7 @@
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_vllm.vllm)
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_custom.custom)
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_sglang.sglang)
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_inspect.eval)
app.add_typer(
lighteval.main_endpoint.app,
name="endpoint",
Expand Down
249 changes: 249 additions & 0 deletions src/lighteval/main_inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
from collections import defaultdict
from typing import Literal

from inspect_ai import Epochs, Task, task
from inspect_ai import eval_set as inspect_ai_eval_set
from inspect_ai.dataset import hf_dataset
from inspect_ai.scorer import exact
from inspect_ai.solver import generate, system_message
from pytablewriter import MarkdownTableWriter

from lighteval.models.abstract_model import InspectAIModelConfig
from lighteval.tasks.lighteval_task import LightevalTaskConfig


logger = logging.getLogger(__name__)


@task
def get_inspect_ai_task(lighteval_task_config: LightevalTaskConfig) -> Task:
name = lighteval_task_config.name
sample_fields = lighteval_task_config.sample_fields

dataset_repo = lighteval_task_config.hf_repo
dataset_subset = lighteval_task_config.hf_subset
dataset_split = lighteval_task_config.evaluation_splits[0]

dataset = hf_dataset(dataset_repo, name=dataset_subset, split=dataset_split, sample_fields=sample_fields)
if lighteval_task_config.filter is not None:
dataset = dataset.filter(lighteval_task_config.filter)
solver = lighteval_task_config.solver or [
generate(cache=True),
]
scorers = lighteval_task_config.scorer or exact()
# TODO: have per task epoch and epoch reducer
epochs = 1
epochs_reducer = "mean"

if lighteval_task_config.num_fewshots > 0:
name += f"_{lighteval_task_config.num_fewshots}_shots"
# TODO: use fewshot split
fewshots = hf_dataset(
path=dataset_repo,
name=dataset_subset,
split=dataset_split,
sample_fields=sample_fields,
shuffle=True,
seed=42,
limit=lighteval_task_config.num_fewshots,
)
solver.insert(
0,
system_message("\n\n".join([lighteval_task_config.sample_to_fewshot(sample) for sample in fewshots])),
)

return Task(dataset=dataset, solver=solver, scorer=scorers, name=name, epochs=Epochs(epochs, epochs_reducer))


def mean_metrics_by_prefix(results_per_model_per_task, sep=":"):
out = {}
for model, tasks in results_per_model_per_task.items():
pref_metrics = defaultdict(lambda: defaultdict(list))
# Collect both per-task metrics and values for prefix aggregation
per_model_out = {}
for task_name, metrics in tasks.items():
if sep not in task_name:
# No subtasks: keep metrics as-is for this task
per_task_vals = {}
for mname, metric in metrics.items():
per_task_vals[mname] = getattr(metric, "value", metric)
per_model_out[task_name] = per_task_vals
continue
prefix = task_name.split(sep, 1)[0]
# Keep non-averaged task metrics
per_task_vals = {}
for mname, metric in metrics.items():
value = getattr(metric, "value", metric)
per_task_vals[mname] = value
pref_metrics[prefix][mname].append(value)
per_model_out[task_name] = per_task_vals
# Add the averaged metrics per prefix
for p, md in pref_metrics.items():
per_model_out[p] = {m: sum(v) / len(v) for m, v in md.items()}
out[model] = per_model_out
return out


def results_to_markdown_table(
results_per_model_per_task,
metric: str = "accuracy",
stderr_metric: str = "stderr",
max_total_columns: int | None = None,
means_only_task_threshold: int = 10,
) -> str:
cols = _collect_columns(results_per_model_per_task, means_only_task_threshold, max_total_columns)

writer = MarkdownTableWriter()
writer.headers = ["Model"] + cols

rows = []
for model in sorted(results_per_model_per_task.keys()):
row = [model]
data = results_per_model_per_task[model]
for col in cols:
row.append(_format_metric_cell(data, col, metric, stderr_metric))
rows.append(row)

writer.value_matrix = rows
return writer.dumps()
Comment on lines +122 to +143
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you reuse the output functions we already have?



def _collect_columns(
results_per_model_per_task, means_only_task_threshold: int, max_total_columns: int | None
) -> list[str]:
all_cols = set()
for model_data in results_per_model_per_task.values():
all_cols.update(model_data.keys())
agg_cols = sorted([c for c in all_cols if ":" not in c])
task_cols = sorted([c for c in all_cols if ":" in c])

if len(task_cols) > means_only_task_threshold:
logger.info(
f"Only showing the meaned tasks (aggregates only) because there are more than {means_only_task_threshold} tasks"
)
return agg_cols

cols = agg_cols + task_cols
if max_total_columns is not None and len(cols) > max_total_columns:
keep_left = max(1, max_total_columns // 2)
keep_right = max_total_columns - keep_left
left_cols = cols[:keep_left]
right_cols = cols[-keep_right:] if keep_right > 0 else []
return left_cols + ["…"] + right_cols
return cols


def _format_metric_cell(data: dict, col: str, metric: str, stderr_metric: str) -> str:
if col == "…":
return "…"
metrics = data.get(col)
if not metrics:
return "-"
val = metrics.get(metric)
if isinstance(val, dict):
val = val.get("value", None)
if val is not None:
return "%.2f" % val
return "-"


def eval(
models: list[str],
tasks: str,
epochs: int = 1,
epochs_reducer: Literal["mean", "median", "mode", "max", "at_least_{n}", "ass_at_{k}"] | None = None,
max_connections: int = 50,
timeout: int = 30,
retry_on_error: int = 1,
max_retries: int = 5,
log_dir: str = "lighteval-logs",
log_dir_allow_dirty: bool = True,
display: Literal["rich", "full", "conversations", "plain", "log", "none"] = "rich",
model_config: str | None = None,
max_samples: int | None = None,
max_tasks: int | None = None,
):
from lighteval.tasks.registry import Registry

registry = Registry(tasks=tasks, custom_tasks=None, load_multilingual=False)
task_configs = registry.task_to_configs
inspect_ai_tasks = []

for task_name, task_configs in task_configs.items():
for task_config in task_configs:
inspect_ai_tasks.append(get_inspect_ai_task(task_config))

if model_config is not None and model_config.endswith(".yaml"):
model_config = InspectAIModelConfig.from_path(model_config).dict()
elif model_config is not None:
model_config = InspectAIModelConfig.from_args(model_config).dict()
else:
model_config = {}

success, logs = inspect_ai_eval_set(
inspect_ai_tasks,
model=models,
max_connections=max_connections,
timeout=timeout,
retry_on_error=retry_on_error,
max_retries=max_retries,
limit=max_samples,
max_tasks=max_tasks,
log_dir=log_dir,
log_dir_allow_dirty=log_dir_allow_dirty,
display=display,
**model_config,
)

if not success:
return

results_per_model_per_task = {}

for model in models:
results_per_model_per_task[model] = {}

for log in logs:
if log.eval.model == model:
results_per_model_per_task[model][log.eval.task] = log.results.metrics

results_per_model_per_task_agg = mean_metrics_by_prefix(results_per_model_per_task)
table_md = results_to_markdown_table(results_per_model_per_task_agg)
print()
print(table_md)
print(f"results saved to {log_dir}")
print(f'run "inspect view --log-dir {log_dir}" to view the results')


if __name__ == "__main__":
task = "lighteval|gsm8k|5,lighteval|gsm8k|1,lighteval|gsm8k|0"
task = "lighteval|agieval|0"
task = "lighteval|hle|0"
task = "lighteval|ifeval|0"
task = "lighteval|gpqa|0"
task = "lighteval|ifbench_test|0"
model = "hf-inference-providers/meta-llama/Llama-3.1-8B-Instruct:nebius"
eval(models=[model], tasks=task)
64 changes: 64 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import numpy as np
from aenum import Enum
from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr
from inspect_ai.solver import TaskState

from lighteval.metrics.dynamic_metrics import MultilingualExtractiveMatchMetric
from lighteval.metrics.harness_compatibility.drop import DropMetrics
Expand Down Expand Up @@ -66,6 +68,8 @@
ExprExtractionConfig,
IndicesExtractionConfig,
LatexExtractionConfig,
extract_target_from_pred,
get_extraction_regexes_inspect,
)
from lighteval.metrics.utils.metric_utils import (
CorpusLevelMetric,
Expand All @@ -77,6 +81,66 @@
from lighteval.utils.language import Language


@scorer(metrics=[accuracy()])
def math_scorer():
gold_extraction_target = (ExprExtractionConfig(),)
pred_extraction_target = (ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0))
language = Language.ENGLISH
fallback_mode = "first_match"
extraction_mode = "first_match"
timeout_seconds = 5

gold_extraction_regexes = get_extraction_regexes_inspect(gold_extraction_target, language, len_choices=1)
pred_extraction_regexes = get_extraction_regexes_inspect(pred_extraction_target, language, len_choices=1)

async def score(state: TaskState, target: Target):
extracted_predictions = extract_target_from_pred(
state.output.completion, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds
)
extracted_gold = extract_target_from_pred(
target.text, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds
)
return Score(
value="C" if extracted_predictions == extracted_gold else "I",
explanation=state.output.completion,
answer=str(extracted_predictions),
)

return score


@scorer(metrics=[accuracy(), stderr()])
def multichoice_scorer():
language = Language.ENGLISH
gold_extraction_target = (
IndicesExtractionConfig(prefix_for_extraction="NativeLetters", try_extract_without_anchor=True),
)
pred_extraction_target = (
IndicesExtractionConfig(prefix_for_extraction="NativeLetters", try_extract_without_anchor=True),
)
fallback_mode = "first_match"
extraction_mode = "first_match"
timeout_seconds = 5

gold_extraction_regexes = get_extraction_regexes_inspect(gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes_inspect(pred_extraction_target, language)
Comment on lines +114 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this behavior of nested functions behaving as classes is really meh for legibility, customizability and maintenability

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definetely could be better ! but that's how inspect is expecting it. Will work on a better format once we start having more metrics compatible with it.


async def score(state: TaskState, target: Target):
extracted_predictions = extract_target_from_pred(
state.output.completion, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds
)
extracted_gold = extract_target_from_pred(
target.text, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds
)
return Score(
value="C" if extracted_predictions == extracted_gold else "I",
explanation=state.output.completion,
answer=str(extracted_predictions),
)

return score


class Metrics(Enum):
acc_golds_likelihood = SampleLevelMetric( # todo: we need a better name for this!
metric_name="acc",
Expand Down
31 changes: 31 additions & 0 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,37 @@ def lazy_indices_regex(
return [(re.compile(pattern), priority) for pattern, priority in regexes]


def get_extraction_regexes_inspect(
target_types: Sequence[ExtractionTarget], language: Language, len_choices: int = 1
) -> list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]]:
"""Get extraction regexes for inspect AI.
Temporary implementation.
TODO: refacto this function to share code with get_extraction_regexes
"""
extraction_regexes: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]] = [
(lazy_latex_regex(target_type, language), target_type)
if isinstance(target_type, LatexExtractionConfig)
else (lazy_expr_regex(target_type, language), target_type)
if isinstance(target_type, ExprExtractionConfig)
else (lazy_indices_regex(target_type, len_choices, language), target_type)
for target_type in target_types
]

# Sort the extraction res so that order is indices, latex, expr
def get_target_type_order(target_type: ExtractionTarget) -> int:
match target_type:
case IndicesExtractionConfig():
return 0
case LatexExtractionConfig():
return 1
case ExprExtractionConfig():
return 2

extraction_regexes = sorted(extraction_regexes, key=lambda x: get_target_type_order(x[1]))

return extraction_regexes


def get_extraction_regexes(
formatted_doc: Doc, target_types: Sequence[ExtractionTarget], language: Language
) -> list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]]:
Expand Down
Loading
Loading