diff --git a/conf/workarena.yaml b/conf/workarena.yaml new file mode 100644 index 00000000..1d4a00a4 --- /dev/null +++ b/conf/workarena.yaml @@ -0,0 +1,91 @@ +defaults: + - base + - override streams: redis + - override finetune: ppo + - _self_ + +world: + actor_fraction: 2 + preprocessor_fraction: 0 + finetune_fraction: 4 + +save_tapes: true + +output_dir: results/workarena/${now:%Y-%m-%d}/${now:%H-%M-%S} +model_path: meta-llama/Llama-3.1-8B-Instruct +use_ray: true + +finetune: + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples + train_batch_size: 1 + gradient_accumulation_passes: 1024 + +eval_every_n_versions: 10240 # 1024 effective bs * 10 "optim steps" + +llm: + use_cache: false + parameters: + max_tokens: 4096 # output tokens + temperature: 1.0 +test_llm: + parameters: + max_tokens: ${...llm.parameters.max_tokens} + temperature: 0.0 + top_p: 1.0 + top_k: 50 + +vllm_config: + use_v1: false + vllm_kwargs: + max-num-seqs: 256 + max-num-batched-tokens: 32000 + max_model_len: 16384 + gpu-memory-utilization: 0.9 + +actor: + rollout_policy: pipelinerl.domains.workarena.rollouts.generate_workarena_rollout + llm_max_rollouts: 256 + problem_queue_size: 2 + async_batch_size: 1 + rollout_workers: 32 + shared_memory_entry_size: 100000000 + collect_logprobs: false + +preprocess: + n_workers: 32 # Increase from 8 + chunk_n_groups: 8 # Increase from 2 for better throughput + # queue for loaded raw groups + raw_queue_size: 32 # Increase from 8 + # queue for processed chunks of multiple groups + input_queue_size: 64 # Increase from 32 + # queue for ready chunks for multiple groups + output_queue_size: 64 # Increase from 32 + # ring buffer to replace old samples with new ones when training is slow + ring_buffer_size: 1024 # Increase from 128 + # "virtual" sample queue per lead trainer + max_ready_samples_per_lead: 256 # Increase from 64 + shared_memory_entry_size: 1000000000 # Increase from 100M + +# AGENT CONFIGURATION +agent_max_loops: 30 # max number of agent - environment interactions for each task +agent_attempts: 3 # number of attempts to run the agent (retry on errors) +rollout_timeout: 600 # overall timeout for entire rollout in seconds (10 minutes) +agent: + _target_: examples.workarena.agent.WorkArenaAgent + +# ENVIRONMENT CONFIGURATION +start_attempts: 3 # number of attempts to start each task +environment: + _target_: pipelinerl.domains.workarena.environment.WorkArenaEnvironment + exp_path: ${output_dir}/browser + headless: true + +# DATASET CONFIGURATION +dataset_loader: pipelinerl.domains.workarena.load_tasks.load_tasks +dataset_loader_params: + seeds: [0, 42, 1337, 900, 103] +train_dataset_names: + - l1 +test_dataset_names: + - l1 \ No newline at end of file diff --git a/pipelinerl/domains/workarena/environment.py b/pipelinerl/domains/workarena/environment.py new file mode 100644 index 00000000..9feb7cc8 --- /dev/null +++ b/pipelinerl/domains/workarena/environment.py @@ -0,0 +1,94 @@ +import logging +import os +from typing import Any + +from browsergym.workarena.tasks.base import AbstractServiceNowTask + +from tapeagents.core import LLMOutputParsingFailureAction, Observation +from tapeagents.environment import Environment +from tapeagents.steps import ActionExecutionFailure +from tapeagents.tools.browser import Browser +from tapeagents.utils import FatalError + +from examples.workarena.steps import Action, FinalAnswerAction, ReflectionThought, WorkArenaTape, WorkArenaTask + +logger = logging.getLogger(__name__) + + +class WorkArenaEnvironment(Environment): + """ + WorkArena environment for running tasks. + Translates action steps into gym browser python commands in the form of a string. + """ + + def __init__(self, exp_path: str, headless: bool = True) -> None: + super().__init__() + os.makedirs(exp_path, exist_ok=True) + self.exp_path = exp_path + self.headless = headless + + def initialize(self): + self.browser = Browser(headless=self.headless, exp_path=self.exp_path) + + def start_task( + self, task_entrypoint: type[AbstractServiceNowTask], seed: int = 42 + ) -> tuple[WorkArenaTape, dict[str, Any]]: + task_id = f"browsergym/{task_entrypoint.get_task_id()}" + info = self.browser.start_task(task_id, seed, wait_for_user_message=False) # type: ignore + obs = self.browser.run_browser_action("noop()") + tape = WorkArenaTape(steps=[obs, WorkArenaTask(task=info["goal"])]) + return tape, info + + def actions(self) -> tuple[type[Action], ...]: + return self.browser.actions + + def validate_task(self, tape: WorkArenaTape) -> tuple[bool, dict]: + answer = tape.steps[-1].text if isinstance(tape.steps[-1], FinalAnswerAction) else "Task finished" + self.browser._env.unwrapped.chat.add_message(role="assistant", msg=answer) + assert self.browser._env.unwrapped.task is not None + reward, stop, message, info = self.browser._env.unwrapped.task.validate( + self.browser._env.unwrapped.page, self.browser._env.unwrapped.chat.messages + ) + result_dict = { + "reward": reward, + "stop": stop, + "message": message, + "info": info, + } + return bool(reward > 0), result_dict + + def react(self, tape: WorkArenaTape) -> WorkArenaTape: + actions = [] + for step in tape.steps[-tape.metadata.n_added_steps :]: + if isinstance(step, Action): + actions.append(step) + elif isinstance(step, ReflectionThought): + # send reflection to chat for user to see + self.browser._env.unwrapped.chat.add_message( + role="assistant", msg=f"{step.last_action_achieved_effect}\nTodo: {step.next_action}" + ) + for action in actions: + try: + if isinstance(action, LLMOutputParsingFailureAction): + continue + observation = self.step(action) + tape = tape.append(observation) # type: ignore + except FatalError: + raise + except Exception as e: + logger.exception(f"Error during action execution: {e}") + tape = tape.append(ActionExecutionFailure(error=str(e))) + break + return tape + + def step(self, action: Action) -> Observation: + return self.browser.run(action) + + def reset(self): + self.browser.reset() + + def close(self) -> None: + try: + self.browser.close() + except Exception as e: + logger.error(f"Failed to properly close task: {e}") diff --git a/pipelinerl/domains/workarena/load_tasks.py b/pipelinerl/domains/workarena/load_tasks.py new file mode 100644 index 00000000..6d9a8175 --- /dev/null +++ b/pipelinerl/domains/workarena/load_tasks.py @@ -0,0 +1,51 @@ +import random +from browsergym.core.task import AbstractBrowserTask +from browsergym.workarena import ALL_WORKARENA_TASKS, workarena_tasks_all, workarena_tasks_l1, workarena_tasks_atomic + +ALL_TASKS_DICT = {task.get_task_id(): task for task in ALL_WORKARENA_TASKS} + + +def load_tasks(dataset_names: list[str], seeds: list[int] = [0, 1, 2, 3, 4]): + all_shuffled_task_ids = list(workarena_tasks_all) + atomic_shuffled_task_ids = list(workarena_tasks_atomic) + l1_shuffled_task_ids = list(workarena_tasks_l1) + random.seed(42) + random.shuffle(all_shuffled_task_ids) + random.shuffle(atomic_shuffled_task_ids) + random.shuffle(l1_shuffled_task_ids) + tasks = [] + for name in dataset_names: + if name == "all": + tasks.extend( + [ + {"dataset": "workarena.all", "task": task_id, "seed": seed} + for task_id in all_shuffled_task_ids + for seed in seeds + ] + ) + elif name == "atomic": + tasks.extend( + [ + {"dataset": "workarena.atomic", "task": task_id, "seed": seed} + for task_id in atomic_shuffled_task_ids + for seed in seeds + ] + ) + elif name == "l1": + tasks.extend( + [ + {"dataset": "workarena.l1", "task": task_id, "seed": seed} + for task_id in l1_shuffled_task_ids + for seed in seeds + ] + ) + else: + raise ValueError(f"Invalid dataset name: {name}") + return tasks + + +def get_task_by_id(task_id: str) -> AbstractBrowserTask: + if task_id in ALL_TASKS_DICT: + return ALL_TASKS_DICT[task_id] + else: + raise ValueError(f"Task {task_id} not found") diff --git a/pipelinerl/domains/workarena/rollouts.py b/pipelinerl/domains/workarena/rollouts.py new file mode 100644 index 00000000..ec523555 --- /dev/null +++ b/pipelinerl/domains/workarena/rollouts.py @@ -0,0 +1,210 @@ +import logging +import os +import time + +from examples.rl_webagent.steps import WebTape +from examples.workarena.agent import WorkArenaAgent +from examples.workarena.environment import WorkArenaEnvironment +from hydra.utils import instantiate +from omegaconf import DictConfig +from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, Observation +from tapeagents.io import save_json_tape +from tapeagents.llms.trainable import TrainableLLM +from tapeagents.orchestrator import execute_agent + +from pipelinerl.async_llm import make_training_text +from pipelinerl.domains.workarena.load_tasks import get_task_by_id +from pipelinerl.rollouts import BaseMetrics, RolloutResult + +logger = logging.getLogger(__name__) + + +class WorkarenaMetrics(BaseMetrics): + reward: float + success: bool + no_error: bool + no_answer: bool + overflow: bool + n_llm_calls: int + n_step_errors: int + n_page_observations: int + n_steps: int + total_execution_time: float + env_start_time: float + env_close_time: float + env_agent_creation_time: float + agent_execution_time: float + environment_execution_time: float + env_step_time: float + agent_step_time: float + llm_call_time: float + env_call_time: float + total_llm_call_time: float + total_env_call_time: float + + +def tape_contains_an_error(tape: WebTape) -> bool: + """ + Returns true if the tape ends with an error, ie if one of the following is true: + - the last step is an LLMOutputParsingFailureAction + - the tape metadata has an error + - the last step is a PageObservation with an error + """ + return ( + len(tape.steps) == 0 + or isinstance(tape.steps[-1], LLMOutputParsingFailureAction) + or tape.metadata.result.get("error") is not None + or (tape.steps[-1].__class__.__name__ == "PageObservation" and tape.steps[-1].error) + ) + + +def compute_reward(tape: WebTape, success: bool, result: dict) -> float: + """ + TODO: Improve this + """ + return 1.0 if success else -1.0 + + +def generate_workarena_rollout(cfg: DictConfig, llm: TrainableLLM, problem: dict) -> RolloutResult: + # make agent and env + # set the llm + # run the agent + # get llm calls from tape + # compute rewards + # get training text from llm calls + + start_time = time.perf_counter() + + environment: WorkArenaEnvironment = instantiate(cfg.environment) + environment.initialize() + agent = WorkArenaAgent.create(llm) + logger.info(f"Agent and environment loaded, using llm {llm.model_name} at {llm.get_base_url()}") + env_agent_creation_time = time.perf_counter() - start_time + try: + task_entrypoint = get_task_by_id(problem["task"]) + start_attempts = cfg.start_attempts + t = time.perf_counter() + while True: + try: + tape, _ = environment.start_task(task_entrypoint) + break + except Exception as e: + logger.exception(f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']}: {e}") + start_attempts -= 1 + if start_attempts <= 0: + raise Exception( + f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']} after {cfg.start_attempts} attempts" + ) + else: + logger.warning("retry after 1 seconds") + time.sleep(1) + env_start_time = time.perf_counter() - t + logger.info( + f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {env_start_time:.2f} seconds" + ) + logger.info(f"Running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']}") + ex_t = time.perf_counter() + tape = execute_agent(agent, tape, environment, max_loops=cfg.agent_max_loops) + execution_time = time.perf_counter() - ex_t + success, result = environment.validate_task(tape) + finally: + close_t = time.perf_counter() + environment.close() + env_close_time = time.perf_counter() - close_t + logger.info( + f"Agent finished task {problem['dataset']}/{problem['task']}/{problem['seed']}, times: start {env_start_time:.2f} sec, exec {execution_time:.2f} sec, close {env_close_time:.2f} sec, produced tape with {len(tape.steps)} steps" + ) + total_execution_time = time.perf_counter() - t + tape.metadata.result.update( + { + "total_execution_time": total_execution_time, + "env_start_time": env_start_time, + "env_agent_creation_time": env_agent_creation_time, + "execution_time": execution_time, + "env_close_time": env_close_time, + } + ) + + # save the tape as we go + if cfg.save_tapes: + try: + tape_name = problem.get("_task_id", tape.metadata.id) + save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape_name) + except Exception as e: + logger.error(f"Error saving tape {tape_name}: {e}") + + reward = compute_reward(tape, success, result) + + # (3) Get LLM calls from Tape + llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] + n_llm_calls = len(llm_calls) + llm_calls: list[LLMCall] = [ + LLMCall(**step.metadata.other["llm_call"]) + if isinstance(step.metadata.other["llm_call"], dict) + else step.metadata.other["llm_call"] + for step in llm_calls + ] + llm_call_times = [ + step.metadata.other.get("llm_call_time") for step in tape.steps if "llm_call_time" in step.metadata.other + ] + env_call_times = [ + step.metadata.other.get("action_execution_time") + for step in tape.steps + if "action_execution_time" in step.metadata.other + ] + total_llm_call_time = sum(llm_call_times) + total_env_call_time = sum(env_call_times) + llm_call_time = total_llm_call_time / len(llm_call_times) if len(llm_call_times) > 0 else -1.0 + env_call_time = total_env_call_time / len(env_call_times) if len(env_call_times) > 0 else -1.0 + + # (4) # For each LLM interaction in the tape, make a training example. + all_finished = 1 + prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] + output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + for text in training_texts: + text.reward = reward + all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 + + latency = time.perf_counter() - start_time + agent_time = tape.metadata.result.get("agent_execution_time", -1.0) + env_time = tape.metadata.result.get("environment_execution_time", -1.0) + n_observations = len( + [s for s in tape.steps if isinstance(s, Observation)] + ) # TODO: is this not the same n_page_observations?? + n_other_steps = len(tape.steps) - n_observations + n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) + n_page_observations = len([step for step in tape.steps if step.__class__.__name__ == "PageObservation"]) + no_error = not tape_contains_an_error(tape) + metrics = WorkarenaMetrics( + reward=reward, + success=reward > 0.5, + no_error=no_error, + no_answer=reward < 0, + overflow=not all_finished, + n_llm_calls=n_llm_calls, + n_step_errors=n_step_errors, + n_page_observations=n_page_observations, + n_steps=len(tape.steps), + total_execution_time=total_execution_time, + env_start_time=env_start_time, + env_close_time=env_close_time, + env_agent_creation_time=env_agent_creation_time, + agent_execution_time=agent_time, + environment_execution_time=env_time, + env_step_time=env_time / n_observations if env_time > 0 and n_observations > 0 else -1.0, + agent_step_time=agent_time / n_other_steps if agent_time > 0 and n_other_steps > 0 else -1.0, + llm_call_time=llm_call_time, + env_call_time=env_call_time, + total_llm_call_time=total_llm_call_time, + total_env_call_time=total_env_call_time, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + prompt_tokens=prompt_tokens, + output_tokens=output_tokens, + ) diff --git a/run_workarena_actor.sh b/run_workarena_actor.sh new file mode 100755 index 00000000..6de41229 --- /dev/null +++ b/run_workarena_actor.sh @@ -0,0 +1,13 @@ +#!/bin/bash +OUTPUT_DIR="results/workarena" +DATE_ID=$(date +%Y_%m_%d__%H_%M_%S) + +python -m pipelinerl.launch \ + output_dir=${OUTPUT_DIR}/debug_${DATE_ID} \ + wandb.wandb_workspace_root=${OUTPUT_DIR} \ + actor.rollout_workers=2 \ + debug.mode=actor \ + world.actor_fraction=8 \ + world.finetune_fraction=0 \ + world.preprocessor_fraction=0 \ + --config-name workarena \ No newline at end of file