Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 64 additions & 38 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,37 @@ def remove_signature_from_tool_description(name: str, description: str) -> str:
@staticmethod
def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
"""Convert an OCI tool call to a LangChain ToolCall."""
parsed = json.loads(tool_call.arguments)

# If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501
if isinstance(parsed, str):
try:
parsed = json.loads(parsed)
except json.JSONDecodeError:
# If it's not valid JSON, keep it as a string
pass
# Check if this is a Generic/Meta format (has arguments as JSON string)
# or Cohere format (has parameters as dict)
attribute_map = getattr(tool_call, "attribute_map", None) or {}

if "arguments" in attribute_map and tool_call.arguments is not None:
# Generic/Meta format: parse JSON arguments
parsed = json.loads(tool_call.arguments)

# If the parsed result is a string, it means JSON was escaped
if isinstance(parsed, str):
try:
parsed = json.loads(parsed)
except json.JSONDecodeError:
# If it's not valid JSON, keep it as a string
pass
args = parsed
else:
# Cohere format: parameters is already a dict
args = tool_call.parameters

# Get tool call ID (generate one if not present)
tool_id = (
tool_call.id
if "id" in attribute_map
else uuid.uuid4().hex[:]
)

return ToolCall(
name=tool_call.name,
args=parsed
if "arguments" in tool_call.attribute_map
else tool_call.parameters,
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
args=args,
id=tool_id,
)


Expand Down Expand Up @@ -263,19 +278,19 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
}

# Include token usage if available
if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why need to replace the old logic with try except in this PR?

if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
except (KeyError, AttributeError):
pass

# Include tool calls if available
if self.chat_tool_calls(response):
generation_info["tool_calls"] = self.format_response_tool_calls(
self.chat_tool_calls(response)
)
# Note: tool_calls are now handled in _generate() to avoid redundant conversions
# The formatted tool calls will be added there if present
return generation_info

def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
Expand Down Expand Up @@ -643,18 +658,19 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
}

# Include token usage if available
if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why need to replace the old logic with try except in this PR?

if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
except (KeyError, AttributeError):
pass

if self.chat_tool_calls(response):
generation_info["tool_calls"] = self.format_response_tool_calls(
self.chat_tool_calls(response)
)
# Note: tool_calls are now handled in _generate() to avoid redundant conversions
# The formatted tool calls will be added there if present
return generation_info

def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
Expand Down Expand Up @@ -1400,6 +1416,9 @@ def _generate(
if stop is not None:
content = enforce_stop_tokens(content, stop)

# Fetch raw tool calls once to avoid redundant calls
raw_tool_calls = self._provider.chat_tool_calls(response)

generation_info = self._provider.chat_generation_info(response)

llm_output = {
Expand All @@ -1408,12 +1427,19 @@ def _generate(
"request_id": response.request_id,
"content-length": response.headers["content-length"],
}

# Convert tool calls once for LangChain format
tool_calls = []
if "tool_calls" in generation_info:
if raw_tool_calls:
tool_calls = [
OCIUtils.convert_oci_tool_call_to_langchain(tool_call)
for tool_call in self._provider.chat_tool_calls(response)
for tool_call in raw_tool_calls
]
# Add formatted version to generation_info if not already present
# This avoids redundant formatting in chat_generation_info()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we do not need so many explanations in the comment here? The comment does not need to record the history of code changes.

if "tool_calls" not in generation_info:
formatted = self._provider.format_response_tool_calls(raw_tool_calls)
generation_info["tool_calls"] = formatted
message = AIMessage(
content=content or "",
additional_kwargs=generation_info,
Expand Down
Loading
Loading