Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 160 additions & 4 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,21 @@ def make_dataset(
push_to_hub: bool = False,
hub_name: str | None = None,
state_columns: list[str] | None = None,
concatenate_safe: bool = True,
**kwargs,
) -> Dataset:
"""
Make a dataset from the evaluation results.

Args:
results: The evaluation results to convert to a dataset
push_to_hub: Whether to push the dataset to the Hugging Face Hub
hub_name: The name of the dataset on the Hugging Face Hub
state_columns: List of state columns to include in the dataset
concatenate_safe: Whether to ensure the dataset can be concatenated with others
by standardizing column types (useful for combining results from
different environments). Defaults to True for consistent schemas.
**kwargs: Additional arguments passed to Dataset creation
"""
# TODO: enable saving of multimodal datasets
state_columns = state_columns or []
Expand All @@ -567,9 +578,20 @@ def make_dataset(
"task": results.task,
"reward": results.reward,
}
if results.info[0] != {}:
results_dict["info"] = results.info

# Handle info column
if results.info and len(results.info) > 0 and results.info[0] != {}:
results_dict["info"] = []
for info_item in results.info:
if isinstance(info_item, dict):
results_dict["info"].append(json.dumps(info_item))
else:
results_dict["info"].append(str(info_item) if info_item is not None else json.dumps({}))
cols.append("info")
elif concatenate_safe:
results_dict["info"] = [json.dumps({})] * len(results.prompt)
cols.append("info")

for i in range(len(results.completion)):
results_dict["completion"].append(
sanitize_tool_calls(results.completion[i])
Expand All @@ -585,12 +607,146 @@ def make_dataset(
self.logger.warning(
f"Column {col} not found in state, skipping from dataset."
)

# Always apply schema standardization for concatenation safety when concatenate_safe is True
if concatenate_safe:
standard_columns = ["prompt", "completion", "answer", "task", "reward"]
for col in standard_columns:
if col not in results_dict:
results_dict[col] = [None] * len(results.prompt)

if "info" not in results_dict:
results_dict["info"] = [json.dumps({})] * len(results.prompt)
elif len(results_dict["info"]) != len(results.prompt):
if len(results_dict["info"]) == 1:
results_dict["info"] = results_dict["info"] * len(results.prompt)
else:
results_dict["info"] = [json.dumps({})] * len(results.prompt)

cols = list(set(cols + standard_columns + ["info"]))

# Apply PyArrow type standardization to ensure compatibility
# Convert any incompatible types to strings for better concatenation
import pyarrow as pa
for col in cols:
if col in results_dict:
try:
pa.array(results_dict[col])
except (pa.ArrowInvalid, pa.ArrowTypeError):
results_dict[col] = [str(item) if item is not None else None for item in results_dict[col]]

# Create dataset with standardized schema
dataset = Dataset.from_dict({col: results_dict[col] for col in cols})

if push_to_hub:
assert hub_name is not None
dataset.push_to_hub(hub_name)
return dataset

@staticmethod
def concatenate_datasets(datasets: list[Dataset], split_names: list[str] | None = None) -> Dataset:
"""
Concatenate multiple datasets from different environments into a single dataset.

This function handles schema mismatches by:
1. Standardizing column types across datasets
2. Adding missing columns with None values
3. Converting incompatible types to strings
4. Optionally adding a 'split' column to identify the source of each example

Args:
datasets: List of datasets to concatenate
split_names: Optional list of names for each dataset (used to create a 'split' column)

Returns:
A single concatenated dataset with standardized schema

Example:
>>> # After running evaluations on different environments
>>> math_dataset = math_env.make_dataset(math_results)
>>> qa_dataset = qa_env.make_dataset(qa_results)
>>> # Concatenate with split identifiers
>>> combined_dataset = Environment.concatenate_datasets(
... [math_dataset, qa_dataset],
... split_names=["math", "qa"]
... )
>>> # Push to hub with multiple splits
>>> combined_dataset.push_to_hub("my-eval-results", split="train")
"""
if not datasets:
raise ValueError("At least one dataset must be provided")

if split_names and len(split_names) != len(datasets):
raise ValueError("split_names must have the same length as datasets")

from datasets import concatenate_datasets as hf_concatenate_datasets

all_columns = set()
for dataset in datasets:
all_columns.update(dataset.column_names)

column_types = {}
for col in all_columns:
all_values = []
for dataset in datasets:
if col in dataset.column_names:
all_values.extend(dataset[col])

try:
float_values = []
can_be_float = True
for val in all_values:
if val is None:
float_values.append(None)
else:
try:
float_values.append(float(val))
except (ValueError, TypeError):
can_be_float = False
break

if can_be_float:
column_types[col] = 'float'
else:
column_types[col] = 'string'
except Exception:
column_types[col] = 'string'

# Standardize each dataset to have all columns with consistent types
standardized_datasets = []
for i, dataset in enumerate(datasets):
dataset_dict = {}
for col in all_columns:
if col in dataset.column_names:
column_data = list(dataset[col])
target_type = column_types[col]

if target_type == 'float':
standardized_data = []
for val in column_data:
if val is None:
standardized_data.append(None)
else:
try:
standardized_data.append(float(val))
except (ValueError, TypeError):
standardized_data.append(None)
dataset_dict[col] = standardized_data
else:
dataset_dict[col] = [str(item) if item is not None else None for item in column_data]
else:
dataset_dict[col] = [None] * len(dataset)

standardized_dataset = Dataset.from_dict(dataset_dict)
standardized_datasets.append(standardized_dataset)

if split_names:
for i, (dataset, split_name) in enumerate(zip(standardized_datasets, split_names)):
dataset = dataset.add_column("split", [split_name] * len(dataset))
standardized_datasets[i] = dataset

return hf_concatenate_datasets(standardized_datasets)

#########################################################
# Optional helper functions for parsing vLLM completions
#########################################################
Expand Down Expand Up @@ -644,7 +800,7 @@ def parse_chat_completion_tokens(
]
return tokens

def parse_completion_tokens(self, completion: Completion) -> list[int]:
def process_completion_tokens(self, completion: Completion) -> list[int]:
"""Parses the output token ids from a list of chat completions returned by vLLM OAI server."""
assert len(completion.choices) == 1, "Response should always have one choice"
assert completion.choices[0].logprobs is not None, (
Expand Down Expand Up @@ -783,7 +939,7 @@ def process_completion_format_vllm(
text, response = zipped[i]
# model-generated case -- use response
if response is not None:
completion_turn_ids = self.parse_completion_tokens(response)
completion_turn_ids = self.process_completion_tokens(response)
completion_turn_mask = [1] * len(completion_turn_ids)
completion_turn_logprobs = self.parse_completion_logprobs(response)
completion_ids.extend(completion_turn_ids)
Expand Down
Loading