diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index 49a0b4c712f..63b28334fa1 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -67,11 +67,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space") parser.add_argument("--dataset-prompt", default="Complete the web task successfully.") parser.add_argument("--dataset-size", type=int, default=1000) - parser.add_argument("--max-steps", type=int, default=10) + parser.add_argument("--max-steps", type=int, default=10, help="Max steps per episode.") parser.add_argument("--max-completion-length", type=int, default=1024) parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.") parser.add_argument("--num-generations", type=int, default=4) - parser.add_argument("--gradient-accumulation-steps", type=int, default=32) + parser.add_argument("--gradient-accumulation-steps", type=int, default=1) parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--num-epochs", type=int, default=1) parser.add_argument("--logging-steps", type=int, default=1) @@ -86,28 +86,24 @@ def sanitize_name(name: str) -> str: return name.replace("/", "-") -SYSTEM_PROMPT = """You control a web browser to complete tasks. - -The page structure shows elements as: [bid] element_type 'element_text' -For example: [13] button 'Click Me!' means the element has bid='13'. - -You will see a screenshot of the page after each action. Use the visual information -along with the page structure to decide your next action. - -Use the available tools to interact with the page: -- click: Click an element by its bid -- fill: Fill an input field with text -- send_keys: Send keyboard input -- scroll: Scroll the page -- noop: Do nothing - -Complete the given task as efficiently as possible.""" +SYSTEM_PROMPT = """You are interacting with a web page. Use the available tools to complete the task.""" def reward_completion(completions, environments, **kwargs) -> list[float]: + """Reward for task completion.""" return [env.reward for env in environments] +def reward_efficiency(completions, **kwargs) -> list[float]: + """Penalize extra tool calls beyond the first one.""" + rewards = [] + for comp in completions: + n_tool_calls = sum(1 for m in comp if isinstance(m, dict) and m.get("tool_calls")) + extra_calls = max(0, n_tool_calls - 1) + rewards.append(-0.1 * extra_calls) + return rewards + + def main() -> None: args = parse_args() @@ -134,15 +130,48 @@ def __init__(self): self.done = False self._step_count = 0 + def _ensure_large_max_size(self): + """Raise WebSocket max message size for large observations (screenshots + axtree). + + openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library + defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent. + """ + import websockets + + self.client.connect() + ws = self.client._ws + if ws is not None and ws.protocol is not None: + proto = ws.protocol + # websockets renamed max_size to max_message_size in version 16 + if int(websockets.__version__.split(".")[0]) >= 16: + if proto.max_message_size == 2**20: + proto.max_message_size = 100 * 1024 * 1024 + else: + if proto.max_size == 2**20: + proto.max_size = 100 * 1024 * 1024 + def reset(self, **kwargs) -> str | None: self.reward = 0.0 self.done = False self._step_count = 0 + self._ensure_large_max_size() result = self.client.reset() self.done = result.done return self._format_observation(result.observation) - def click(self, bid: str) -> list: + @staticmethod + def _normalize_bid(bid) -> str: + """Coerce bid to a clean string. + + Qwen-style XML tool calls may parse the bid as int (e.g. 13) or as + a stringified list ("[13]") copied verbatim from the axtree. BrowserGym + expects a plain string id, so strip brackets and unwrap 1-element lists. + """ + if isinstance(bid, list) and len(bid) == 1: + bid = bid[0] + return str(bid).strip().strip("[]") + + def click(self, bid: str) -> list | str: """Click an element on the page. Args: @@ -151,9 +180,10 @@ def click(self, bid: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"click('{bid}')") + bid = self._normalize_bid(bid) + return self._do_action(f"click({bid!r})") - def fill(self, bid: str, text: str) -> list: + def fill(self, bid: str, text: str) -> list | str: """Fill an input field with text. Args: @@ -163,9 +193,10 @@ def fill(self, bid: str, text: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"fill('{bid}', '{text}')") + bid = self._normalize_bid(bid) + return self._do_action(f"fill({bid!r}, {text!r})") - def send_keys(self, text: str) -> list: + def send_keys(self, text: str) -> list | str: """Send keyboard input to the page. Args: @@ -174,9 +205,9 @@ def send_keys(self, text: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"send_keys('{text}')") + return self._do_action(f"send_keys({text!r})") - def scroll(self, direction: str) -> list: + def scroll(self, direction: str) -> list | str: """Scroll the page. Args: @@ -185,9 +216,9 @@ def scroll(self, direction: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"scroll('{direction}')") + return self._do_action(f"scroll({direction!r})") - def noop(self) -> list: + def noop(self) -> list | str: """Do nothing and observe the current page state. Returns: @@ -195,9 +226,9 @@ def noop(self) -> list: """ return self._do_action("noop()") - def _do_action(self, action_str: str) -> list: + def _do_action(self, action_str: str) -> list | str: if self.done: - raise ValueError("Episode is done.") + return "Episode is done. No further actions needed." self._step_count += 1 result = self.client.step(BrowserGymAction(action_str=action_str)) @@ -233,7 +264,6 @@ def _format_observation_multimodal(self, observation) -> list: """Format observation as multimodal content blocks (screenshot + text).""" content = [] - # Add screenshot if available if observation.screenshot is not None: screenshot_array = np.array(observation.screenshot, dtype=np.uint8) screenshot_image = Image.fromarray(screenshot_array) @@ -241,7 +271,6 @@ def _format_observation_multimodal(self, observation) -> list: screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS) content.append({"type": "image", "image": screenshot_image}) - # Add text observation parts = [] if observation.goal: parts.append(f"Goal: {observation.goal}") @@ -263,7 +292,7 @@ def _format_observation_multimodal(self, observation) -> list: trainer = GRPOTrainer( model=args.model_id, - reward_funcs=reward_completion, + reward_funcs=[reward_completion, reward_efficiency], train_dataset=dataset, args=GRPOConfig( use_vllm=args.use_vllm, diff --git a/examples/scripts/openenv/browsergym_llm.py b/examples/scripts/openenv/browsergym_llm.py index ae68f98e578..932ebb6f243 100644 --- a/examples/scripts/openenv/browsergym_llm.py +++ b/examples/scripts/openenv/browsergym_llm.py @@ -29,34 +29,22 @@ The environment runs on a Hugging Face Space by default. -Setup (Option A - Install from HF Space, recommended): +Setup: ```sh uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env ``` -Setup (Option B - Clone OpenEnv repo, for development): +Usage: ```sh -git clone https://github.com/meta-pytorch/OpenEnv.git -cd OpenEnv/envs/browsergym_env -uv pip install -e . -``` - -# Option 1: HF Spaces + Colocated vLLM (1 GPU required) -```sh +# HF Spaces + Colocated vLLM (1 GPU required) python examples/scripts/openenv/browsergym_llm.py --vllm-mode colocate -``` -# Option 2: HF Spaces + Separate vLLM server (2 GPUs required) - -# Spin up vLLM server (Terminal 1) -```sh +# HF Spaces + Separate vLLM server (2 GPUs required) +# Terminal 1: CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8001 -``` - -# Run training (Terminal 2) -```sh +# Terminal 2: CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym_llm.py --vllm-mode server --vllm-server-url http://localhost:8001 ``` """ @@ -75,155 +63,20 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.") - parser.add_argument( - "--model-id", - default="Qwen/Qwen3-0.6B", - help="Model identifier passed to GRPOTrainer for fine-tuning.", - ) - parser.add_argument( - "--space-url", - type=str, - default="https://openenv-browsergym-env.hf.space", - help="URL for the Hugging Face Space running the BrowserGym environment.", - ) - parser.add_argument( - "--benchmark", - default="miniwob", - help="BrowserGym benchmark to use (miniwob, webarena, etc.).", - ) - parser.add_argument( - "--task-name", - default="click-test", - help="Specific task within the benchmark (e.g., click-test, click-button).", - ) - parser.add_argument( - "--dataset-prompt", - default="Complete the web task successfully.", - help="Prompt text used to seed the training dataset.", - ) - parser.add_argument( - "--dataset-size", - type=int, - default=1000, - help="Number of entries to include in the synthetic training dataset.", - ) - parser.add_argument( - "--max-steps", - type=int, - default=10, - help="Maximum number of steps per episode.", - ) - parser.add_argument( - "--max-completion-length", - type=int, - default=1024, - help="Maximum completion length in tokens for tool-calling generation.", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Sampling temperature used during rollout generation.", - ) - parser.add_argument( - "--top-k", - type=int, - default=50, - help="Top-k sampling parameter forwarded to vLLM.", - ) - parser.add_argument( - "--top-p", - type=float, - default=None, - help="Optional top-p sampling parameter forwarded to vLLM.", - ) - parser.add_argument( - "--learning-rate", - type=float, - default=5e-6, - help="Learning rate for GRPO training.", - ) - parser.add_argument( - "--weight-decay", - type=float, - default=0.0, - help="Weight decay applied during optimization.", - ) - parser.add_argument( - "--gradient-accumulation-steps", - type=int, - default=32, - help="Gradient accumulation steps for GRPO training.", - ) - parser.add_argument( - "--warmup-steps", - type=int, - default=10, - help="Warmup steps for the scheduler.", - ) - parser.add_argument( - "--per-device-batch-size", - type=int, - default=1, - help="Per-device train batch size.", - ) - parser.add_argument( - "--num-generations", - type=int, - default=4, - help="Number of rollout generations per dataset prompt.", - ) - parser.add_argument( - "--num-epochs", - type=int, - default=1, - help="Number of training epochs.", - ) - parser.add_argument( - "--save-interval", - type=int, - default=50, - help="Interval (in steps) between checkpoint saves.", - ) - parser.add_argument( - "--save-total-limit", - type=int, - default=None, - help="Maximum number of checkpoints to keep.", - ) - parser.add_argument( - "--output-dir", - default=None, - help="Directory where training outputs and checkpoints are stored.", - ) - parser.add_argument( - "--run-name", - default=None, - help="Optional run name for logging systems.", - ) - parser.add_argument( - "--project", - default=None, - help="Optional project identifier for logging systems.", - ) - parser.add_argument( - "--vllm-mode", - choices=("colocate", "server"), - default="colocate", - help="vLLM execution mode: 'colocate' or 'server'.", - ) - parser.add_argument( - "--vllm-server-url", - type=str, - default="http://localhost:8001", - help="URL for the vLLM server (only used when --vllm-mode=server).", - ) - parser.add_argument( - "--logging-steps", - type=int, - default=1, - help="Frequency of logging steps for GRPO training.", - ) + parser.add_argument("--model-id", default="Qwen/Qwen3-0.6B") + parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space") + parser.add_argument("--dataset-prompt", default="Complete the web task successfully.") + parser.add_argument("--dataset-size", type=int, default=1000) + parser.add_argument("--max-steps", type=int, default=10, help="Max steps per episode.") + parser.add_argument("--max-completion-length", type=int, default=1024) + parser.add_argument("--learning-rate", type=float, default=5e-6) + parser.add_argument("--gradient-accumulation-steps", type=int, default=1) + parser.add_argument("--num-generations", type=int, default=4) + parser.add_argument("--num-epochs", type=int, default=1) + parser.add_argument("--logging-steps", type=int, default=1) + parser.add_argument("--output-dir", default=None) + parser.add_argument("--vllm-mode", choices=("colocate", "server"), default="colocate") + parser.add_argument("--vllm-server-url", default="http://localhost:8001") return parser.parse_args() @@ -231,40 +84,14 @@ def sanitize_name(name: str) -> str: return name.replace("/", "-") -# --------------------------------------------------------------------------- -# System Prompt -# --------------------------------------------------------------------------- - -SYSTEM_PROMPT = """You control a web browser to complete tasks. - -The page structure shows elements as: [bid] element_type 'element_text' -For example: [13] button 'Click Me!' means the element has bid='13'. +SYSTEM_PROMPT = """You are interacting with a web page. Elements have numeric IDs shown in brackets like [13]. Use the available tools to complete the task.""" -Use the available tools to interact with the page: -- click: Click an element by its bid -- fill: Fill an input field with text -- send_keys: Send keyboard input -- scroll: Scroll the page -- noop: Do nothing -Complete the given task as efficiently as possible.""" - - -# --------------------------------------------------------------------------- -# Reward -# --------------------------------------------------------------------------- - - -def reward_completion(environments, **kwargs) -> list[float]: +def reward_completion(completions, environments, **kwargs) -> list[float]: """Reward for task completion.""" return [env.reward for env in environments] -# --------------------------------------------------------------------------- -# Main entrypoint -# --------------------------------------------------------------------------- - - def main() -> None: args = parse_args() @@ -296,14 +123,19 @@ def _ensure_large_max_size(self): openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent. """ + import websockets + self.client.connect() ws = self.client._ws - if ws is not None and hasattr(ws, "protocol"): + if ws is not None and ws.protocol is not None: proto = ws.protocol - # websockets <16: max_size; websockets >=16: max_message_size - attr = "max_size" if hasattr(proto, "max_size") else "max_message_size" - if getattr(proto, attr) == 2**20: - setattr(proto, attr, 100 * 1024 * 1024) + # websockets renamed max_size to max_message_size in version 16 + if int(websockets.__version__.split(".")[0]) >= 16: + if proto.max_message_size == 2**20: + proto.max_message_size = 100 * 1024 * 1024 + else: + if proto.max_size == 2**20: + proto.max_size = 100 * 1024 * 1024 def reset(self, **kwargs) -> str: self.reward = 0.0 @@ -369,7 +201,7 @@ def noop(self) -> str: def _do_action(self, action_str: str) -> str: if self._done: - raise ValueError("Episode is done.") + return "Episode is done. No further actions needed." self._step_count += 1 result = self.client.step(BrowserGymAction(action_str=action_str)) @@ -377,7 +209,6 @@ def _do_action(self, action_str: str) -> str: step_reward = float(result.reward or 0.0) self._done = result.done - # Reward shaping: binary success/failure on completion if self._done and step_reward > 0: self.reward = 1.0 elif self._done: @@ -385,7 +216,6 @@ def _do_action(self, action_str: str) -> str: else: self.reward = step_reward - # Enforce max steps if self._step_count >= max_steps: self._done = True @@ -408,54 +238,29 @@ def _format_observation(self, observation) -> str: default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}" output_dir = Path(args.output_dir or default_output_dir) - grpo_config = GRPOConfig( - use_vllm=True, - vllm_mode=args.vllm_mode, - vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, - vllm_gpu_memory_utilization=0.4, - output_dir=str(output_dir), - num_train_epochs=args.num_epochs, - learning_rate=args.learning_rate, - weight_decay=args.weight_decay, - gradient_accumulation_steps=args.gradient_accumulation_steps, - per_device_train_batch_size=args.per_device_batch_size, - warmup_steps=args.warmup_steps, - num_generations=args.num_generations, - generation_batch_size=args.num_generations, - max_completion_length=args.max_completion_length, - logging_steps=args.logging_steps, - report_to="trackio", - trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}", - save_strategy="steps", - save_steps=args.save_interval, - save_total_limit=args.save_total_limit, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - chat_template_kwargs={"enable_thinking": False}, - ) - - grpo_config.run_name = args.run_name or f"run-{timestamp}" - grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" - trainer = GRPOTrainer( model=args.model_id, - reward_funcs=[reward_completion], + reward_funcs=reward_completion, train_dataset=dataset, - args=grpo_config, + args=GRPOConfig( + use_vllm=True, + vllm_mode=args.vllm_mode, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, + vllm_gpu_memory_utilization=0.4, + output_dir=str(output_dir), + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_generations=args.num_generations, + max_completion_length=args.max_completion_length, + logging_steps=args.logging_steps, + log_completions=True, + report_to="trackio", + trackio_space_id=f"browsergym-llm-grpo-{sanitize_name(args.model_id)}", + chat_template_kwargs={"enable_thinking": False}, + ), environment_factory=BrowserGymLLMEnv, ) - - print("=" * 80) - print("Starting GRPO training with BrowserGym environment (LLM mode)") - print(f"Benchmark: {args.benchmark}") - print(f"Task: {args.task_name}") - print(f"Model: {args.model_id}") - print("Mode: LLM (text-only, using accessibility tree)") - print(f"Using {args.num_generations} rollouts per dataset prompt") - print(f"Output directory: {output_dir}") - print("=" * 80) - trainer.train()