Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions activate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
source venv/bin/activate
export PYTHONPATH=$PYTHONPATH:$(pwd)

echo "Virtual environment activated!"
echo ""
echo "To start the application:"
echo "1. Set your API key:"
echo " export ANTHROPIC_API_KEY=your_api_key_here"
echo "2. Set display dimensions (recommended):"
echo " export WIDTH=1280"
echo " export HEIGHT=800"
echo "3. Run the Streamlit app:"
echo " streamlit run streamlit.py"
201 changes: 140 additions & 61 deletions loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,35 @@
from enum import StrEnum
from typing import Any, cast

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic.types import (
ToolResultBlockParam,
import httpx
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIError,
APIResponseValidationError,
APIStatusError,
)
from anthropic.types.beta import (
BetaContentBlock,
BetaCacheControlEphemeralParam,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolResultBlockParam,
BetaToolUseBlockParam,
)

from tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
from tools import (
TOOL_GROUPS_BY_VERSION,
ToolCollection,
ToolResult,
ToolVersion,
)

BETA_FLAG = "computer-use-2024-10-22"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"


class APIProvider(StrEnum):
Expand All @@ -33,33 +45,13 @@ class APIProvider(StrEnum):
VERTEX = "vertex"


PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
}


# This system prompt is optimized for the Docker environment in this repository and
# specific tool combinations enabled.
# We encourage modifying this system prompt to ensure the model has context for the
# environment it is running in, and to provide any additional information that may be
# helpful for the task at hand.
# SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
# * You are utilizing a macOS Sonoma 15.7 environment using {platform.machine()} architecture with internet access.
# * You can install applications using homebrew with your bash tool. Use curl instead of wget.
# * To open Chrome, please just click on the Chrome icon in the Dock or use Spotlight.
# * Using bash tool you can start GUI applications. GUI apps can be launched directly or with `open -a "Application Name"`. GUI apps will appear natively within macOS, but they may take some time to appear. Take a screenshot to confirm it did.
# * When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B <lines before> -A <lines after> <query> <filename>` to confirm output.
# * When viewing a page it can be helpful to zoom out so that you can see everything on the page. In Chrome, use Command + "-" to zoom out or Command + "+" to zoom in.
# * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
# * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
# </SYSTEM_CAPABILITY>
# <IMPORTANT>
# * When using Chrome, if any first-time setup dialogs appear, IGNORE THEM. Instead, click directly in the address bar and enter the appropriate search term or URL there.
# * If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext (available via homebrew) to convert it to a text file, and then read that text file directly with your StrReplaceEditTool.
# </IMPORTANT>"""
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
SYSTEM_PROMPT = f"""
<SYSTEM_CAPABILITY>
* You are utilizing a macOS Sonoma 15.7 environment using {platform.machine()} architecture with command line internet access.
* Package management:
- Use homebrew for package installation
Expand Down Expand Up @@ -95,79 +87,118 @@ class APIProvider(StrEnum):
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
</SYSTEM_CAPABILITY>"""


async def sampling_loop(
*,
model: str,
provider: APIProvider,
system_prompt_suffix: str,
messages: list[BetaMessageParam],
output_callback: Callable[[BetaContentBlock], None],
output_callback: Callable[[BetaContentBlockParam], None],
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
api_response_callback: Callable[
[httpx.Request, httpx.Response | object | None, Exception | None], None
],
api_key: str,
only_n_most_recent_images: int | None = None,
max_tokens: int = 4096,
tool_version: ToolVersion,
thinking_budget: int | None = None,
token_efficient_tools_beta: bool = False,
):
"""
Agentic sampling loop for the assistant/tool interaction of computer use.
"""
tool_collection = ToolCollection(
ComputerTool(),
BashTool(),
EditTool(),
)
system = (
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
tool_group = TOOL_GROUPS_BY_VERSION[tool_version]
tool_collection = ToolCollection(*(ToolCls() for ToolCls in tool_group.tools))
system = BetaTextBlockParam(
type="text",
text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}",
)

while True:
if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images)

enable_prompt_caching = False
betas = [tool_group.beta_flag] if tool_group.beta_flag else []
if token_efficient_tools_beta:
betas.append("token-efficient-tools-2025-02-19")
image_truncation_threshold = only_n_most_recent_images or 0
if provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=api_key)
client = Anthropic(api_key=api_key, max_retries=4)
enable_prompt_caching = True
elif provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
_inject_prompt_caching(messages)
# Because cached reads are 10% of the price, we don't think it's
# ever sensible to break the cache by truncating images
only_n_most_recent_images = 0
# Use type ignore to bypass TypedDict check until SDK types are updated
system["cache_control"] = {"type": "ephemeral"} # type: ignore

