Skip to content
62 changes: 38 additions & 24 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,37 @@
@staticmethod
def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
"""Convert an OCI tool call to a LangChain ToolCall."""
parsed = json.loads(tool_call.arguments)

# If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501
if isinstance(parsed, str):
try:
parsed = json.loads(parsed)
except json.JSONDecodeError:
# If it's not valid JSON, keep it as a string
pass
# Check if this is a Generic/Meta format (has arguments as JSON string)
# or Cohere format (has parameters as dict)
attribute_map = getattr(tool_call, "attribute_map", None) or {}

if "arguments" in attribute_map and tool_call.arguments is not None:
# Generic/Meta format: parse JSON arguments
parsed = json.loads(tool_call.arguments)

# If the parsed result is a string, it means JSON was escaped
if isinstance(parsed, str):
try:
parsed = json.loads(parsed)
except json.JSONDecodeError:
# If it's not valid JSON, keep it as a string
pass
args = parsed
else:
# Cohere format: parameters is already a dict
args = tool_call.parameters

# Get tool call ID (generate one if not present)
tool_id = (
tool_call.id
if "id" in attribute_map
else uuid.uuid4().hex[:]
)

return ToolCall(
name=tool_call.name,
args=parsed
if "arguments" in tool_call.attribute_map
else tool_call.parameters,
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
args=args,
id=tool_id,
)


Expand Down Expand Up @@ -271,11 +286,6 @@
response.data.chat_response.usage.total_tokens
)

# Include tool calls if available
if self.chat_tool_calls(response):
generation_info["tool_calls"] = self.format_response_tool_calls(
self.chat_tool_calls(response)
)
return generation_info

def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
Expand Down Expand Up @@ -651,10 +661,6 @@
response.data.chat_response.usage.total_tokens
)

if self.chat_tool_calls(response):
generation_info["tool_calls"] = self.format_response_tool_calls(
self.chat_tool_calls(response)
)
return generation_info

def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
Expand Down Expand Up @@ -1400,6 +1406,9 @@
if stop is not None:
content = enforce_stop_tokens(content, stop)

# Fetch raw tool calls once to avoid redundant calls
raw_tool_calls = self._provider.chat_tool_calls(response)

generation_info = self._provider.chat_generation_info(response)

llm_output = {
Expand All @@ -1408,12 +1417,17 @@
"request_id": response.request_id,
"content-length": response.headers["content-length"],
}

tool_calls = []
if "tool_calls" in generation_info:
if raw_tool_calls:
tool_calls = [
OCIUtils.convert_oci_tool_call_to_langchain(tool_call)
for tool_call in self._provider.chat_tool_calls(response)
for tool_call in raw_tool_calls
]
if "tool_calls" not in generation_info:
generation_info["tool_calls"] = self._provider.format_response_tool_calls(
raw_tool_calls
)
message = AIMessage(
content=content or "",
additional_kwargs=generation_info,
Expand Down Expand Up @@ -1484,7 +1498,7 @@

Attributes:
auth (httpx.Auth): Authentication handler for OCI request signing.
compartment_id (str): OCI compartment ID for resource isolation

Check failure on line 1501 in libs/oci/langchain_oci/chat_models/oci_generative_ai.py

View workflow job for this annotation

GitHub Actions / cd libs/oci / make lint #3.9

Ruff (E501)

langchain_oci/chat_models/oci_generative_ai.py:1501:89: E501 Line too long (90 > 88)
model (str): Name of OpenAI model to use.
conversation_store_id (str | None): Conversation Store Id to use
when generating responses.
Expand Down
Loading