Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions verifiers/envs/textarena_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def __init__(
rubric: Rubric | None = None,
feedback_fn: Callable[[str], str] = lambda x: x,
seed: int = 0,
game_kwargs: dict[str, Any] | None = None,
**kwargs,
):
# default parser in textarena is XMLParser
parser = parser or XMLParser(fields=["think", "guess"], answer_field="guess")

self.game = game
self.ta_env = ta.make(env_id=game)
self.game_kwargs = game_kwargs or {}
self.ta_env = ta.make(env_id=game, **self.game_kwargs)
self.num_train_examples = num_train_examples
self.num_eval_examples = num_eval_examples
self.seed = seed
Expand Down Expand Up @@ -90,7 +92,9 @@ async def env_response(
if "ta_env" not in state:
ta_env = deepcopy(self.ta_env)
ta_env.reset(num_players=1)
ta_env.state.game_state["secret_word"] = state["answer"]
# Only set secret_word if the environment has this concept (for word games)
if hasattr(ta_env, 'state') and hasattr(ta_env.state, 'game_state') and 'secret_word' in ta_env.state.game_state:
ta_env.state.game_state["secret_word"] = state["answer"]
state["ta_env"] = ta_env
else:
ta_env = state["ta_env"]
Expand All @@ -107,19 +111,33 @@ async def env_response(
def ta_to_hf(self) -> tuple[Dataset, Dataset | None]:
dataset_rows = []
eval_dataset_rows = []
ta_env = ta.make(env_id=self.game)
ta_env = ta.make(env_id=self.game, **self.game_kwargs)
ta_env.reset(num_players=1)
_, user_prompt = ta_env.get_observation()
words = ta_env.word_list
# set seed
random.seed(self.seed)
for i in range(self.num_train_examples + self.num_eval_examples):
question = user_prompt
answer = random.choice(words)
if i < self.num_train_examples:
dataset_rows.append({"question": question, "answer": answer})
else:
eval_dataset_rows.append({"question": question, "answer": answer})

# Handle word-based games (like Wordle) that have a word_list
if hasattr(ta_env, 'word_list'):
words = ta_env.word_list
# set seed
random.seed(self.seed)
for i in range(self.num_train_examples + self.num_eval_examples):
question = user_prompt
answer = random.choice(words)
if i < self.num_train_examples:
dataset_rows.append({"question": question, "answer": answer})
else:
eval_dataset_rows.append({"question": question, "answer": answer})
else:
# Handle non-word-based games (like Tower of Hanoi) that don't have predefined answers
# For these games, the "answer" is just an empty string or the game completion state
for i in range(self.num_train_examples + self.num_eval_examples):
question = user_prompt
answer = "" # No specific answer for puzzle games
if i < self.num_train_examples:
dataset_rows.append({"question": question, "answer": answer})
else:
eval_dataset_rows.append({"question": question, "answer": answer})

dataset = Dataset.from_list(dataset_rows)
if self.num_eval_examples > 0:
eval_dataset = Dataset.from_list(eval_dataset_rows)
Expand Down