diff --git a/pyproject.toml b/pyproject.toml index 8229568..9099ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "requests>=2.31.0", "click>=8.0.0", "openai>=1.0.0", + "anthropic>=0.18.0", ] [project.scripts] diff --git a/src/concierge_clients/client_tool_calling.py b/src/concierge_clients/client_tool_calling.py index 7d6cb69..f375cd6 100644 --- a/src/concierge_clients/client_tool_calling.py +++ b/src/concierge_clients/client_tool_calling.py @@ -3,17 +3,17 @@ import time import threading import requests -from openai import OpenAI from enum import Enum from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from concierge.config import SERVER_HOST, SERVER_PORT +from concierge_clients.providers.factory import get_provider class Mode(Enum): """Client operation modes""" - USER = "user" + USER = "user" SERVER = "server" @@ -49,20 +49,32 @@ def _spin(self): class ToolCallingClient: - - def __init__(self, api_base: str, api_key: str, verbose: bool = False): - self.llm = OpenAI(base_url=api_base, api_key=api_key) - self.model = "gpt-5" + + def __init__(self, api_base: str, api_key: str, provider_name: str = "openai", model: str = None, verbose: bool = False): + self.provider_name = provider_name + + # Initialize provider + provider_kwargs = {"api_key": api_key} + if provider_name == "openai": + provider_kwargs["api_base"] = api_base + if model: + provider_kwargs["model"] = model + elif provider_name == "anthropic": + if model: + provider_kwargs["model"] = model + + self.llm_provider = get_provider(provider_name, **provider_kwargs) + self.concierge_url = f"http://{SERVER_HOST}:{SERVER_PORT}" - + self.mode = Mode.USER self.in_context_servers = [] - - self.workflow_sessions = {} + + self.workflow_sessions = {} self.current_workflow = None self.current_tools = [] self.verbose = verbose - + self.conversation_history = [{ "role": "system", "content": """You are an AI assistant with access to remote Concierge workflows. @@ -83,13 +95,13 @@ def __init__(self, api_base: str, api_key: str, verbose: bool = False): Always provide a seamless, conversational experience and explain what you're doing.""" }] - - + + def _log(self, message: str, style: str = "info"): """Pretty print log messages""" if not self.verbose: return - + colors = { "info": "\033[36m", # Cyan "success": "\033[32m", # Green @@ -99,17 +111,17 @@ def _log(self, message: str, style: str = "info"): } color = colors.get(style, colors["info"]) print(f"{color}{message}{colors['reset']}") - + def _status(self, message: str, icon: str = "→"): """Print status message with elegant styling""" print(f" \033[38;5;110m{icon}\033[0m \033[38;5;252m{message}\033[0m") - + def _success(self, message: str, detail: str = ""): """Print success message with checkmark""" print(f" \033[38;5;82m✓\033[0m \033[1;38;5;189m{message}\033[0m") if detail: print(f" \033[38;5;245m{detail}\033[0m") - + def _action(self, heading: str, detail: str | None = None): """Print action being performed""" line = f" \033[38;5;183m▪\033[0m \033[38;5;252m{heading}\033[0m" @@ -117,7 +129,7 @@ def _action(self, heading: str, detail: str | None = None): line += f": \033[38;5;110m{detail}\033[0m" print(line) - + def search_remote_servers(self, search_query: str) -> list: """Search for available remote servers/workflows - ALWAYS OVERWRITES in-context servers""" try: @@ -125,73 +137,73 @@ def search_remote_servers(self, search_query: str) -> list: response = requests.get(f"{self.concierge_url}/api/workflows", params={"search": search_query}) response.raise_for_status() workflows = response.json().get('workflows', []) - + self.in_context_servers = workflows self._log(f"[SEARCH] Query: '{search_query}'", "info") self._log(f"[IN-CONTEXT] Found {len(workflows)} servers", "success") - + if workflows: self._success(f"Found {len(workflows)} workflow{'s' if len(workflows) != 1 else ''}") - + return workflows except Exception as e: self._log(f"[ERROR] Search failed: {e}", "error") self.in_context_servers = [] return [] - + def establish_connection(self, server_name: str) -> dict: """Establish connection with a discovered server and switch to SERVER MODE""" try: server = next((s for s in self.in_context_servers if s.get("name") == server_name), None) if not server: return {"error": f"Server '{server_name}' not found in current context. Search first."} - + self._status(f"Connecting to {server_name}") - + server_url = server.get("url", f"{self.concierge_url}") headers = {} if server_name in self.workflow_sessions: headers["X-Session-Id"] = self.workflow_sessions[server_name] - + payload = { "action": "handshake", "workflow_name": server_name } - + response = requests.post(f"{server_url}/execute", json=payload, headers=headers, timeout=30) response.raise_for_status() - + if 'X-Session-Id' in response.headers: self.workflow_sessions[server_name] = response.headers['X-Session-Id'] - + result = response.json() - - self.current_tools = self.concierge_to_openai_tools(result.get("tools", [])) + + self.current_tools = self.llm_provider.convert_tools(result.get("tools", [])) self.current_workflow = server_name - + self.mode = Mode.SERVER - + self._log(f"[CONNECTED] Session: {self.workflow_sessions.get(server_name, 'N/A')[:8]}...", "success") self._log(f"[MODE] Switched to SERVER mode", "info") self._log(f"[TOOLS] {len(self.current_tools)} tools available", "info") - + # Beautiful success message session_id = self.workflow_sessions.get(server_name, 'N/A') current_stage = result.get('current_stage', 'unknown') self._success(f"Connected to {server_name}", f"session: {session_id[:8]}... • stage: {current_stage}") - + return { "status": "connected", "server": server_name, "current_stage": result.get("current_stage"), - "tools": [t["function"]["name"] for t in self.current_tools], + "tools": [t.get("function", {}).get("name") or t.get("name") for t in self.current_tools], "message": f"Connected to {server_name}. Ready to use server tools." } - + except Exception as e: self._log(f"[ERROR] Connection failed: {e}", "error") return {"error": str(e)} - + def get_user_mode_tools(self) -> list: """Tools available in USER mode - dynamically generates establish_connection with in-context servers""" tools = [ @@ -213,7 +225,7 @@ def get_user_mode_tools(self) -> list: } } ] - + if self.in_context_servers: server_options = [] for server in self.in_context_servers: @@ -224,7 +236,7 @@ def get_user_mode_tools(self) -> list: "const": name, "description": desc }) - + establish_tool = { "type": "function", "function": { @@ -238,7 +250,7 @@ def get_user_mode_tools(self) -> list: * Use all the server's workflow tools * Perform tasks specific to that server * Disconnect when done to search for other servers - + After connecting, you'll have access to the server's tools to help the user accomplish their goal.""", "parameters": { "type": "object", @@ -252,72 +264,72 @@ def get_user_mode_tools(self) -> list: } } tools.append(establish_tool) - - return tools - - + + return self.llm_provider.convert_tools(tools) + + def call_workflow(self, workflow_name: str, payload: dict) -> dict: """Call current workflow with an action""" if workflow_name not in self.workflow_sessions: raise ValueError(f"Not connected to workflow: {workflow_name}") - + headers = {"X-Session-Id": self.workflow_sessions[workflow_name]} payload["workflow_name"] = workflow_name - + self._log(f"[{workflow_name.upper()}] Action: {payload.get('action')}", "info") - + response = requests.post(f"{self.concierge_url}/execute", json=payload, headers=headers) response.raise_for_status() - + result = json.loads(response.text) - + # Update tools if they changed if "tools" in result: - self.current_tools = self.concierge_to_openai_tools(result["tools"]) - + self.current_tools = self.llm_provider.convert_tools(result["tools"]) + return result - + def disconnect_server(self) -> dict: """Disconnect from current server and return to USER mode""" if not self.current_workflow: return {"status": "no_active_connection"} - + try: workflow_name = self.current_workflow - + if workflow_name in self.workflow_sessions: headers = {"X-Session-Id": self.workflow_sessions[workflow_name]} payload = {"action": "terminate_session", "workflow_name": workflow_name} - + try: response = requests.post(f"{self.concierge_url}/execute", json=payload, headers=headers, timeout=10) except: pass - + if workflow_name in self.workflow_sessions: del self.workflow_sessions[workflow_name] - + self.current_workflow = None self.current_tools = [] self.in_context_servers = [] self.mode = Mode.USER - + self._log(f"[DISCONNECTED] Server: {workflow_name}", "info") self._log(f"[MODE] Switched to USER mode", "info") print(f" \033[38;5;147m◇\033[0m \033[38;5;245mDisconnected from\033[0m \033[38;5;183m{workflow_name}\033[0m") - + return {"status": "disconnected", "server": workflow_name} - + except Exception as e: print(f"[ERROR] Disconnect failed: {e}") return {"error": str(e)} - + def get_server_mode_tools(self) -> list: """Tools available in SERVER mode""" tools = list(self.current_tools) # Copy workflow tools - + # Add disconnect tool with detailed context - tools.append({ + disconnect_tool = { "type": "function", "function": { "name": "disconnect_server", @@ -330,7 +342,7 @@ def get_server_mode_tools(self) -> list: * Search for other remote servers * Discover different workflows * Connect to a different server - + Use this when: - The user wants to work with a different server - The current task is complete @@ -342,30 +354,18 @@ def get_server_mode_tools(self) -> list: "required": [] } } - }) - + } + tools.extend(self.llm_provider.convert_tools([disconnect_tool])) + return tools - - - def concierge_to_openai_tools(self, concierge_tools: list) -> list: - """Convert Concierge tools to OpenAI format""" - openai_tools = [] - for tool in concierge_tools: - openai_tools.append({ - "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": tool["input_schema"] - } - }) - return openai_tools - + + + def openai_to_concierge_action(self, tool_call) -> dict: """Convert OpenAI tool_call to Concierge contract""" - function_name = tool_call.function.name - arguments = json.loads(tool_call.function.arguments) - + function_name = tool_call["function"]["name"] + arguments = json.loads(tool_call["function"]["arguments"]) + # Server control actions if function_name == "transition_stage": return {"action": "stage_transition", "stage": arguments["target_stage"]} @@ -375,59 +375,46 @@ def openai_to_concierge_action(self, tool_call) -> dict: return {"action": "terminate_session", "reason": arguments.get("reason", "completed")} else: return {"action": "method_call", "task": function_name, "args": arguments} - - + + def chat(self, user_message: str) -> str: """Main chat loop with mode-aware tool selection""" self.conversation_history.append({ "role": "user", "content": user_message }) - + max_iterations = 15 for iteration in range(max_iterations): if self.mode == Mode.USER: tools = self.get_user_mode_tools() else: tools = self.get_server_mode_tools() - + self._log(f"[ITERATION {iteration + 1}] Mode: {self.mode.value.upper()}, Tools: {len(tools)}", "info") - + spinner_message = "Executing workflow plan" if self.mode == Mode.SERVER else "Evaluating request" with Spinner(spinner_message): - response = self.llm.chat.completions.create( - model=self.model, + response = self.llm_provider.chat( messages=self.conversation_history, - tools=tools, - tool_choice="auto" + tools=tools ) - - message = response.choices[0].message - - assistant_message = {"role": "assistant", "content": message.content} - if message.tool_calls: - assistant_message["tool_calls"] = [ - { - "id": tc.id, - "type": tc.type, - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments - } - } - for tc in message.tool_calls - ] + parsed_response = self.llm_provider.parse_response(response) + + assistant_message = {"role": "assistant", "content": parsed_response["content"]} + if parsed_response["tool_calls"]: + assistant_message["tool_calls"] = parsed_response["tool_calls"] self.conversation_history.append(assistant_message) - - if not message.tool_calls: - print(f" \033[38;5;147m◈\033[0m \033[1;38;5;252m{message.content}\033[0m") - return message.content - - for tool_call in message.tool_calls: - function_name = tool_call.function.name - arguments = json.loads(tool_call.function.arguments) - + + if not parsed_response["tool_calls"]: + print(f" \033[38;5;147m◈\033[0m \033[1;38;5;252m{parsed_response['content']}\033[0m") + return parsed_response["content"] + + for tool_call in parsed_response["tool_calls"]: + function_name = tool_call["function"]["name"] + arguments = json.loads(tool_call["function"]["arguments"]) + if function_name == "transition_stage": self._action("Stage transition", arguments.get('target_stage')) elif function_name == "get_all_products": @@ -459,7 +446,7 @@ def chat(self, user_message: str) -> str: self._action(readable) self._log(f"[TOOL CALL] {function_name}({json.dumps(arguments, indent=2)})", "info") - + if self.mode == Mode.USER: if function_name == "search_remote_servers": servers = self.search_remote_servers(arguments["search_query"]) @@ -468,19 +455,19 @@ def chat(self, user_message: str) -> str: "count": len(servers), "message": f"Found {len(servers)} servers. Use establish_connection to connect." }) - + elif function_name == "establish_connection": result = self.establish_connection(arguments["server_name"]) result_content = json.dumps(result) - + else: result_content = json.dumps({"error": f"Tool '{function_name}' not available in USER mode"}) - + else: if function_name == "disconnect_server": result = self.disconnect_server() result_content = json.dumps(result) - + else: if not self.current_workflow: result_content = json.dumps({"error": "Not connected to any server"}) @@ -488,18 +475,18 @@ def chat(self, user_message: str) -> str: action = self.openai_to_concierge_action(tool_call) result = self.call_workflow(self.current_workflow, action) result_content = result.get("content", json.dumps(result)) - + if "current_stage" in result: print(f"\033[90m Current stage: {result['current_stage']}\033[0m") - + self.conversation_history.append({ "role": "tool", - "tool_call_id": tool_call.id, + "tool_call_id": tool_call["id"], "content": result_content }) - + return "Max iterations reached. Please try again." - + def run(self): """Interactive chat loop""" # Clean, powerful banner @@ -512,19 +499,19 @@ def run(self): print("\033[38;5;147m│\033[0m \033[38;5;147m│\033[0m") print("\033[38;5;147m╰────────────────────────────────────────────────────────╯\033[0m") print() - + # Status line status_parts = [] - status_parts.append(f"\033[38;5;183m{self.model}\033[0m") + status_parts.append(f"\033[38;5;183m{self.provider_name.upper()}\033[0m") status_parts.append(f"\033[38;5;147m{self.mode.value.upper()}\033[0m") if self.verbose: status_parts.append("\033[38;5;110mVERBOSE\033[0m") - + separator = " \033[38;5;238m•\033[38;5;245m " print(f"\033[38;5;245m {separator.join(status_parts)}\033[0m") print(f"\033[38;5;245m Type \033[3;38;5;183mexit\033[0m \033[38;5;245mto quit\033[0m") print() - + while True: try: user_input = input(" \033[1;38;5;189m›\033[0m \033[1mYou:\033[0m ").strip() @@ -537,11 +524,11 @@ def run(self): print("\033[38;5;147m╰────────────────────────────────────────────────────────╯\033[0m") print() break - + print() # Add spacing before response self.chat(user_input) print() # Add spacing after response - + except KeyboardInterrupt: print() print() @@ -556,12 +543,19 @@ def run(self): traceback.print_exc() if __name__ == "__main__": - if len(sys.argv) < 3: - print("Usage: python client_tool_calling.py ") - sys.exit(1) - - api_base = sys.argv[1] - api_key = sys.argv[2] - - client = ToolCallingClient(api_base, api_key) + import argparse + + parser = argparse.ArgumentParser(description="Concierge Tool Calling Client") + parser.add_argument("api_base", help="API Base URL (for OpenAI) or ignored for Anthropic") + parser.add_argument("api_key", help="API Key") + parser.add_argument("--provider", default="openai", choices=["openai", "anthropic"], help="LLM Provider") + parser.add_argument("--model", help="Model name") + parser.add_argument("--verbose", action="store_true", help="Verbose logging") + + args = parser.parse_args() + + # For OpenAI, api_base is required. For Anthropic, it might not be relevant but we keep the signature. + # If provider is Anthropic, api_base might be ignored or used as something else if needed. + + client = ToolCallingClient(args.api_base, args.api_key, args.provider, args.model, args.verbose) client.run() diff --git a/src/concierge_clients/providers/anthropic_provider.py b/src/concierge_clients/providers/anthropic_provider.py new file mode 100644 index 0000000..a51a96e --- /dev/null +++ b/src/concierge_clients/providers/anthropic_provider.py @@ -0,0 +1,110 @@ +from typing import Any, Dict, List, Optional +import json +try: + from anthropic import Anthropic +except ImportError: + Anthropic = None + +from .base import LLMProvider + +class AnthropicProvider(LLMProvider): + def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"): + if Anthropic is None: + raise ImportError("anthropic package is required for AnthropicProvider") + self.client = Anthropic(api_key=api_key) + self.model = model + + def chat(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None) -> Any: + # Convert OpenAI-style messages to Anthropic format if necessary + # Anthropic expects 'user' and 'assistant' roles. 'system' message is a separate parameter. + + system_prompt = None + anthropic_messages = [] + + for msg in messages: + if msg["role"] == "system": + system_prompt = msg["content"] + elif msg["role"] == "tool": + # Anthropic tool results + anthropic_messages.append({ + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg["tool_call_id"], + "content": msg["content"] + } + ] + }) + elif "tool_calls" in msg and msg["tool_calls"]: + # Assistant message with tool calls + content = [] + if msg["content"]: + content.append({"type": "text", "text": msg["content"]}) + + for tc in msg["tool_calls"]: + content.append({ + "type": "tool_use", + "id": tc["id"], + "name": tc["function"]["name"], + "input": json.loads(tc["function"]["arguments"]) + }) + anthropic_messages.append({"role": "assistant", "content": content}) + else: + anthropic_messages.append(msg) + + kwargs = { + "model": self.model, + "messages": anthropic_messages, + "max_tokens": 4096, + } + + if system_prompt: + kwargs["system"] = system_prompt + + if tools: + kwargs["tools"] = tools + + return self.client.messages.create(**kwargs) + + def convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + anthropic_tools = [] + for tool in tools: + # Check if it's already in OpenAI format (which we might receive) and convert to Anthropic + if "type" in tool and tool["type"] == "function": + func = tool["function"] + anthropic_tools.append({ + "name": func["name"], + "description": func["description"], + "input_schema": func["parameters"] + }) + else: + # Assume Concierge format which is close to Anthropic's + anthropic_tools.append({ + "name": tool["name"], + "description": tool["description"], + "input_schema": tool["input_schema"] + }) + return anthropic_tools + + def parse_response(self, response: Any) -> Dict[str, Any]: + content = "" + tool_calls = [] + + for block in response.content: + if block.type == "text": + content += block.text + elif block.type == "tool_use": + tool_calls.append({ + "id": block.id, + "type": "function", + "function": { + "name": block.name, + "arguments": json.dumps(block.input) + } + }) + + return { + "content": content, + "tool_calls": tool_calls + } diff --git a/src/concierge_clients/providers/base.py b/src/concierge_clients/providers/base.py new file mode 100644 index 0000000..4601272 --- /dev/null +++ b/src/concierge_clients/providers/base.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +class LLMProvider(ABC): + """Abstract base class for LLM providers.""" + + @abstractmethod + def chat(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None) -> Any: + """ + Send a chat request to the provider. + + Args: + messages: List of message dictionaries (role, content). + tools: Optional list of tools in the provider's format (or generic format to be converted). + + Returns: + The raw response object from the provider. + """ + pass + + @abstractmethod + def convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Convert generic tools to the provider-specific format. + + Args: + tools: List of generic tool definitions. + + Returns: + List of tools in the provider's format. + """ + pass + + @abstractmethod + def parse_response(self, response: Any) -> Dict[str, Any]: + """ + Parse the provider's response into a standard format. + + Args: + response: The raw response from the provider. + + Returns: + A dictionary containing: + - content: The text content of the response. + - tool_calls: A list of tool calls (if any). + """ + pass diff --git a/src/concierge_clients/providers/factory.py b/src/concierge_clients/providers/factory.py new file mode 100644 index 0000000..2b2fd45 --- /dev/null +++ b/src/concierge_clients/providers/factory.py @@ -0,0 +1,25 @@ +from typing import Optional +from .base import LLMProvider +from .openai_provider import OpenAIProvider +from .anthropic_provider import AnthropicProvider + +def get_provider(provider_name: str, **kwargs) -> LLMProvider: + """ + Factory function to get an LLM provider instance. + + Args: + provider_name: The name of the provider ("openai" or "anthropic"). + **kwargs: Arguments to pass to the provider constructor. + + Returns: + An instance of LLMProvider. + + Raises: + ValueError: If the provider name is not supported. + """ + if provider_name.lower() == "openai": + return OpenAIProvider(**kwargs) + elif provider_name.lower() == "anthropic": + return AnthropicProvider(**kwargs) + else: + raise ValueError(f"Unsupported provider: {provider_name}") diff --git a/src/concierge_clients/providers/openai_provider.py b/src/concierge_clients/providers/openai_provider.py new file mode 100644 index 0000000..3953ada --- /dev/null +++ b/src/concierge_clients/providers/openai_provider.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, List, Optional +from openai import OpenAI +from .base import LLMProvider + +class OpenAIProvider(LLMProvider): + def __init__(self, api_base: str, api_key: str, model: str = "gpt-5"): + self.client = OpenAI(base_url=api_base, api_key=api_key) + self.model = model + + def chat(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None) -> Any: + kwargs = { + "model": self.model, + "messages": messages, + } + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = "auto" + + return self.client.chat.completions.create(**kwargs) + + def convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + # OpenAI expects tools in the format: {"type": "function", "function": {...}} + # Assuming input tools are already in or close to this format, or need wrapping. + # Based on client_tool_calling.py, it seems we might receive Concierge tools and need to convert them. + # Let's reuse the logic from client_tool_calling.py if possible, or reimplement it here. + + openai_tools = [] + for tool in tools: + # Check if it's already in OpenAI format + if "type" in tool and tool["type"] == "function": + openai_tools.append(tool) + else: + # Convert from Concierge format + openai_tools.append({ + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["input_schema"] + } + }) + return openai_tools + + def parse_response(self, response: Any) -> Dict[str, Any]: + message = response.choices[0].message + result = { + "content": message.content, + "tool_calls": [] + } + + if message.tool_calls: + result["tool_calls"] = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + } + for tc in message.tool_calls + ] + + return result diff --git a/tests/test_provider_abstraction.py b/tests/test_provider_abstraction.py new file mode 100644 index 0000000..05ac8c1 --- /dev/null +++ b/tests/test_provider_abstraction.py @@ -0,0 +1,75 @@ +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch +import sys + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Mock openai and requests before importing modules that use them +sys.modules["openai"] = MagicMock() +sys.modules["requests"] = MagicMock() +sys.modules["yaml"] = MagicMock() + +# Mock pydantic with a real class for BaseModel so issubclass works +class MockBaseModel: + pass +mock_pydantic = MagicMock() +mock_pydantic.BaseModel = MockBaseModel +sys.modules["pydantic"] = mock_pydantic +sys.modules["pydantic_core"] = MagicMock() + +# Mock concierge package to avoid importing dependencies +mock_concierge = MagicMock() +mock_config = MagicMock() +mock_config.SERVER_HOST = "localhost" +mock_config.SERVER_PORT = 8080 +mock_concierge.config = mock_config +sys.modules["concierge"] = mock_concierge +sys.modules["concierge.config"] = mock_config + +from concierge_clients.providers.factory import get_provider +from concierge_clients.providers.openai_provider import OpenAIProvider +from concierge_clients.providers.anthropic_provider import AnthropicProvider +from concierge_clients.client_tool_calling import ToolCallingClient + +def test_factory(): + openai_provider = get_provider("openai", api_base="http://test", api_key="test") + assert isinstance(openai_provider, OpenAIProvider) + + # Mock Anthropic import for test environment where it might not be installed + with patch("concierge_clients.providers.anthropic_provider.Anthropic") as mock_anthropic: + anthropic_provider = get_provider("anthropic", api_key="test") + assert isinstance(anthropic_provider, AnthropicProvider) + +def test_client_initialization(): + # Test OpenAI initialization + client_openai = ToolCallingClient(api_base="http://test", api_key="test", provider_name="openai") + assert isinstance(client_openai.llm_provider, OpenAIProvider) + + # Test Anthropic initialization + with patch("concierge_clients.providers.anthropic_provider.Anthropic") as mock_anthropic: + client_anthropic = ToolCallingClient(api_base="http://test", api_key="test", provider_name="anthropic") + assert isinstance(client_anthropic.llm_provider, AnthropicProvider) + +def test_openai_provider_conversion(): + provider = OpenAIProvider(api_base="http://test", api_key="test") + tools = [{"name": "test_tool", "description": "test", "input_schema": {"type": "object"}}] + converted = provider.convert_tools(tools) + assert converted[0]["type"] == "function" + assert converted[0]["function"]["name"] == "test_tool" + +def test_anthropic_provider_conversion(): + with patch("concierge_clients.providers.anthropic_provider.Anthropic") as mock_anthropic: + provider = AnthropicProvider(api_key="test") + tools = [{"name": "test_tool", "description": "test", "input_schema": {"type": "object"}}] + converted = provider.convert_tools(tools) + assert converted[0]["name"] == "test_tool" + assert "input_schema" in converted[0] + +if __name__ == "__main__": + test_factory() + test_client_initialization() + test_openai_provider_conversion() + test_anthropic_provider_conversion() + print("All tests passed!")