if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(
messages,
only_n_most_recent_images,
min_removal_threshold=image_truncation_threshold,
)
extra_body = {}
if thinking_budget:
# Ensure we only send the required fields for thinking
extra_body = {
"thinking": {"type": "enabled", "budget_tokens": thinking_budget}
}

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
# `response = client.messages.create(...)` instead.
raw_response = client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=model,
system=system,
tools=tool_collection.to_params(),
betas=[BETA_FLAG],
)
try:
raw_response = client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=model,
system=[system],
tools=tool_collection.to_params(),
betas=betas,
extra_body=extra_body,
)
except (APIStatusError, APIResponseValidationError) as e:
api_response_callback(e.request, e.response, e)
return messages
except APIError as e:
api_response_callback(e.request, e.body, e)
return messages

api_response_callback(cast(APIResponse[BetaMessage], raw_response))
api_response_callback(
raw_response.http_response.request, raw_response.http_response, None
)

response = raw_response.parse()

response_params = _response_to_params(response)
messages.append(
{
"role": "assistant",
"content": cast(list[BetaContentBlockParam], response.content),
"content": response_params,
}
)

tool_result_content: list[BetaToolResultBlockParam] = []
for content_block in cast(list[BetaContentBlock], response.content):
print("CONTENT", content_block)
for content_block in response_params:
output_callback(content_block)
if content_block.type == "tool_use":
if content_block["type"] == "tool_use":
result = await tool_collection.run(
name=content_block.name,
tool_input=cast(dict[str, Any], content_block.input),
name=content_block["name"],
tool_input=cast(dict[str, Any], content_block["input"]),
)
tool_result_content.append(
_make_api_tool_result(result, content_block.id)
_make_api_tool_result(result, content_block["id"])
)
tool_output_callback(result, content_block.id)
tool_output_callback(result, content_block["id"])

if not tool_result_content:
return messages
Expand All @@ -178,7 +209,7 @@ async def sampling_loop(
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int = 10,
min_removal_threshold: int,
):
"""
With the assumption that images are screenshots that are of diminishing value as
Expand All @@ -190,7 +221,7 @@ def _maybe_filter_to_n_most_recent_images(
return messages

tool_result_blocks = cast(
list[ToolResultBlockParam],
list[BetaToolResultBlockParam],
[
item
for message in messages
Expand Down Expand Up @@ -224,6 +255,54 @@ def _maybe_filter_to_n_most_recent_images(
tool_result["content"] = new_content


def _response_to_params(
response: BetaMessage,
) -> list[BetaContentBlockParam]:
res: list[BetaContentBlockParam] = []
for block in response.content:
if isinstance(block, BetaTextBlock):
if block.text:
res.append(BetaTextBlockParam(type="text", text=block.text))
elif getattr(block, "type", None) == "thinking":
# Handle thinking blocks - include signature field
thinking_block = {
"type": "thinking",
"thinking": getattr(block, "thinking", None),
}
if hasattr(block, "signature"):
thinking_block["signature"] = getattr(block, "signature", None)
res.append(cast(BetaContentBlockParam, thinking_block))
else:
# Handle tool use blocks normally
res.append(cast(BetaToolUseBlockParam, block.model_dump()))
return res


def _inject_prompt_caching(
messages: list[BetaMessageParam],
):
"""
Set cache breakpoints for the 3 most recent turns
one cache breakpoint is left for tools/system prompt, to be shared across sessions
"""

breakpoints_remaining = 3
for message in reversed(messages):
if message["role"] == "user" and isinstance(
content := message["content"], list
):
if breakpoints_remaining:
breakpoints_remaining -= 1
# Use type ignore to bypass TypedDict check until SDK types are updated
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
{"type": "ephemeral"}
)
else:
content[-1].pop("cache_control", None)
# we'll only every have one extra turn per loop
break


def _make_api_tool_result(
result: ToolResult, tool_use_id: str
) -> BetaToolResultBlockParam:
Expand Down Expand Up @@ -263,4 +342,4 @@ def _make_api_tool_result(
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
if result.system:
result_text = f"<system>{result.system}</system>\n{result_text}"
return result_text
return result_text
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
streamlit>=1.38.0
anthropic[bedrock,vertex]>=0.37.1
anthropic[bedrock,vertex]>=0.49.0
jsonschema==4.22.0
keyboard>=0.13.5
boto3>=1.28.57
Expand Down
Loading