-
Notifications
You must be signed in to change notification settings - Fork 16
Optimize tool call conversions to eliminate redundant API lookups #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
0465fa5
05de5f4
d6dddb1
fd52dc5
24f31b6
8665f8e
864a3b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,22 +100,37 @@ def remove_signature_from_tool_description(name: str, description: str) -> str: | |
| @staticmethod | ||
| def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall: | ||
| """Convert an OCI tool call to a LangChain ToolCall.""" | ||
| parsed = json.loads(tool_call.arguments) | ||
|
|
||
| # If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501 | ||
| if isinstance(parsed, str): | ||
| try: | ||
| parsed = json.loads(parsed) | ||
| except json.JSONDecodeError: | ||
| # If it's not valid JSON, keep it as a string | ||
| pass | ||
| # Check if this is a Generic/Meta format (has arguments as JSON string) | ||
| # or Cohere format (has parameters as dict) | ||
| attribute_map = getattr(tool_call, "attribute_map", None) or {} | ||
|
|
||
| if "arguments" in attribute_map and tool_call.arguments is not None: | ||
| # Generic/Meta format: parse JSON arguments | ||
| parsed = json.loads(tool_call.arguments) | ||
|
|
||
| # If the parsed result is a string, it means JSON was escaped | ||
| if isinstance(parsed, str): | ||
| try: | ||
| parsed = json.loads(parsed) | ||
| except json.JSONDecodeError: | ||
| # If it's not valid JSON, keep it as a string | ||
| pass | ||
| args = parsed | ||
| else: | ||
| # Cohere format: parameters is already a dict | ||
| args = tool_call.parameters | ||
|
|
||
| # Get tool call ID (generate one if not present) | ||
| tool_id = ( | ||
| tool_call.id | ||
| if "id" in attribute_map | ||
| else uuid.uuid4().hex[:] | ||
| ) | ||
|
|
||
| return ToolCall( | ||
| name=tool_call.name, | ||
| args=parsed | ||
| if "arguments" in tool_call.attribute_map | ||
| else tool_call.parameters, | ||
| id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:], | ||
| args=args, | ||
| id=tool_id, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -263,19 +278,19 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: | |
| } | ||
|
|
||
| # Include token usage if available | ||
| if ( | ||
| hasattr(response.data.chat_response, "usage") | ||
| and response.data.chat_response.usage | ||
| ): | ||
| generation_info["total_tokens"] = ( | ||
| response.data.chat_response.usage.total_tokens | ||
| ) | ||
| try: | ||
| if ( | ||
| hasattr(response.data.chat_response, "usage") | ||
| and response.data.chat_response.usage | ||
| ): | ||
| generation_info["total_tokens"] = ( | ||
| response.data.chat_response.usage.total_tokens | ||
| ) | ||
| except (KeyError, AttributeError): | ||
| pass | ||
|
|
||
| # Include tool calls if available | ||
| if self.chat_tool_calls(response): | ||
| generation_info["tool_calls"] = self.format_response_tool_calls( | ||
| self.chat_tool_calls(response) | ||
| ) | ||
| # Note: tool_calls are now handled in _generate() to avoid redundant conversions | ||
| # The formatted tool calls will be added there if present | ||
| return generation_info | ||
|
|
||
| def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: | ||
|
|
@@ -643,18 +658,19 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: | |
| } | ||
|
|
||
| # Include token usage if available | ||
| if ( | ||
| hasattr(response.data.chat_response, "usage") | ||
| and response.data.chat_response.usage | ||
| ): | ||
| generation_info["total_tokens"] = ( | ||
| response.data.chat_response.usage.total_tokens | ||
| ) | ||
| try: | ||
|
||
| if ( | ||
| hasattr(response.data.chat_response, "usage") | ||
| and response.data.chat_response.usage | ||
| ): | ||
| generation_info["total_tokens"] = ( | ||
| response.data.chat_response.usage.total_tokens | ||
| ) | ||
| except (KeyError, AttributeError): | ||
| pass | ||
|
|
||
| if self.chat_tool_calls(response): | ||
| generation_info["tool_calls"] = self.format_response_tool_calls( | ||
| self.chat_tool_calls(response) | ||
| ) | ||
| # Note: tool_calls are now handled in _generate() to avoid redundant conversions | ||
| # The formatted tool calls will be added there if present | ||
| return generation_info | ||
|
|
||
| def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: | ||
|
|
@@ -1400,6 +1416,9 @@ def _generate( | |
| if stop is not None: | ||
| content = enforce_stop_tokens(content, stop) | ||
|
|
||
| # Fetch raw tool calls once to avoid redundant calls | ||
| raw_tool_calls = self._provider.chat_tool_calls(response) | ||
|
|
||
| generation_info = self._provider.chat_generation_info(response) | ||
|
|
||
| llm_output = { | ||
|
|
@@ -1408,12 +1427,19 @@ def _generate( | |
| "request_id": response.request_id, | ||
| "content-length": response.headers["content-length"], | ||
| } | ||
|
|
||
| # Convert tool calls once for LangChain format | ||
| tool_calls = [] | ||
| if "tool_calls" in generation_info: | ||
| if raw_tool_calls: | ||
| tool_calls = [ | ||
| OCIUtils.convert_oci_tool_call_to_langchain(tool_call) | ||
| for tool_call in self._provider.chat_tool_calls(response) | ||
| for tool_call in raw_tool_calls | ||
| ] | ||
| # Add formatted version to generation_info if not already present | ||
| # This avoids redundant formatting in chat_generation_info() | ||
|
||
| if "tool_calls" not in generation_info: | ||
| formatted = self._provider.format_response_tool_calls(raw_tool_calls) | ||
| generation_info["tool_calls"] = formatted | ||
| message = AIMessage( | ||
| content=content or "", | ||
| additional_kwargs=generation_info, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering why need to replace the old logic with try except in this PR?