Skip to content
93 changes: 61 additions & 32 deletions examples/scripts/openenv/browsergym.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space")
parser.add_argument("--dataset-prompt", default="Complete the web task successfully.")
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--max-steps", type=int, default=10)
parser.add_argument("--max-steps", type=int, default=10, help="Max steps per episode.")
parser.add_argument("--max-completion-length", type=int, default=1024)
parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.")
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--gradient-accumulation-steps", type=int, default=32)
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument("--logging-steps", type=int, default=1)
Expand All @@ -86,28 +86,24 @@ def sanitize_name(name: str) -> str:
return name.replace("/", "-")


SYSTEM_PROMPT = """You control a web browser to complete tasks.

The page structure shows elements as: [bid] element_type 'element_text'
For example: [13] button 'Click Me!' means the element has bid='13'.

You will see a screenshot of the page after each action. Use the visual information
along with the page structure to decide your next action.

Use the available tools to interact with the page:
- click: Click an element by its bid
- fill: Fill an input field with text
- send_keys: Send keyboard input
- scroll: Scroll the page
- noop: Do nothing

Complete the given task as efficiently as possible."""
SYSTEM_PROMPT = """You are interacting with a web page. Use the available tools to complete the task."""


def reward_completion(completions, environments, **kwargs) -> list[float]:
"""Reward for task completion."""
return [env.reward for env in environments]


def reward_efficiency(completions, **kwargs) -> list[float]:
"""Penalize extra tool calls beyond the first one."""
rewards = []
for comp in completions:
n_tool_calls = sum(1 for m in comp if isinstance(m, dict) and m.get("tool_calls"))
extra_calls = max(0, n_tool_calls - 1)
rewards.append(-0.1 * extra_calls)
return rewards


def main() -> None:
args = parse_args()

Expand All @@ -134,15 +130,48 @@ def __init__(self):
self.done = False
self._step_count = 0

def _ensure_large_max_size(self):
"""Raise WebSocket max message size for large observations (screenshots + axtree).

openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library
defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent.
"""
import websockets

self.client.connect()
ws = self.client._ws
if ws is not None and ws.protocol is not None:
proto = ws.protocol
# websockets renamed max_size to max_message_size in version 16
if int(websockets.__version__.split(".")[0]) >= 16:
if proto.max_message_size == 2**20:
proto.max_message_size = 100 * 1024 * 1024
else:
if proto.max_size == 2**20:
proto.max_size = 100 * 1024 * 1024

def reset(self, **kwargs) -> str | None:
self.reward = 0.0
self.done = False
self._step_count = 0
self._ensure_large_max_size()
result = self.client.reset()
self.done = result.done
return self._format_observation(result.observation)

def click(self, bid: str) -> list:
@staticmethod
def _normalize_bid(bid) -> str:
"""Coerce bid to a clean string.

Qwen-style XML tool calls may parse the bid as int (e.g. 13) or as
a stringified list ("[13]") copied verbatim from the axtree. BrowserGym
expects a plain string id, so strip brackets and unwrap 1-element lists.
"""
if isinstance(bid, list) and len(bid) == 1:
bid = bid[0]
return str(bid).strip().strip("[]")
Comment thread
sergiopaniego marked this conversation as resolved.

def click(self, bid: str) -> list | str:
"""Click an element on the page.

Args:
Expand All @@ -151,9 +180,10 @@ def click(self, bid: str) -> list:
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"click('{bid}')")
bid = self._normalize_bid(bid)
return self._do_action(f"click({bid!r})")

def fill(self, bid: str, text: str) -> list:
def fill(self, bid: str, text: str) -> list | str:
"""Fill an input field with text.

Args:
Expand All @@ -163,9 +193,10 @@ def fill(self, bid: str, text: str) -> list:
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"fill('{bid}', '{text}')")
bid = self._normalize_bid(bid)
return self._do_action(f"fill({bid!r}, {text!r})")

def send_keys(self, text: str) -> list:
def send_keys(self, text: str) -> list | str:
"""Send keyboard input to the page.

Args:
Expand All @@ -174,9 +205,9 @@ def send_keys(self, text: str) -> list:
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"send_keys('{text}')")
return self._do_action(f"send_keys({text!r})")

def scroll(self, direction: str) -> list:
def scroll(self, direction: str) -> list | str:
"""Scroll the page.

Args:
Expand All @@ -185,19 +216,19 @@ def scroll(self, direction: str) -> list:
Returns:
The updated page observation with screenshot.
"""
return self._do_action(f"scroll('{direction}')")
return self._do_action(f"scroll({direction!r})")

def noop(self) -> list:
def noop(self) -> list | str:
"""Do nothing and observe the current page state.

Returns:
The current page observation with screenshot.
"""
return self._do_action("noop()")

def _do_action(self, action_str: str) -> list:
def _do_action(self, action_str: str) -> list | str:
if self.done:
raise ValueError("Episode is done.")
return "Episode is done. No further actions needed."

self._step_count += 1
result = self.client.step(BrowserGymAction(action_str=action_str))
Expand Down Expand Up @@ -233,15 +264,13 @@ def _format_observation_multimodal(self, observation) -> list:
"""Format observation as multimodal content blocks (screenshot + text)."""
content = []

# Add screenshot if available
if observation.screenshot is not None:
screenshot_array = np.array(observation.screenshot, dtype=np.uint8)
screenshot_image = Image.fromarray(screenshot_array)
if image_size > 0:
screenshot_image.thumbnail((image_size, image_size), Image.LANCZOS)
content.append({"type": "image", "image": screenshot_image})

# Add text observation
parts = []
if observation.goal:
parts.append(f"Goal: {observation.goal}")
Expand All @@ -263,7 +292,7 @@ def _format_observation_multimodal(self, observation) -> list:

trainer = GRPOTrainer(
model=args.model_id,
reward_funcs=reward_completion,
reward_funcs=[reward_completion, reward_efficiency],
train_dataset=dataset,
args=GRPOConfig(
use_vllm=args.use_vllm,
Expand Down
Loading
Loading