diff --git a/.env.example b/.env.example
index c9bb6809..78b4a36d 100644
--- a/.env.example
+++ b/.env.example
@@ -15,6 +15,9 @@ HF_TOKEN=YOUR_HUGGINGFACE_TOKEN
OPENPIPE_API_KEY=YOUR_OPENPIPE_API_KEY
# Optional, Together API key (used for deploying models to Together)
TOGETHER_API_KEY=YOUR_TOGETHER_API_KEY
+# Optional, OpenRouter API key (used for benchmarking other models)
+OPENROUTER_API_KEY=YOUR_OPENROUTER_API_KEY
+
# Optional, S3 configuration for log and model backups
AWS_ACCESS_KEY_ID=YOUR_AWS_ACCESS_KEY_ID
diff --git a/examples/wikihop/judge_group.py b/examples/wikihop/judge_group.py
new file mode 100644
index 00000000..f9997e99
--- /dev/null
+++ b/examples/wikihop/judge_group.py
@@ -0,0 +1,127 @@
+from openai.types.chat import ChatCompletionMessageParam
+import tenacity
+import json
+import os
+from textwrap import dedent
+import art
+from typing import List
+from openai import AsyncOpenAI
+from pydantic import BaseModel, Field
+from litellm import acompletion
+
+judge_client = AsyncOpenAI(
+ api_key=os.getenv("OPENROUTER_API_KEY"),
+ base_url="https://openrouter.ai/api/v1",
+)
+
+
+class JudgeGroupScore(BaseModel):
+ rollout_id: str = Field(description="The id of the rollout being scored.")
+ explanation: str = Field(
+ description="A short explanation of why you gave this score."
+ )
+ score: float = Field(description="A score between 0 and 1.")
+
+
+class JudgeGroupResponse(BaseModel):
+ scores: List[JudgeGroupScore]
+
+
+@tenacity.retry(stop=tenacity.stop_after_attempt(10))
+async def judge_group(
+ _model_name: str, # Just included for observability
+ trajectories: list[art.Trajectory],
+ judge_model_name: str = "openai/o4-mini",
+ *,
+ debug: bool = False,
+) -> list[art.Trajectory]:
+ """Judge a list of trajectories with an LLM-as-a-judge.
+
+ This keeps the original trajectories but overwrites ``reward`` with the
+ score returned by the judge (0–1). The original reward is preserved in
+ ``traj.metrics['independent_reward']`` and the new score is written to
+ ``traj.metrics['judge_group_reward']``.
+ """
+
+ # Serialize each rollout's messages (keeping tool_calls as-is)
+ serialized_rollouts: List[str] = []
+ # Keep structured messages for nicer debug printing
+ debug_rollouts: List[tuple[int, list]] = [] if debug else []
+ for idx, traj in enumerate(trajectories, start=1):
+ # Save the original reward
+ traj.metrics["independent_reward"] = traj.reward
+ # Flatten messages to regular OpenAI format (role/content/…)
+ messages = traj.messages()
+ if debug:
+ debug_rollouts.append((idx, messages))
+ serialized_rollouts.append(
+ f'\n' + json.dumps(messages) + "\n"
+ )
+
+ if debug:
+ print("\n[judge_group] Serialized rollouts (pretty JSON):")
+ for idx, msg_list in debug_rollouts:
+ print(f"\nRollout {idx}:")
+ print(json.dumps(msg_list, indent=2, ensure_ascii=False))
+
+ print("\n[judge_group] Rollout metrics:")
+ for idx, traj in enumerate(trajectories, start=1):
+ print(f"\nRollout {idx} metrics:")
+ print(json.dumps(traj.metrics, indent=2, ensure_ascii=False))
+
+ rubric_text = dedent(
+ """
+ All of the rollouts below have been given the same goal. Your job is to consider each of them and give them a score between 0 and 1. Take into consideration your best judgement of the agent's goal.
+
+ Grading standards:
+ - A rollout that achieves its goal should always get a significantly higher score than a rollout that does not achieve its goal.
+ - A rollout that achieves its goal more efficiently (eg. by avoiding unproductive detours) should get a higher score than a rollout that achieves its goal less efficiently.
+ - If one rollout is only slightly better than another, the difference in scores should be small. If it is significantly better, the difference in scores should be large.
+ - You may give some partial credit for a rollout that makes progress towards its goal but does not complete it.
+ """
+ )
+
+ user_text = "Rollouts:\n\n" + "\n\n".join(serialized_rollouts)
+
+ # Decide which LLM should act as the judge. TrainingConfig now carries
+ # a `judge_group_model_name` with a default of "openai/o3" so existing
+ # runs do not have to set anything. If `training_config` is None, we also
+ # fall back to "openai/o3".
+
+ messages: list[ChatCompletionMessageParam] = [
+ {"role": "system", "content": rubric_text},
+ {"role": "user", "content": user_text},
+ ]
+
+ response = await acompletion(
+ model=judge_model_name,
+ messages=messages,
+ response_format=JudgeGroupResponse,
+ caching=True,
+ )
+
+ first_choice = response.choices[0] # type: ignore[attr-defined]
+
+ if debug:
+ raw_content = first_choice.message.content or "{}" # type: ignore[attr-defined]
+ print("\n[judge_group] Raw LLM choice content:")
+ print(raw_content)
+
+ try:
+ print("\n[judge_group] Pretty-printed LLM choice JSON:")
+ print(json.dumps(json.loads(raw_content), indent=2, ensure_ascii=False))
+ except json.JSONDecodeError as e:
+ print(f"[judge_group] Could not parse choice content as JSON: {e}")
+
+ content = first_choice.message.content or "{}" # type: ignore[attr-defined]
+ parsed = JudgeGroupResponse.model_validate_json(content)
+ assert len(parsed.scores) == len(trajectories)
+
+ for idx, (traj, score) in enumerate(zip(trajectories, parsed.scores)):
+ traj.metrics["judge_group_reward"] = score.score
+ traj.reward = score.score
+ if traj.metrics.get("failed_format_validation", 0) > 0:
+ traj.reward = 0
+ traj.metadata["judge_group_explanation"] = score.explanation
+
+ return trajectories
diff --git a/examples/wikihop/wikihop.ipynb b/examples/wikihop/wikihop.ipynb
new file mode 100644
index 00000000..ad91746d
--- /dev/null
+++ b/examples/wikihop/wikihop.ipynb
@@ -0,0 +1,896 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To train this agent, click _Runtime_ and press _Run all_. Make sure you've enabled a free Tesla T4 GPU!\n",
+ "\n",
+ "
\n",
+ "

