From 26bc3a26b59f47779dcfcf0e1c9ba9d4f3c51337 Mon Sep 17 00:00:00 2001 From: Advait1306 Date: Tue, 22 Apr 2025 02:07:18 +0530 Subject: [PATCH] Refactor tools and update API integration - Updated tool classes to support multiple versions, including BashTool, ComputerTool, and EditTool. - Enhanced API integration with improved error handling and response parsing. - Introduced prompt caching mechanism for better performance. - Updated requirements to use the latest version of the anthropic library. - Refined system prompt and configuration management for better user experience. --- activate.sh | 14 +++ loop.py | 201 ++++++++++++++++++++++++----------- requirements.txt | 2 +- streamlit.py | 249 +++++++++++++++++++++++++++++++++++--------- tools/__init__.py | 20 ++-- tools/base.py | 2 +- tools/bash.py | 25 +++-- tools/collection.py | 2 +- tools/computer.py | 166 +++++++++++++++++++++++++++-- tools/edit.py | 26 ++--- tools/groups.py | 33 ++++++ tools/run.py | 2 +- 12 files changed, 589 insertions(+), 153 deletions(-) create mode 100755 activate.sh create mode 100644 tools/groups.py diff --git a/activate.sh b/activate.sh new file mode 100755 index 00000000..377e75aa --- /dev/null +++ b/activate.sh @@ -0,0 +1,14 @@ +#!/bin/bash +source venv/bin/activate +export PYTHONPATH=$PYTHONPATH:$(pwd) + +echo "Virtual environment activated!" +echo "" +echo "To start the application:" +echo "1. Set your API key:" +echo " export ANTHROPIC_API_KEY=your_api_key_here" +echo "2. Set display dimensions (recommended):" +echo " export WIDTH=1280" +echo " export HEIGHT=800" +echo "3. Run the Streamlit app:" +echo " streamlit run streamlit.py" diff --git a/loop.py b/loop.py index 263328c6..89e1dac7 100644 --- a/loop.py +++ b/loop.py @@ -8,23 +8,35 @@ from enum import StrEnum from typing import Any, cast -from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse -from anthropic.types import ( - ToolResultBlockParam, +import httpx +from anthropic import ( + Anthropic, + AnthropicBedrock, + AnthropicVertex, + APIError, + APIResponseValidationError, + APIStatusError, ) from anthropic.types.beta import ( - BetaContentBlock, + BetaCacheControlEphemeralParam, BetaContentBlockParam, BetaImageBlockParam, BetaMessage, BetaMessageParam, + BetaTextBlock, BetaTextBlockParam, BetaToolResultBlockParam, + BetaToolUseBlockParam, ) -from tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult +from tools import ( + TOOL_GROUPS_BY_VERSION, + ToolCollection, + ToolResult, + ToolVersion, +) -BETA_FLAG = "computer-use-2024-10-22" +PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" class APIProvider(StrEnum): @@ -33,33 +45,13 @@ class APIProvider(StrEnum): VERTEX = "vertex" -PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { - APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022", - APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", - APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", -} - - # This system prompt is optimized for the Docker environment in this repository and # specific tool combinations enabled. # We encourage modifying this system prompt to ensure the model has context for the # environment it is running in, and to provide any additional information that may be # helpful for the task at hand. -# SYSTEM_PROMPT = f""" -# * You are utilizing a macOS Sonoma 15.7 environment using {platform.machine()} architecture with internet access. -# * You can install applications using homebrew with your bash tool. Use curl instead of wget. -# * To open Chrome, please just click on the Chrome icon in the Dock or use Spotlight. -# * Using bash tool you can start GUI applications. GUI apps can be launched directly or with `open -a "Application Name"`. GUI apps will appear natively within macOS, but they may take some time to appear. Take a screenshot to confirm it did. -# * When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B -A ` to confirm output. -# * When viewing a page it can be helpful to zoom out so that you can see everything on the page. In Chrome, use Command + "-" to zoom out or Command + "+" to zoom in. -# * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -# * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. -# -# -# * When using Chrome, if any first-time setup dialogs appear, IGNORE THEM. Instead, click directly in the address bar and enter the appropriate search term or URL there. -# * If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext (available via homebrew) to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. -# """ -SYSTEM_PROMPT = f""" +SYSTEM_PROMPT = f""" + * You are utilizing a macOS Sonoma 15.7 environment using {platform.machine()} architecture with command line internet access. * Package management: - Use homebrew for package installation @@ -95,79 +87,118 @@ class APIProvider(StrEnum): * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. """ + async def sampling_loop( *, model: str, provider: APIProvider, system_prompt_suffix: str, messages: list[BetaMessageParam], - output_callback: Callable[[BetaContentBlock], None], + output_callback: Callable[[BetaContentBlockParam], None], tool_output_callback: Callable[[ToolResult, str], None], - api_response_callback: Callable[[APIResponse[BetaMessage]], None], + api_response_callback: Callable[ + [httpx.Request, httpx.Response | object | None, Exception | None], None + ], api_key: str, only_n_most_recent_images: int | None = None, max_tokens: int = 4096, + tool_version: ToolVersion, + thinking_budget: int | None = None, + token_efficient_tools_beta: bool = False, ): """ Agentic sampling loop for the assistant/tool interaction of computer use. """ - tool_collection = ToolCollection( - ComputerTool(), - BashTool(), - EditTool(), - ) - system = ( - f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}" + tool_group = TOOL_GROUPS_BY_VERSION[tool_version] + tool_collection = ToolCollection(*(ToolCls() for ToolCls in tool_group.tools)) + system = BetaTextBlockParam( + type="text", + text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}", ) while True: - if only_n_most_recent_images: - _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images) - + enable_prompt_caching = False + betas = [tool_group.beta_flag] if tool_group.beta_flag else [] + if token_efficient_tools_beta: + betas.append("token-efficient-tools-2025-02-19") + image_truncation_threshold = only_n_most_recent_images or 0 if provider == APIProvider.ANTHROPIC: - client = Anthropic(api_key=api_key) + client = Anthropic(api_key=api_key, max_retries=4) + enable_prompt_caching = True elif provider == APIProvider.VERTEX: client = AnthropicVertex() elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() + if enable_prompt_caching: + betas.append(PROMPT_CACHING_BETA_FLAG) + _inject_prompt_caching(messages) + # Because cached reads are 10% of the price, we don't think it's + # ever sensible to break the cache by truncating images + only_n_most_recent_images = 0 + # Use type ignore to bypass TypedDict check until SDK types are updated + system["cache_control"] = {"type": "ephemeral"} # type: ignore + + if only_n_most_recent_images: + _maybe_filter_to_n_most_recent_images( + messages, + only_n_most_recent_images, + min_removal_threshold=image_truncation_threshold, + ) + extra_body = {} + if thinking_budget: + # Ensure we only send the required fields for thinking + extra_body = { + "thinking": {"type": "enabled", "budget_tokens": thinking_budget} + } + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: # `response = client.messages.create(...)` instead. - raw_response = client.beta.messages.with_raw_response.create( - max_tokens=max_tokens, - messages=messages, - model=model, - system=system, - tools=tool_collection.to_params(), - betas=[BETA_FLAG], - ) + try: + raw_response = client.beta.messages.with_raw_response.create( + max_tokens=max_tokens, + messages=messages, + model=model, + system=[system], + tools=tool_collection.to_params(), + betas=betas, + extra_body=extra_body, + ) + except (APIStatusError, APIResponseValidationError) as e: + api_response_callback(e.request, e.response, e) + return messages + except APIError as e: + api_response_callback(e.request, e.body, e) + return messages - api_response_callback(cast(APIResponse[BetaMessage], raw_response)) + api_response_callback( + raw_response.http_response.request, raw_response.http_response, None + ) response = raw_response.parse() + response_params = _response_to_params(response) messages.append( { "role": "assistant", - "content": cast(list[BetaContentBlockParam], response.content), + "content": response_params, } ) tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in cast(list[BetaContentBlock], response.content): - print("CONTENT", content_block) + for content_block in response_params: output_callback(content_block) - if content_block.type == "tool_use": + if content_block["type"] == "tool_use": result = await tool_collection.run( - name=content_block.name, - tool_input=cast(dict[str, Any], content_block.input), + name=content_block["name"], + tool_input=cast(dict[str, Any], content_block["input"]), ) tool_result_content.append( - _make_api_tool_result(result, content_block.id) + _make_api_tool_result(result, content_block["id"]) ) - tool_output_callback(result, content_block.id) + tool_output_callback(result, content_block["id"]) if not tool_result_content: return messages @@ -178,7 +209,7 @@ async def sampling_loop( def _maybe_filter_to_n_most_recent_images( messages: list[BetaMessageParam], images_to_keep: int, - min_removal_threshold: int = 10, + min_removal_threshold: int, ): """ With the assumption that images are screenshots that are of diminishing value as @@ -190,7 +221,7 @@ def _maybe_filter_to_n_most_recent_images( return messages tool_result_blocks = cast( - list[ToolResultBlockParam], + list[BetaToolResultBlockParam], [ item for message in messages @@ -224,6 +255,54 @@ def _maybe_filter_to_n_most_recent_images( tool_result["content"] = new_content +def _response_to_params( + response: BetaMessage, +) -> list[BetaContentBlockParam]: + res: list[BetaContentBlockParam] = [] + for block in response.content: + if isinstance(block, BetaTextBlock): + if block.text: + res.append(BetaTextBlockParam(type="text", text=block.text)) + elif getattr(block, "type", None) == "thinking": + # Handle thinking blocks - include signature field + thinking_block = { + "type": "thinking", + "thinking": getattr(block, "thinking", None), + } + if hasattr(block, "signature"): + thinking_block["signature"] = getattr(block, "signature", None) + res.append(cast(BetaContentBlockParam, thinking_block)) + else: + # Handle tool use blocks normally + res.append(cast(BetaToolUseBlockParam, block.model_dump())) + return res + + +def _inject_prompt_caching( + messages: list[BetaMessageParam], +): + """ + Set cache breakpoints for the 3 most recent turns + one cache breakpoint is left for tools/system prompt, to be shared across sessions + """ + + breakpoints_remaining = 3 + for message in reversed(messages): + if message["role"] == "user" and isinstance( + content := message["content"], list + ): + if breakpoints_remaining: + breakpoints_remaining -= 1 + # Use type ignore to bypass TypedDict check until SDK types are updated + content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore + {"type": "ephemeral"} + ) + else: + content[-1].pop("cache_control", None) + # we'll only every have one extra turn per loop + break + + def _make_api_tool_result( result: ToolResult, tool_use_id: str ) -> BetaToolResultBlockParam: @@ -263,4 +342,4 @@ def _make_api_tool_result( def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str): if result.system: result_text = f"{result.system}\n{result_text}" - return result_text + return result_text \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d3113e7b..fefc6bf9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ streamlit>=1.38.0 -anthropic[bedrock,vertex]>=0.37.1 +anthropic[bedrock,vertex]>=0.49.0 jsonschema==4.22.0 keyboard>=0.13.5 boto3>=1.28.57 diff --git a/streamlit.py b/streamlit.py index a57a5607..94cb2f6e 100644 --- a/streamlit.py +++ b/streamlit.py @@ -6,49 +6,86 @@ import base64 import os import subprocess -from datetime import datetime +import traceback +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timedelta from enum import StrEnum from functools import partial from pathlib import PosixPath -from typing import cast +from typing import cast, get_args +import httpx import streamlit as st -from anthropic import APIResponse -from anthropic.types import ( - TextBlock, +from anthropic import RateLimitError +from anthropic.types.beta import ( + BetaContentBlockParam, + BetaTextBlockParam, + BetaToolResultBlockParam, ) -from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock -from anthropic.types.tool_use_block import ToolUseBlock from streamlit.delta_generator import DeltaGenerator from loop import ( - PROVIDER_TO_DEFAULT_MODEL_NAME, APIProvider, sampling_loop, ) -from tools import ToolResult -from dotenv import load_dotenv +from tools import ToolResult, ToolVersion -load_dotenv() +PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { + APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", + APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", + APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", +} +@dataclass(kw_only=True, frozen=True) +class ModelConfig: + tool_version: ToolVersion + max_output_tokens: int + default_output_tokens: int + has_thinking: bool = False + + +SONNET_3_5_NEW = ModelConfig( + tool_version="computer_use_20241022", + max_output_tokens=1024 * 8, + default_output_tokens=1024 * 4, +) + +SONNET_3_7 = ModelConfig( + tool_version="computer_use_20250124", + max_output_tokens=128_000, + default_output_tokens=1024 * 16, + has_thinking=True, +) + +MODEL_TO_MODEL_CONF: dict[str, ModelConfig] = { + "claude-3-7-sonnet-20250219": SONNET_3_7, +} + CONFIG_DIR = PosixPath("~/.anthropic").expanduser() API_KEY_FILE = CONFIG_DIR / "api_key" STREAMLIT_STYLE = """ """ -WARNING_TEXT = "" +WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior" +INTERRUPT_TEXT = "(user stopped or interrupted and wrote the following)" +INTERRUPT_TOOL_ERROR = "human stopped or interrupted tool execution" class Sender(StrEnum): @@ -80,17 +117,41 @@ def setup_state(): if "tools" not in st.session_state: st.session_state.tools = {} if "only_n_most_recent_images" not in st.session_state: - st.session_state.only_n_most_recent_images = 10 + st.session_state.only_n_most_recent_images = 3 if "custom_system_prompt" not in st.session_state: st.session_state.custom_system_prompt = load_from_storage("system_prompt") or "" if "hide_images" not in st.session_state: st.session_state.hide_images = False + if "token_efficient_tools_beta" not in st.session_state: + st.session_state.token_efficient_tools_beta = False + if "in_sampling_loop" not in st.session_state: + st.session_state.in_sampling_loop = False def _reset_model(): st.session_state.model = PROVIDER_TO_DEFAULT_MODEL_NAME[ cast(APIProvider, st.session_state.provider) ] + _reset_model_conf() + + +def _reset_model_conf(): + model_conf = ( + SONNET_3_7 + if "3-7" in st.session_state.model + else MODEL_TO_MODEL_CONF.get(st.session_state.model, SONNET_3_5_NEW) + ) + + # If we're in radio selection mode, use the selected tool version + if hasattr(st.session_state, "tool_versions"): + st.session_state.tool_version = st.session_state.tool_versions + else: + st.session_state.tool_version = model_conf.tool_version + + st.session_state.has_thinking = model_conf.has_thinking + st.session_state.output_tokens = model_conf.default_output_tokens + st.session_state.max_output_tokens = model_conf.max_output_tokens + st.session_state.thinking_budget = int(model_conf.default_output_tokens / 2) async def main(): @@ -101,7 +162,8 @@ async def main(): st.title("Claude Computer Use for Mac") - st.markdown("""This is from [Mac Computer Use](https://github.com/deedy/mac_computer_use), a fork of [Anthropic Computer Use](https://github.com/anthropics/anthropic-quickstarts/blob/main/computer-use-demo/README.md) to work natively on Mac.""") + if not os.getenv("HIDE_WARNING", False): + st.warning(WARNING_TEXT) with st.sidebar: @@ -120,7 +182,7 @@ def _reset_api_provider(): on_change=_reset_api_provider, ) - st.text_input("Model", key="model") + st.text_input("Model", key="model", on_change=_reset_model_conf) if st.session_state.provider == APIProvider.ANTHROPIC: st.text_input( @@ -145,6 +207,30 @@ def _reset_api_provider(): ), ) st.checkbox("Hide screenshots", key="hide_images") + st.checkbox( + "Enable token-efficient tools beta", key="token_efficient_tools_beta" + ) + versions = get_args(ToolVersion) + st.radio( + "Tool Versions", + key="tool_versions", + options=versions, + index=versions.index(st.session_state.tool_version), + on_change=lambda: setattr( + st.session_state, "tool_version", st.session_state.tool_versions + ), + ) + + st.number_input("Max Output Tokens", key="output_tokens", step=1) + + st.checkbox("Thinking Enabled", key="thinking", value=False) + st.number_input( + "Thinking Budget", + key="thinking_budget", + max_value=st.session_state.max_output_tokens, + step=1, + disabled=not st.session_state.thinking, + ) if st.button("Reset", type="primary"): with st.spinner("Resetting..."): @@ -185,19 +271,22 @@ def _reset_api_provider(): else: _render_message( message["role"], - cast(BetaTextBlock | BetaToolUseBlock, block), + cast(BetaContentBlockParam | ToolResult, block), ) # render past http exchanges - for identity, response in st.session_state.responses.items(): - _render_api_response(response, identity, http_logs) + for identity, (request, response) in st.session_state.responses.items(): + _render_api_response(request, response, identity, http_logs) # render past chats if new_message: st.session_state.messages.append( { "role": Sender.USER, - "content": [TextBlock(type="text", text=new_message)], + "content": [ + *maybe_add_interruption_blocks(), + BetaTextBlockParam(type="text", text=new_message), + ], } ) _render_message(Sender.USER, new_message) @@ -211,7 +300,7 @@ def _reset_api_provider(): # we don't have a user message to respond to, exit early return - with st.spinner("Running Agent..."): + with track_sampling_loop(): # run the agent sampling loop with the newest message st.session_state.messages = await sampling_loop( system_prompt_suffix=st.session_state.custom_system_prompt, @@ -229,7 +318,44 @@ def _reset_api_provider(): ), api_key=st.session_state.api_key, only_n_most_recent_images=st.session_state.only_n_most_recent_images, + tool_version=st.session_state.tool_versions, + max_tokens=st.session_state.output_tokens, + thinking_budget=st.session_state.thinking_budget + if st.session_state.thinking + else None, + token_efficient_tools_beta=st.session_state.token_efficient_tools_beta, + ) + + +def maybe_add_interruption_blocks(): + if not st.session_state.in_sampling_loop: + return [] + # If this function is called while we're in the sampling loop, we can assume that the previous sampling loop was interrupted + # and we should annotate the conversation with additional context for the model and heal any incomplete tool use calls + result = [] + last_message = st.session_state.messages[-1] + previous_tool_use_ids = [ + block["id"] for block in last_message["content"] if block["type"] == "tool_use" + ] + for tool_use_id in previous_tool_use_ids: + st.session_state.tools[tool_use_id] = ToolResult(error=INTERRUPT_TOOL_ERROR) + result.append( + BetaToolResultBlockParam( + tool_use_id=tool_use_id, + type="tool_result", + content=INTERRUPT_TOOL_ERROR, + is_error=True, ) + ) + result.append(BetaTextBlockParam(type="text", text=INTERRUPT_TEXT)) + return result + + +@contextmanager +def track_sampling_loop(): + st.session_state.in_sampling_loop = True + yield + st.session_state.in_sampling_loop = False def validate_auth(provider: APIProvider, api_key: str | None): @@ -281,16 +407,20 @@ def save_to_storage(filename: str, data: str) -> None: def _api_response_callback( - response: APIResponse[BetaMessage], + request: httpx.Request, + response: httpx.Response | object | None, + error: Exception | None, tab: DeltaGenerator, - response_state: dict[str, APIResponse[BetaMessage]], + response_state: dict[str, tuple[httpx.Request, httpx.Response | object | None]], ): """ Handle an API response by storing it to state and rendering it. """ response_id = datetime.now().isoformat() - response_state[response_id] = response - _render_api_response(response, response_id, tab) + response_state[response_id] = (request, response) + if error: + _render_error(error) + _render_api_response(request, response, response_id, tab) def _tool_output_callback( @@ -302,33 +432,51 @@ def _tool_output_callback( def _render_api_response( - response: APIResponse[BetaMessage], response_id: str, tab: DeltaGenerator + request: httpx.Request, + response: httpx.Response | object | None, + response_id: str, + tab: DeltaGenerator, ): """Render an API response to a streamlit tab""" with tab: with st.expander(f"Request/Response ({response_id})"): newline = "\n\n" st.markdown( - f"`{response.http_request.method} {response.http_request.url}`{newline}{newline.join(f'`{k}: {v}`' for k, v in response.http_request.headers.items())}" - ) - st.json(response.http_request.read().decode()) - st.markdown( - f"`{response.http_response.status_code}`{newline}{newline.join(f'`{k}: {v}`' for k, v in response.headers.items())}" + f"`{request.method} {request.url}`{newline}{newline.join(f'`{k}: {v}`' for k, v in request.headers.items())}" ) - st.json(response.http_response.text) + st.json(request.read().decode()) + st.markdown("---") + if isinstance(response, httpx.Response): + st.markdown( + f"`{response.status_code}`{newline}{newline.join(f'`{k}: {v}`' for k, v in response.headers.items())}" + ) + st.json(response.text) + else: + st.write(response) + + +def _render_error(error: Exception): + if isinstance(error, RateLimitError): + body = "You have been rate limited." + if retry_after := error.response.headers.get("retry-after"): + body += f" **Retry after {str(timedelta(seconds=int(retry_after)))} (HH:MM:SS).** See our API [documentation](https://docs.anthropic.com/en/api/rate-limits) for more details." + body += f"\n\n{error.message}" + else: + body = str(error) + body += "\n\n**Traceback:**" + lines = "\n".join(traceback.format_exception(error)) + body += f"\n\n```{lines}```" + save_to_storage(f"error_{datetime.now().timestamp()}.md", body) + st.error(f"**{error.__class__.__name__}**\n\n{body}", icon=":material/error:") def _render_message( sender: Sender, - message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, + message: str | BetaContentBlockParam | ToolResult, ): """Convert input from the user or output from the agent to a streamlit message.""" # streamlit's hotreloading breaks isinstance checks, so we need to check for class names - is_tool_result = not isinstance(message, str) and ( - isinstance(message, ToolResult) - or message.__class__.__name__ == "ToolResult" - or message.__class__.__name__ == "CLIResult" - ) + is_tool_result = not isinstance(message, str | dict) if not message or ( is_tool_result and st.session_state.hide_images @@ -348,13 +496,20 @@ def _render_message( st.error(message.error) if message.base64_image and not st.session_state.hide_images: st.image(base64.b64decode(message.base64_image)) - elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock): - st.write(message.text) - elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock): - st.code(f"Tool Use: {message.name}\nInput: {message.input}") + elif isinstance(message, dict): + if message["type"] == "text": + st.write(message["text"]) + elif message["type"] == "thinking": + thinking_content = message.get("thinking", "") + st.markdown(f"[Thinking]\n\n{thinking_content}") + elif message["type"] == "tool_use": + st.code(f'Tool Use: {message["name"]}\nInput: {message["input"]}') + else: + # only expected return types are text and tool_use + raise Exception(f'Unexpected response type {message["type"]}') else: st.markdown(message) if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py index 1fd037f1..b4c4f137 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1,14 +1,20 @@ from .base import CLIResult, ToolResult -from .bash import BashTool +from .bash import BashTool20241022, BashTool20250124 from .collection import ToolCollection -from .computer import ComputerTool -from .edit import EditTool +from .computer import ComputerTool20241022, ComputerTool20250124 +from .edit import EditTool20241022, EditTool20250124 +from .groups import TOOL_GROUPS_BY_VERSION, ToolVersion __ALL__ = [ - BashTool, + BashTool20241022, + BashTool20250124, CLIResult, - ComputerTool, - EditTool, + ComputerTool20241022, + ComputerTool20250124, + EditTool20241022, + EditTool20250124, ToolCollection, ToolResult, -] + ToolVersion, + TOOL_GROUPS_BY_VERSION, +] \ No newline at end of file diff --git a/tools/base.py b/tools/base.py index d6f13712..08c7d054 100644 --- a/tools/base.py +++ b/tools/base.py @@ -66,4 +66,4 @@ class ToolError(Exception): """Raised when a tool encounters an error.""" def __init__(self, message): - self.message = message + self.message = message \ No newline at end of file diff --git a/tools/bash.py b/tools/bash.py index db13db00..9ad93587 100644 --- a/tools/bash.py +++ b/tools/bash.py @@ -1,8 +1,6 @@ import asyncio import os -from typing import ClassVar, Literal - -from anthropic.types.beta import BetaToolBash20241022Param +from typing import Any, Literal from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult @@ -103,20 +101,27 @@ async def run(self, command: str): return CLIResult(output=output, error=error) -class BashTool(BaseAnthropicTool): +class BashTool20250124(BaseAnthropicTool): """ A tool that allows the agent to run bash commands. The tool parameters are defined by Anthropic and are not editable. """ _session: _BashSession | None - name: ClassVar[Literal["bash"]] = "bash" - api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022" + + api_type: Literal["bash_20250124"] = "bash_20250124" + name: Literal["bash"] = "bash" def __init__(self): self._session = None super().__init__() + def to_params(self) -> Any: + return { + "type": self.api_type, + "name": self.name, + } + async def __call__( self, command: str | None = None, restart: bool = False, **kwargs ): @@ -137,8 +142,6 @@ async def __call__( raise ToolError("no command provided.") - def to_params(self) -> BetaToolBash20241022Param: - return { - "type": self.api_type, - "name": self.name, - } + +class BashTool20241022(BashTool20250124): + api_type: Literal["bash_20241022"] = "bash_20241022" # pyright: ignore[reportIncompatibleVariableOverride] \ No newline at end of file diff --git a/tools/collection.py b/tools/collection.py index c4e8c95c..7b9e0dc6 100644 --- a/tools/collection.py +++ b/tools/collection.py @@ -31,4 +31,4 @@ async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: try: return await tool(**tool_input) except ToolError as e: - return ToolFailure(error=e.message) + return ToolFailure(error=e.message) \ No newline at end of file diff --git a/tools/computer.py b/tools/computer.py index 0e7646fb..dd3b72cd 100644 --- a/tools/computer.py +++ b/tools/computer.py @@ -6,10 +6,10 @@ import keyboard from enum import StrEnum from pathlib import Path -from typing import Literal, TypedDict +from typing import Literal, TypedDict, cast, get_args from uuid import uuid4 -from anthropic.types.beta import BetaToolComputerUse20241022Param +from anthropic.types.beta import BetaToolComputerUse20241022Param, BetaToolUnionParam from .base import BaseAnthropicTool, ToolError, ToolResult from .run import run @@ -19,7 +19,7 @@ TYPING_DELAY_MS = 12 TYPING_GROUP_SIZE = 50 -Action = Literal[ +Action_20241022 = Literal[ "key", "type", "mouse_move", @@ -32,6 +32,20 @@ "cursor_position", ] +Action_20250124 = ( + Action_20241022 + | Literal[ + "left_mouse_down", + "left_mouse_up", + "scroll", + "hold_key", + "wait", + "triple_click", + ] +) + +ScrollDirection = Literal["up", "down", "left", "right"] + class Resolution(TypedDict): width: int @@ -47,6 +61,14 @@ class Resolution(TypedDict): } SCALE_DESTINATION = MAX_SCALING_TARGETS["FWXGA"] +CLICK_BUTTONS = { + "left_click": 1, + "right_click": 3, + "middle_click": 2, + "double_click": "--repeat 2 --delay 10 1", + "triple_click": "--repeat 3 --delay 10 1", +} + class ScalingSource(StrEnum): COMPUTER = "computer" @@ -63,7 +85,7 @@ def chunks(s: str, chunk_size: int) -> list[str]: return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)] -class ComputerTool(BaseAnthropicTool): +class BaseComputerTool: """ A tool that allows the agent to interact with the screen, keyboard, and mouse of the current macOS computer. The tool parameters are defined by Anthropic and are not editable. @@ -71,7 +93,6 @@ class ComputerTool(BaseAnthropicTool): """ name: Literal["computer"] = "computer" - api_type: Literal["computer_20241022"] = "computer_20241022" width: int height: int display_num: int | None @@ -81,15 +102,15 @@ class ComputerTool(BaseAnthropicTool): @property def options(self) -> ComputerToolOptions: + width, height = self.scale_coordinates( + ScalingSource.COMPUTER, self.width, self.height + ) return { - "display_width_px": self.width, - "display_height_px": self.height, + "display_width_px": width, + "display_height_px": height, "display_number": self.display_num, } - def to_params(self) -> BetaToolComputerUse20241022Param: - return {"name": self.name, "type": self.api_type, **self.options} - def __init__(self): super().__init__() @@ -100,7 +121,7 @@ def __init__(self): async def __call__( self, *, - action: Action, + action: Action_20241022, text: str | None = None, coordinate: tuple[int, int] | None = None, **kwargs, @@ -279,3 +300,126 @@ def scale_coordinates(self, source: ScalingSource, x: int, y: int) -> tuple[int, else: # Scale down from original resolution to SCALE_DESTINATION return round(x * x_scaling_factor), round(y * y_scaling_factor) + +class ComputerTool20241022(BaseComputerTool, BaseAnthropicTool): + api_type: Literal["computer_20241022"] = "computer_20241022" + + def to_params(self) -> BetaToolComputerUse20241022Param: + return {"name": self.name, "type": self.api_type, **self.options} + +class ComputerTool20250124(BaseComputerTool, BaseAnthropicTool): + api_type: Literal["computer_20250124"] = "computer_20250124" + + def to_params(self): + return cast( + BetaToolUnionParam, + {"name": self.name, "type": self.api_type, **self.options}, + ) + + async def __call__( + self, + *, + action: Action_20250124, + text: str | None = None, + coordinate: tuple[int, int] | None = None, + scroll_direction: ScrollDirection | None = None, + scroll_amount: int | None = None, + duration: int | float | None = None, + key: str | None = None, + **kwargs, + ): + if action in ("left_mouse_down", "left_mouse_up"): + if coordinate is not None: + raise ToolError(f"coordinate is not accepted for {action=}.") + click_cmd = { + "left_mouse_down": "dd:.", # Press down + "left_mouse_up": "du:.", # Release up + }[action] + return await self.shell(f"cliclick {click_cmd}") + + if action == "scroll": + if scroll_direction is None or scroll_direction not in get_args( + ScrollDirection + ): + raise ToolError( + f"{scroll_direction=} must be 'up', 'down', 'left', or 'right'" + ) + if not isinstance(scroll_amount, int) or scroll_amount < 0: + raise ToolError(f"{scroll_amount=} must be a non-negative int") + + if coordinate is not None: + x, y = self.scale_coordinates(ScalingSource.API, coordinate[0], coordinate[1]) + await self.shell(f"cliclick m:{x},{y}") + + # Map scroll directions to cliclick scroll commands + scroll_cmd = { + "up": "kp:page-up", + "down": "kp:page-down", + "left": "kp:arrow-left", + "right": "kp:arrow-right" + }[scroll_direction] + + commands = [] + if text: + commands.append(f"cliclick kd:{text}") + for _ in range(scroll_amount): + commands.append(f"cliclick {scroll_cmd}") + if text: + commands.append(f"cliclick ku:{text}") + + for cmd in commands: + await self.shell(cmd) + return await self.screenshot() + + if action in ("hold_key", "wait"): + if duration is None or not isinstance(duration, (int, float)): + raise ToolError(f"{duration=} must be a number") + if duration < 0: + raise ToolError(f"{duration=} must be non-negative") + if duration > 100: + raise ToolError(f"{duration=} is too long.") + + if action == "hold_key": + if text is None: + raise ToolError(f"text is required for {action}") + await self.shell(f"cliclick kd:{text}") + await asyncio.sleep(duration) + await self.shell(f"cliclick ku:{text}") + return await self.screenshot() + + if action == "wait": + await asyncio.sleep(duration) + return await self.screenshot() + + if action in ( + "left_click", + "right_click", + "double_click", + "triple_click", + "middle_click", + ): + if text is not None: + raise ToolError(f"text is not accepted for {action}") + + if coordinate is not None: + x, y = self.scale_coordinates(ScalingSource.API, coordinate[0], coordinate[1]) + await self.shell(f"cliclick m:{x},{y}") + + click_cmd = { + "left_click": "c:.", + "right_click": "rc:.", + "middle_click": "mc:.", + "double_click": "dc:.", + "triple_click": "tc:.", # Note: cliclick may not support triple click natively + }[action] + + if key: + await self.shell(f"cliclick kd:{key}") + result = await self.shell(f"cliclick {click_cmd}") + if key: + await self.shell(f"cliclick ku:{key}") + return result + + return await super().__call__( + action=action, text=text, coordinate=coordinate, key=key, **kwargs + ) \ No newline at end of file diff --git a/tools/edit.py b/tools/edit.py index 3cca1f36..3c18bd3f 100644 --- a/tools/edit.py +++ b/tools/edit.py @@ -1,8 +1,6 @@ from collections import defaultdict from pathlib import Path -from typing import Literal, get_args - -from anthropic.types.beta import BetaToolTextEditor20241022Param +from typing import Any, Literal, get_args from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult from .run import maybe_truncate, run @@ -17,13 +15,13 @@ SNIPPET_LINES: int = 4 -class EditTool(BaseAnthropicTool): +class EditTool20250124(BaseAnthropicTool): """ An filesystem editor tool that allows the agent to view, create, and edit files. The tool parameters are defined by Anthropic and are not editable. """ - api_type: Literal["text_editor_20241022"] = "text_editor_20241022" + api_type: Literal["text_editor_20250124"] = "text_editor_20250124" name: Literal["str_replace_editor"] = "str_replace_editor" _file_history: dict[Path, list[str]] @@ -32,7 +30,7 @@ def __init__(self): self._file_history = defaultdict(list) super().__init__() - def to_params(self) -> BetaToolTextEditor20241022Param: + def to_params(self) -> Any: return { "name": self.name, "type": self.api_type, @@ -55,13 +53,13 @@ async def __call__( if command == "view": return await self.view(_path, view_range) elif command == "create": - if not file_text: + if file_text is None: raise ToolError("Parameter `file_text` is required for command: create") self.write_file(_path, file_text) self._file_history[_path].append(file_text) return ToolResult(output=f"File created successfully at: {_path}") elif command == "str_replace": - if not old_str: + if old_str is None: raise ToolError( "Parameter `old_str` is required for command: str_replace" ) @@ -71,7 +69,7 @@ async def __call__( raise ToolError( "Parameter `insert_line` is required for command: insert" ) - if not new_str: + if new_str is None: raise ToolError("Parameter `new_str` is required for command: insert") return self.insert(_path, insert_line, new_str) elif command == "undo_edit": @@ -133,15 +131,15 @@ async def view(self, path: Path, view_range: list[int] | None = None): init_line, final_line = view_range if init_line < 1 or init_line > n_lines_file: raise ToolError( - f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}" + f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}" ) if final_line > n_lines_file: raise ToolError( - f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`" + f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`" ) if final_line != -1 and final_line < init_line: raise ToolError( - f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`" + f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`" ) if final_line == -1: @@ -288,3 +286,7 @@ def _make_output( + file_content + "\n" ) + + +class EditTool20241022(EditTool20250124): + api_type: Literal["text_editor_20241022"] = "text_editor_20241022" # pyright: ignore[reportIncompatibleVariableOverride] \ No newline at end of file diff --git a/tools/groups.py b/tools/groups.py new file mode 100644 index 00000000..51bf209c --- /dev/null +++ b/tools/groups.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import Literal + +from .base import BaseAnthropicTool +from .bash import BashTool20241022, BashTool20250124 +from .computer import ComputerTool20241022, ComputerTool20250124 +from .edit import EditTool20241022, EditTool20250124 + +ToolVersion = Literal["computer_use_20250124", "computer_use_20241022"] +BetaFlag = Literal["computer-use-2024-10-22", "computer-use-2025-01-24"] + + +@dataclass(frozen=True, kw_only=True) +class ToolGroup: + version: ToolVersion + tools: list[type[BaseAnthropicTool]] + beta_flag: BetaFlag | None = None + + +TOOL_GROUPS: list[ToolGroup] = [ + ToolGroup( + version="computer_use_20241022", + tools=[ComputerTool20241022, EditTool20241022, BashTool20241022], + beta_flag="computer-use-2024-10-22", + ), + ToolGroup( + version="computer_use_20250124", + tools=[ComputerTool20250124, EditTool20250124, BashTool20250124], + beta_flag="computer-use-2025-01-24", + ), +] + +TOOL_GROUPS_BY_VERSION = {tool_group.version: tool_group for tool_group in TOOL_GROUPS} \ No newline at end of file diff --git a/tools/run.py b/tools/run.py index 89db980a..0bf447b8 100644 --- a/tools/run.py +++ b/tools/run.py @@ -39,4 +39,4 @@ async def run( pass raise TimeoutError( f"Command '{cmd}' timed out after {timeout} seconds" - ) from exc + ) from exc \ No newline at end of file