diff --git a/demo_agent/agent.py b/demo_agent/agent.py index 632c0bbc..43e9e5aa 100644 --- a/demo_agent/agent.py +++ b/demo_agent/agent.py @@ -29,6 +29,11 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): return f"data:image/jpeg;base64,{image_base64}" +def extract_code_blocks(text) -> list[tuple[str, str]]: + pattern = re.compile(r"```(\w*\n)?(.*?)```", re.DOTALL) + + matches = pattern.findall(text) + return [(match[0].strip(), match[1].strip()) for match in matches] class DemoAgent(Agent): """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities.""" @@ -322,7 +327,7 @@ def get_action(self, obs: dict) -> tuple[str, dict]: {"role": "user", "content": user_msgs}, ], ) - action = response.choices[0].message.content + action = extract_code_blocks(response.choices[0].message.content) self.action_history.append(action)