\n",
+ "

\n",
+ "

\n",
+ "\n",
+ "Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [Github](https://github.com/openpipe/art).\n",
+ "\n",
+ "
\n",
+ "\n",
+ "
\n",
+ "\n",
+ "This notebook shows how to train a Qwen 2.5 7B model to navigate Wikipedia. It will demonstrate how to set up a multi-turn agent that learns to hop between Wikipedia pages by selecting the best links to reach target pages.\n",
+ "\n",
+ "Completions will be logged to OpenPipe, and metrics will be logged to Weights & Biases.\n",
+ "\n",
+ "You will learn how to construct an [agentic environment](#Environment), how to define a [rollout](#Rollout), and how to run a [training loop](#Loop).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[2mUsing Python 3.10.13 environment at: /root/sky_workdir/.venv\u001b[0m\n",
+ "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 10ms\u001b[0m\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "!uv pip install \"numpy<2.0.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### WARNING:\n",
+ "\n",
+ "If you are running in Google Colab and installing numpy does not say \"Requirement already satisfied: numpy<2.0.0\" then click \"Runtime\" and \"Restart Session.\"\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Numpy version is 1.*.*, you're good to go!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# make sure we're using numpy 1.*.*\n",
+ "import numpy as np\n",
+ "\n",
+ "if (np.__version__).startswith(\"1.\"):\n",
+ " print(\"Numpy version is 1.*.*, you're good to go!\")\n",
+ "else:\n",
+ " raise ValueError(\"Please restart your runtime using the above instructions!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Environment Variables\n",
+ "\n",
+ "Later on in the notebook, we'll be creating a model that can automatically logs metrics to Weights & Biases. In order to do so, you'll need to provide your Weights & Biases API key as an environment variable.\n",
+ "\n",
+ "You can also optionally initiate an OpenPipe client to report completions to a [dashboard](https://app.openpipe.ai) to get a feel for what the completions your model is generating look like, and how they change over time. Logging to OpenPipe is free, but is not required for training!\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "\n",
+ "# Optional\n",
+ "WANDB_API_KEY = \"\"\n",
+ "if WANDB_API_KEY:\n",
+ " os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
+ "\n",
+ "# Optional\n",
+ "OPENPIPE_API_KEY = \"\"\n",
+ "if OPENPIPE_API_KEY:\n",
+ " os.environ[\"OPENPIPE_API_KEY\"] = OPENPIPE_API_KEY"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Installation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "!uv pip install openpipe-art==0.3.11 openpipe accelerate==1.7.0 requests beautifulsoup4 --prerelease allow --no-cache-dir"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Agentic Environment\n",
+ "\n",
+ "\n",
+ "\n",
+ "ART allows your agent to learn by interacting with its environment. In this example, we'll create an environment in which the agent navigates Wikipedia by selecting links to reach target pages.\n",
+ "\n",
+ "The agent starts at the Philosophy Wikipedia page and must navigate to various target pages by selecting the best links from the first paragraph of each page. The environment functions handle Wikipedia scraping, link selection, and target matching.\n",
+ "\n",
+ "Feel free to read as much or as little of this section's code as you'd like. The important thing to understand is that we're defining the rules of this agent's environment.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import re\n",
+ "import xml.etree.ElementTree as ET\n",
+ "from typing import List\n",
+ "import requests\n",
+ "from bs4 import BeautifulSoup\n",
+ "\n",
+ "\n",
+ "def read_first_paragraphs_with_links(url: str, num_paragraphs: int = 3) -> str:\n",
+ " \"\"\"Given a link to a wikipedia page, read the first paragraphs with links.\"\"\"\n",
+ " try:\n",
+ " response = requests.get(url, timeout=10)\n",
+ " response.raise_for_status()\n",
+ " \n",
+ " soup = BeautifulSoup(response.content, 'html.parser')\n",
+ " \n",
+ " # Find the main content div\n",
+ " content = soup.find('div', {'id': 'mw-content-text'})\n",
+ " if not content:\n",
+ " return \"Could not find main content\"\n",
+ "\n",
+ " num_good_paragraphs = 0\n",
+ " parsed_paragraphs = []\n",
+ " # also count ordered and unordered lists\n",
+ " paragraphs = content.find_all(['p', 'ol', 'ul'])\n",
+ " paragraph_index = 0\n",
+ "\n",
+ " while num_good_paragraphs < num_paragraphs and paragraph_index < len(paragraphs):\n",
+ " p = paragraphs[paragraph_index]\n",
+ " text = p.get_text().strip()\n",
+ " # Get the paragraph with links preserved\n",
+ " paragraph_html = str(p)\n",
+ " parsed_paragraphs.append(paragraph_html)\n",
+ " if len(text) > 50 and p.find('a'): # Must have substantial content and a link\n",
+ " num_good_paragraphs += 1\n",
+ " paragraph_index += 1\n",
+ " \n",
+ " if len(parsed_paragraphs) == 0:\n",
+ " raise Exception(\"No substantial first paragraphs found\")\n",
+ " \n",
+ " return \"\\n\\n\".join(parsed_paragraphs)\n",
+ " \n",
+ " except Exception as e:\n",
+ " return f\"Error reading page: {str(e)}\"\n",
+ "\n",
+ "\n",
+ "def extract_urls(first_paragraph: str) -> List[str]:\n",
+ " \"\"\"Extract Wikipedia URLs from the first paragraph HTML.\"\"\"\n",
+ " soup = BeautifulSoup(first_paragraph, 'html.parser')\n",
+ " links = []\n",
+ " for a_tag in soup.find_all('a', href=True):\n",
+ " href = a_tag['href']\n",
+ " text = a_tag.get_text().strip()\n",
+ " if href.startswith('/wiki/') and ':' not in href and text:\n",
+ " full_url = f\"https://en.wikipedia.org{href}\"\n",
+ " links.append(full_url)\n",
+ " return links\n",
+ "\n",
+ "\n",
+ "\n",
+ "def check_target_match(current_url: str, target_page: str) -> bool:\n",
+ " \"\"\"Given the current url, determine whether it is a match for the target page.\"\"\"\n",
+ " # Normalize URLs for comparison\n",
+ " current_url = current_url.strip().rstrip('/')\n",
+ " target_page = target_page.strip().rstrip('/')\n",
+ " \n",
+ " # Direct match\n",
+ " if current_url == target_page:\n",
+ " return True\n",
+ " \n",
+ " # Extract the page title from both URLs\n",
+ " def extract_page_title(url):\n",
+ " if '/wiki/' in url:\n",
+ " return url.split('/wiki/')[-1]\n",
+ " return url\n",
+ " \n",
+ " current_title = extract_page_title(current_url)\n",
+ " target_title = extract_page_title(target_page)\n",
+ "\n",
+ " \n",
+ " return current_title == target_title\n",
+ "\n",
+ "\n",
+ "# Target URLs for training scenarios\n",
+ "TARGET_URLS = [\n",
+ " \"https://en.wikipedia.org/wiki/Unsupervised_learning\",\n",
+ " \"https://en.wikipedia.org/wiki/Exploration%E2%80%93exploitation_dilemma\",\n",
+ " \"https://en.wikipedia.org/wiki/Markov_decision_process\",\n",
+ " \"https://en.wikipedia.org/wiki/Autoencoder\",\n",
+ " \"https://en.wikipedia.org/wiki/Gradient_descent\",\n",
+ " \"https://en.wikipedia.org/wiki/Data_compression\",\n",
+ " \"https://en.wikipedia.org/wiki/Barcode\",\n",
+ " \"https://en.wikipedia.org/wiki/Edge_detection\",\n",
+ " \"https://en.wikipedia.org/wiki/Evolutionary_algorithm\",\n",
+ " \"https://en.wikipedia.org/wiki/Fitness_function\",\n",
+ " \"https://en.wikipedia.org/wiki/Planning\",\n",
+ " \"https://en.wikipedia.org/wiki/Forecasting\",\n",
+ " \"https://en.wikipedia.org/wiki/Agentic_AI\",\n",
+ " \"https://en.wikipedia.org/wiki/Outer_space\",\n",
+ " \"https://en.wikipedia.org/wiki/Knowledge_representation_and_reasoning\",\n",
+ " \"https://en.wikipedia.org/wiki/Particle_physics\",\n",
+ " \"https://en.wikipedia.org/wiki/Hadron\",\n",
+ " \"https://en.wikipedia.org/wiki/Ancient_Greek\",\n",
+ " \"https://en.wikipedia.org/wiki/Logic\",\n",
+ " \"https://en.wikipedia.org/wiki/Renaissance\",\n",
+ " \"https://en.wikipedia.org/wiki/Action_selection\",\n",
+ " \"https://en.wikipedia.org/wiki/French_Revolution\",\n",
+ " \"https://en.wikipedia.org/wiki/Golden_Rule\"\n",
+ "]\n",
+ "\n",
+ "STARTING_URL = \"https://en.wikipedia.org/wiki/Reinforcement_learning\"\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['https://en.wikipedia.org/wiki/Cooking', 'https://en.wikipedia.org/wiki/Arabic_language', 'https://en.wikipedia.org/wiki/Latin_language', 'https://en.wikipedia.org/wiki/Viscous', 'https://en.wikipedia.org/wiki/Solution_(chemistry)', 'https://en.wikipedia.org/wiki/Sugar', 'https://en.wikipedia.org/wiki/Crystal', 'https://en.wikipedia.org/wiki/Molasses', 'https://en.wikipedia.org/wiki/Hydrogen_bond', 'https://en.wikipedia.org/wiki/Hydroxyl', 'https://en.wikipedia.org/wiki/Agave_nectar', 'https://en.wikipedia.org/wiki/Agave', 'https://en.wikipedia.org/wiki/Cane_syrup', 'https://en.wikipedia.org/wiki/Chocolate_syrup', 'https://en.wikipedia.org/wiki/Corn_syrup', 'https://en.wikipedia.org/wiki/Glucose_syrup', 'https://en.wikipedia.org/wiki/Golden_syrup', 'https://en.wikipedia.org/wiki/Sugar', 'https://en.wikipedia.org/wiki/High_fructose_corn_syrup', 'https://en.wikipedia.org/wiki/Maple_syrup', 'https://en.wikipedia.org/wiki/Table_syrup', 'https://en.wikipedia.org/wiki/Mixed_drink']\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(extract_urls(read_first_paragraphs_with_links(\"https://en.wikipedia.org/wiki/Syrup\")))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Creating a Model\n",
+ "\n",
+ "Now that we've defined the rules of our environment, we can create a model that will learn to navigate Wikipedia. We'll use a Qwen 2.5 3B model for this example. The `name` parameter will be associated with a wandb run, and the `base_model` parameter is the model that we'll be training a LoRA on top of."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "OpenPipe client initialized\n"
+ ]
+ }
+ ],
+ "source": [
+ "import art\n",
+ "from dotenv import load_dotenv\n",
+ "\n",
+ "from openpipe.client import OpenPipe\n",
+ "from art.local import LocalBackend\n",
+ "\n",
+ "load_dotenv()\n",
+ "\n",
+ "op_client = OpenPipe()\n",
+ "print(\"OpenPipe client initialized\")\n",
+ "\n",
+ "random.seed(42)\n",
+ "\n",
+ "backend = LocalBackend(path=\"./.art\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Model Registration\n",
+ "\n",
+ "Now we'll register our model with the ART backend. This creates the infrastructure needed for training and tracks our model's progress through the training steps.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "TEST_MODE = False\n",
+ "\n",
+ "\n",
+ "if TEST_MODE:\n",
+ " model_or = art.Model(\n",
+ " name=\"closed_or\",\n",
+ " project=\"wikihop-navigation\",\n",
+ " inference_base_url=\"https://openrouter.ai/api/v1\",\n",
+ " inference_api_key=os.getenv(\"OPENROUTER_API_KEY\"),\n",
+ " inference_model_name=\"openai/gpt-4.1\",\n",
+ " )\n",
+ " model = art.Model(\n",
+ " name=\"closed\",\n",
+ " project=\"wikihop-navigation\",\n",
+ " inference_base_url=\"https://api.openai.com/v1\",\n",
+ " inference_api_key=os.getenv(\"OPENAI_API_KEY\"),\n",
+ " inference_model_name=\"gpt-4.1\",\n",
+ " )\n",
+ "else:\n",
+ "\n",
+ " model = art.TrainableModel(\n",
+ " name=\"007-wikihop\", project=\"wikihop-navigation\", base_model=\"Qwen/Qwen2.5-7B-Instruct\"\n",
+ " )\n",
+ " await model.register(backend)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Defining a Rollout\n",
+ "\n",
+ "\n",
+ "\n",
+ "A rollout is a single episode of an agent performing its task. It generates one or more trajectories, which are lists of messages and choices.\n",
+ "\n",
+ "In this example, the rollout function starts the agent at the Philosophy Wikipedia page and gives it a target page to reach. The agent navigates by selecting links from the first paragraph of each page until it either reaches the target or makes too many moves.\n",
+ "\n",
+ "When the navigation is finished, the `reward` for the agent's performance is calculated based on whether it successfully reached the target, how many steps it took, and whether it made any errors.\n",
+ "\n",
+ "This rollout function will be called many times in parallel during each step of the training loop."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import art\n",
+ "import openai\n",
+ "import time\n",
+ "from pydantic import BaseModel\n",
+ "\n",
+ "class WikihopScenario(BaseModel):\n",
+ " step: int\n",
+ " target_url: str\n",
+ "\n",
+ "max_hops = 20\n",
+ "\n",
+ "\n",
+ "@art.retry(exceptions=(openai.LengthFinishReasonError,))\n",
+ "async def rollout(\n",
+ " model: art.Model, scenario: WikihopScenario\n",
+ ") -> art.Trajectory:\n",
+ " current_url = STARTING_URL\n",
+ " target_url = scenario.target_url\n",
+ " \n",
+ " trajectory = art.Trajectory(\n",
+ " messages_and_choices=[\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": f\"You are a Wikipedia navigator. Your goal is to reach the target page: {target_url}\\n\\nYou will be shown the first few paragraphs of each Wikipedia page. Select the link from these paragraphs that is most likely to get you closer to your target. If you see a direct link to your target, choose it immediately.\\n\\nRespond with ONLY the full URL in this exact format: https://en.wikipedia.org/wiki/YourChoice. You must always choose a link from the list of available links.\",\n",
+ " }\n",
+ " ],\n",
+ " reward=0,\n",
+ " )\n",
+ "\n",
+ " hop_number = 0\n",
+ " \n",
+ " while hop_number < max_hops:\n",
+ " # Check if we've reached the target\n",
+ " if check_target_match(current_url, target_url):\n",
+ " print(\"reached target\")\n",
+ " trajectory.reward = max_hops - hop_number # Negative of total hops used\n",
+ " trajectory.metrics[\"success\"] = 1\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " break\n",
+ " \n",
+ " # Read the first paragraph of the current page\n",
+ " try:\n",
+ " first_paragraph = read_first_paragraphs_with_links(current_url)\n",
+ " if first_paragraph.startswith(\"Error\"):\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"page_read_error\"\n",
+ " break\n",
+ " except Exception as e:\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"page_read_exception\"\n",
+ " break\n",
+ "\n",
+ " # Extract links from the first paragraph\n",
+ " links = extract_urls(first_paragraph)\n",
+ " \n",
+ " if not links:\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"no_valid_links\"\n",
+ " break\n",
+ "\n",
+ " links_text = \"\\n\".join(links)\n",
+ "\n",
+ " # Add the current page content to the trajectory\n",
+ " page_title = current_url.split('/wiki/')[-1].replace('_', ' ')\n",
+ " trajectory.messages_and_choices.append({\n",
+ " \"role\": \"user\", \n",
+ " \"content\": f\"Current page: {page_title}\\nHop {hop_number + 1}/{max_hops}\\n\\nChoose from one of the following available links:\\n{links_text}\"\n",
+ " })\n",
+ "\n",
+ " requested_at = int(time.time() * 1000)\n",
+ "\n",
+ " try:\n",
+ " \n",
+ " \n",
+ " # Get the model's choice of next URL\n",
+ " client = model.openai_client()\n",
+ "\n",
+ " chat_completion = None\n",
+ " chat_completion = await client.chat.completions.create(\n",
+ " model=model.get_inference_name(),\n",
+ " messages=trajectory.messages(),\n",
+ " max_completion_tokens=2000,\n",
+ " )\n",
+ "\n",
+ " \n",
+ " response = chat_completion.choices[0].message.content\n",
+ " if not response:\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"empty_model_response\"\n",
+ " break\n",
+ " \n",
+ " # Parse the XML response to get the selected URL\n",
+ " try:\n",
+ " root = ET.fromstring(response)\n",
+ " next_url = root.text\n",
+ " if not next_url:\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"empty_url_in_response\"\n",
+ " break\n",
+ " next_url = next_url.strip()\n",
+ " except ET.ParseError:\n",
+ " # Try to extract URL with regex as fallback\n",
+ " url_pattern = r'https://en\\.wikipedia\\.org/wiki/[^\\s<>]+'\n",
+ " matches = re.findall(url_pattern, response)\n",
+ " if matches:\n",
+ " next_url = matches[0]\n",
+ " else:\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"could_not_parse_url\"\n",
+ " break\n",
+ "\n",
+ " print(target_url.split('/wiki/')[-1], next_url)\n",
+ " \n",
+ " # Validate that the selected URL is actually in the first paragraph\n",
+ " if next_url not in links:\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"selected_url_not_in_text\"\n",
+ " break\n",
+ " \n",
+ " except openai.LengthFinishReasonError as e:\n",
+ " raise e\n",
+ " except Exception as e:\n",
+ " print(f\"caught exception during URL selection: {e}\")\n",
+ " trajectory.reward = hop_number - max_hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"url_selection_error\"\n",
+ " break\n",
+ "\n",
+ " # Record the choice\n",
+ " trajectory.messages_and_choices.append(chat_completion.choices[0])\n",
+ "\n",
+ " # Move to the next page\n",
+ " current_url = next_url\n",
+ " hop_number += 1\n",
+ "\n",
+ " # If we ran out of hops without reaching the target\n",
+ " if hop_number >= max_hops and not check_target_match(current_url, target_url):\n",
+ " trajectory.reward = hop_number - max_hops # Error penalty for not reaching target, but reward good initial hops\n",
+ " trajectory.metrics[\"success\"] = 0\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ " trajectory.metadata[\"error\"] = \"max_hops_exceeded\"\n",
+ " \n",
+ " # Final metrics\n",
+ " if \"success\" not in trajectory.metrics:\n",
+ " trajectory.metrics[\"success\"] = 1 if check_target_match(current_url, target_url) else 0\n",
+ " if \"hops\" not in trajectory.metrics:\n",
+ " trajectory.metrics[\"hops\"] = hop_number\n",
+ "\n",
+ " if op_client.api_key:\n",
+ " messages = trajectory.messages()\n",
+ " if messages[-1][\"role\"] == \"assistant\":\n",
+ " messages = messages[:-1]\n",
+ "\n",
+ " try:\n",
+ " op_client.report(\n",
+ " requested_at=requested_at,\n",
+ " received_at=int(time.time() * 1000),\n",
+ " req_payload={\n",
+ " \"model\": model.name,\n",
+ " \"messages\": messages,\n",
+ " \"metadata\": {\n",
+ " \"notebook-id\": \"wikihop\",\n",
+ " \"step\": str(scenario.step),\n",
+ " \"final_hops\": str(hop_number),\n",
+ " \"success\": str(trajectory.metrics[\"success\"]),\n",
+ " \"reward\": str(trajectory.reward),\n",
+ " \"target_url\": target_url,\n",
+ " \"final_url\": current_url,\n",
+ " \"error\": trajectory.metadata[\"error\"] if \"error\" in trajectory.metadata else \"none\",\n",
+ " },\n",
+ " },\n",
+ " resp_payload=chat_completion,\n",
+ " status_code=200,\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(f\"Error reporting to OpenPipe: {e}\")\n",
+ "\n",
+ " return trajectory\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "if TEST_MODE:\n",
+ " await rollout(model, WikihopScenario(step=0, target_url=\"https://en.wikipedia.org/wiki/Particle_physics\"))\n",
+ " raise Exception(\"stopping early for test mode\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "### Training Loop\n",
+ "\n",
+ "The training loop is where the magic happens. For each of the 50 steps defined below, the rollout function will be called 48 times in parallel with different target Wikipedia pages. This means that 48 Wikipedia navigation tasks will be performed at once, each with a randomly selected target page.\n",
+ "\n",
+ "The `gather` step will wait for all of the trajectories to be generated, then it will delete all but the most recent checkpoint and train the model on the new trajectories.\n",
+ "\n",
+ "Inference will be blocked until the training is complete.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import asyncio\n",
+ "import judge_group\n",
+ "import importlib\n",
+ "\n",
+ "# refresh judge_group.py\n",
+ "importlib.reload(judge_group)\n",
+ "\n",
+ "\n",
+ "batch_size = 3\n",
+ "\n",
+ "for i in range(await model.get_step(), 50):\n",
+ "\n",
+ " batch_start_idx = i * batch_size % len(TARGET_URLS)\n",
+ " batch_end_idx = (i + 1) * batch_size % len(TARGET_URLS)\n",
+ "\n",
+ " batch_urls = TARGET_URLS[batch_start_idx:batch_end_idx]\n",
+ "\n",
+ " \n",
+ " train_groups = await art.gather_trajectory_groups(\n",
+ " (\n",
+ " art.TrajectoryGroup(\n",
+ " rollout(model, WikihopScenario(step=i, target_url=target_url)) for _ in range(8)\n",
+ " )\n",
+ " for target_url in batch_urls\n",
+ " ),\n",
+ " pbar_desc=\"gather\",\n",
+ " )\n",
+ " # judge simultaneously\n",
+ " judge_promises = []\n",
+ " for group in train_groups:\n",
+ " judge_promises.append(judge_group.judge_group(_model_name=model.name, trajectories=group.trajectories, debug=True))\n",
+ " await asyncio.gather(*judge_promises)\n",
+ "\n",
+ " await model.delete_checkpoints()\n",
+ " await model.train(train_groups, config=art.TrainConfig(learning_rate=5e-5))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Using the Model\n",
+ "\n",
+ "Just like that, you've trained an agent to navigate Wikipedia! Now it's time to use your model outside of ART, in the wild! The easiest way to do that is to load it from disk, where it was saved after each training step, and either run inference on it locally or upload it to a central hub like HuggingFace.\n",
+ "\n",
+ "Check out the code below for a small demo of the model you just trained navigating Wikipedia to reach a target page!\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "loading model from .art/tic-tac-toe-local/models/001-script/0100\n",
+ "\n",
+ "==((====))== Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.51.1. vLLM: 0.7.3.\n",
+ " \\\\ /| NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.097 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.5.1+cu124. CUDA: 9.0. CUDA Toolkit: 12.4. Triton: 3.1.0\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post1. FA2 = False]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
+ "\n",
+ "move 1\n",
+ "board:\n",
+ " 1 2 3\n",
+ "A _ | _ | _\n",
+ "B _ | _ | _\n",
+ "C _ | _ | _\n",
+ "\n",
+ "agent move: B1\n",
+ "updated board:\n",
+ " 1 2 3\n",
+ "A _ | _ | _\n",
+ "B x | _ | _\n",
+ "C _ | _ | _\n",
+ "\n",
+ "\n",
+ "move 3\n",
+ "board:\n",
+ " 1 2 3\n",
+ "A _ | _ | _\n",
+ "B x | _ | _\n",
+ "C _ | o | _\n",
+ "\n",
+ "agent move: A1\n",
+ "updated board:\n",
+ " 1 2 3\n",
+ "A x | _ | _\n",
+ "B x | _ | _\n",
+ "C _ | o | _\n",
+ "\n",
+ "\n",
+ "move 5\n",
+ "board:\n",
+ " 1 2 3\n",
+ "A x | o | _\n",
+ "B x | _ | _\n",
+ "C _ | o | _\n",
+ "\n",
+ "agent move: C1\n",
+ "updated board:\n",
+ " 1 2 3\n",
+ "A x | o | _\n",
+ "B x | _ | _\n",
+ "C x | o | _\n",
+ "\n",
+ "game finished in 5 moves\n",
+ "game won! 💪\n",
+ "final board:\n",
+ "\n",
+ " 1 2 3\n",
+ "A x | o | _\n",
+ "B x | _ | _\n",
+ "C x | o | _\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from unsloth import FastLanguageModel\n",
+ "\n",
+ "\n",
+ "# example: .art/wikihop-navigation/models/001-wikihop/0003\n",
+ "lora_model_path = (\n",
+ " f\".art/{model.project}/models/{model.name}/{await model.get_step():04d}\"\n",
+ ")\n",
+ "\n",
+ "print(f\"loading model from {lora_model_path}\\n\")\n",
+ "\n",
+ "peft_model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=lora_model_path,\n",
+ " max_seq_length=16384,\n",
+ " dtype=torch.bfloat16,\n",
+ " load_in_4bit=True,\n",
+ ")\n",
+ "FastLanguageModel.for_inference(peft_model)\n",
+ "\n",
+ "# Demo navigation to a target page\n",
+ "current_url = STARTING_URL\n",
+ "target_url = random.choice(TARGET_URLS)\n",
+ "max_hops = 5\n",
+ "hop_number = 0\n",
+ "\n",
+ "print(f\"🎯 Target: {target_url}\")\n",
+ "print(f\"🚀 Starting from: {current_url}\\n\")\n",
+ "\n",
+ "messages = [\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": f\"You are a Wikipedia navigator. Your goal is to reach the target page: {target_url}\\n\\nYou will be shown the first paragraph of each Wikipedia page. Select the link that is most likely to get you closer to your target. If you see a direct link to your target, choose it immediately.\\n\\nRespond with ONLY the full URL in this exact format: https://en.wikipedia.org/wiki/YourChoice\",\n",
+ " },\n",
+ "]\n",
+ "\n",
+ "while hop_number < max_hops:\n",
+ " # Check if we've reached the target\n",
+ " if check_target_match(current_url, target_url):\n",
+ " print(f\"🎉 SUCCESS! Reached target in {hop_number} hops!\")\n",
+ " break\n",
+ " \n",
+ " # Read the first paragraph of the current page\n",
+ " try:\n",
+ " first_paragraph = read_first_paragraphs_with_links(current_url)\n",
+ " if first_paragraph.startswith(\"Error\"):\n",
+ " print(f\"❌ Error reading page: {first_paragraph}\")\n",
+ " break\n",
+ " except Exception as e:\n",
+ " print(f\"❌ Exception reading page: {e}\")\n",
+ " break\n",
+ "\n",
+ " # Show current page\n",
+ " page_title = current_url.split('/wiki/')[-1].replace('_', ' ')\n",
+ " print(f\"📖 Hop {hop_number + 1}: Currently on '{page_title}'\")\n",
+ " \n",
+ " # Extract and show available links\n",
+ " links = extract_urls(first_paragraph)\n",
+ " print(f\"🔗 Found {len(links)} links: {', '.join([link.split(' -> ')[0] for link in links[:5]])}{'...' if len(links) > 5 else ''}\")\n",
+ " \n",
+ " # Add the current page content to the trajectory\n",
+ " page_content = f\"Current page: {page_title}\\nTarget: {target_url}\\nHop {hop_number + 1}/{max_hops}\\n\\nFirst paragraph:\\n{first_paragraph}\"\n",
+ " messages.append({\"role\": \"user\", \"content\": page_content})\n",
+ "\n",
+ " inputs = tokenizer.apply_chat_template(\n",
+ " messages, return_tensors=\"pt\", add_generation_prompt=True\n",
+ " ).to(\"cuda\")\n",
+ "\n",
+ " def get_completion() -> str:\n",
+ " with torch.no_grad():\n",
+ " outputs = peft_model.generate(\n",
+ " input_ids=inputs,\n",
+ " max_new_tokens=256,\n",
+ " do_sample=True,\n",
+ " temperature=0.7,\n",
+ " top_p=0.9,\n",
+ " )\n",
+ " return tokenizer.decode(\n",
+ " outputs[0][inputs.shape[1] :], skip_special_tokens=True\n",
+ " )\n",
+ "\n",
+ " try:\n",
+ " content = get_completion()\n",
+ " print(f\"🤖 Model response: {content}\")\n",
+ " \n",
+ " # Parse the URL from the response\n",
+ " try:\n",
+ " root = ET.fromstring(content)\n",
+ " next_url = root.text\n",
+ " if not next_url:\n",
+ " raise ValueError(\"Empty URL in response\")\n",
+ " next_url = next_url.strip()\n",
+ " except ET.ParseError:\n",
+ " # Try to extract URL with regex as fallback\n",
+ " import re\n",
+ " url_pattern = r'https://en\\.wikipedia\\.org/wiki/[^\\s<>]+'\n",
+ " matches = re.findall(url_pattern, content)\n",
+ " if matches:\n",
+ " next_url = matches[0]\n",
+ " else:\n",
+ " print(\"❌ Could not parse URL from response\")\n",
+ " break\n",
+ " \n",
+ " except Exception as e:\n",
+ " print(f\"❌ Error generating completion: {e}\")\n",
+ " break\n",
+ "\n",
+ " messages.append({\"role\": \"assistant\", \"content\": content})\n",
+ "\n",
+ " # Show the selected link\n",
+ " next_title = next_url.split('/wiki/')[-1].replace('_', ' ')\n",
+ " print(f\"🔗 Selected: '{next_title}' -> {next_url}\")\n",
+ " \n",
+ " # Move to the next page\n",
+ " current_url = next_url\n",
+ " hop_number += 1\n",
+ " print()\n",
+ "\n",
+ "# Final result\n",
+ "if hop_number >= max_hops and not check_target_match(current_url, target_url):\n",
+ " print(f\"⏰ Reached maximum hops ({max_hops}) without finding target\")\n",
+ " final_title = current_url.split('/wiki/')[-1].replace('_', ' ')\n",
+ " print(f\"📍 Final page: '{final_title}'\")\n",
+ "elif check_target_match(current_url, target_url):\n",
+ " print(f\"🎉 SUCCESS! Found target '{target_url}' in {hop_number} hops!\")\n",
+ "else:\n",
+ " print(f\"❌ Navigation stopped due to error\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "

\n",
+ "

\n",
+ "

\n",
+ "\n",
+ "Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [Github](https://github.com/openpipe/art).\n",
+ "\n",
+ "
\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}