Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def run(
if signature(self._run).parameters.get("run_manager"):
tool_kwargs |= {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs |= {config_param: config}
tool_kwargs |= {config_param: child_config}
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
Expand Down Expand Up @@ -976,7 +976,7 @@ async def arun(
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
tool_kwargs[config_param] = child_config

coro = self._arun(*tool_args, **tool_kwargs)
response = await coro_with_context(coro, context)
Expand Down Expand Up @@ -1258,6 +1258,11 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]:
for name, type_ in type_hints.items():
if type_ is RunnableConfig:
return name
# Handle Optional[RunnableConfig] and Union[RunnableConfig, None, ...]
if get_origin(type_) is Union:
union_args = get_args(type_)
if RunnableConfig in union_args:
return name
return None


Expand Down
46 changes: 46 additions & 0 deletions libs/core/tests/unit_tests/utils/test_json_schema_recursion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from langchain_core.utils.json_schema import dereference_refs


def test_dereference_refs_self_reference_no_recursion() -> None:
"""Ensure self-referential schemas are handled without infinite recursion."""
schema = {
"$defs": {
"Node": {
"type": "object",
"properties": {
"value": {"type": "string"},
"child": {"$ref": "#/$defs/Node"},
},
}
},
"type": "object",
"properties": {"root": {"$ref": "#/$defs/Node"}},
}

# Should not raise RecursionError and should return a dictionary
actual = dereference_refs(schema)
assert isinstance(actual, dict)
# The $defs should be preserved and recursion should be broken within
# dereferenced parts
assert "$defs" in actual
assert "properties" in actual


def test_dereference_refs_circular_chain_no_recursion() -> None:
"""Ensure multi-node circular chains are handled without infinite recursion."""
schema = {
"$defs": {
"A": {"type": "object", "properties": {"to_b": {"$ref": "#/$defs/B"}}},
"B": {"type": "object", "properties": {"to_c": {"$ref": "#/$defs/C"}}},
"C": {"type": "object", "properties": {"to_a": {"$ref": "#/$defs/A"}}},
},
"type": "object",
"properties": {"start": {"$ref": "#/$defs/A"}},
}

# Should not raise RecursionError
actual = dereference_refs(schema)
assert isinstance(actual, dict)
# Spot-check top-level dereference occurred
assert "properties" in actual
assert "start" in actual["properties"]
85 changes: 81 additions & 4 deletions libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import builtins
import contextlib
import importlib
import json
import logging
import time
Expand All @@ -19,7 +20,6 @@
cast,
)

import yaml
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
from langchain_core.callbacks import (
Expand Down Expand Up @@ -52,8 +52,14 @@
from langchain.chains.llm import LLMChain
from langchain.utilities.asyncio import asyncio_timeout

# Replace direct yaml import with dynamic import typed as Any
yaml: Any = importlib.import_module("yaml")

logger = logging.getLogger(__name__)

# Sentinel used to detect absence of attribute without widening type
_NOTSET = object()


class BaseSingleActionAgent(BaseModel):
"""Base Single Action Agent class."""
Expand Down Expand Up @@ -1183,6 +1189,55 @@ def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
return self.agent

@override
def invoke(
self,
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Invoke the agent with RunnableConfig support.

This override ensures that RunnableConfig is passed through to tools
during agent execution.
"""
# Store config temporarily for access during tool execution
old_config = getattr(self, "_current_config", _NOTSET)
self._current_config = config
try:
return super().invoke(input, config, **kwargs)
finally:
# Restore previous config
if old_config is _NOTSET:
if hasattr(self, "_current_config"):
delattr(self, "_current_config")
else:
self._current_config = cast("Optional[RunnableConfig]", old_config)

@override
async def ainvoke(
self,
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Async invoke the agent with RunnableConfig support.

This override ensures that RunnableConfig is passed through to tools
during agent execution.
"""
# Store config temporarily for access during tool execution
old_config = getattr(self, "_current_config", _NOTSET)
self._current_config = config
try:
return await super().ainvoke(input, config, **kwargs)
finally:
# Restore previous config
if old_config is _NOTSET:
if hasattr(self, "_current_config"):
delattr(self, "_current_config")
else:
self._current_config = cast("Optional[RunnableConfig]", old_config)

def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors.

Expand All @@ -1192,6 +1247,8 @@ def save(self, file_path: Union[Path, str]) -> None:
Raises:
ValueError: Saving not supported for agent executors.
"""
# mark variable as used for linters
_ = file_path
msg = (
"Saving not supported for agent executors. "
"If you are trying to save the agent, please use the "
Expand Down Expand Up @@ -1227,7 +1284,7 @@ def iter(
AgentExecutorIterator: Agent executor iterator object.
"""
return AgentExecutorIterator(
self,
cast("Any", self),
inputs,
callbacks,
tags=self.tags,
Expand Down Expand Up @@ -1316,6 +1373,7 @@ def _take_next_step(
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
config: Optional[RunnableConfig] = None,
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step(
list(
Expand All @@ -1325,6 +1383,7 @@ def _take_next_step(
inputs,
intermediate_steps,
run_manager,
config,
),
),
)
Expand All @@ -1336,6 +1395,7 @@ def _iter_next_step(
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
config: Optional[RunnableConfig] = None,
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop.

Expand Down Expand Up @@ -1406,6 +1466,7 @@ def _iter_next_step(
color_mapping,
agent_action,
run_manager,
config,
)

def _perform_agent_action(
Expand All @@ -1414,6 +1475,7 @@ def _perform_agent_action(
color_mapping: dict[str, str],
agent_action: AgentAction,
run_manager: Optional[CallbackManagerForChainRun] = None,
config: Optional[RunnableConfig] = None,
) -> AgentStep:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
Expand All @@ -1431,6 +1493,7 @@ def _perform_agent_action(
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
config=config,
**tool_run_kwargs,
)
else:
Expand All @@ -1454,6 +1517,7 @@ async def _atake_next_step(
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
config: Optional[RunnableConfig] = None,
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step(
[
Expand All @@ -1464,6 +1528,7 @@ async def _atake_next_step(
inputs,
intermediate_steps,
run_manager,
config,
)
],
)
Expand All @@ -1475,6 +1540,7 @@ async def _aiter_next_step(
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
config: Optional[RunnableConfig] = None,
) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop.

Expand Down Expand Up @@ -1546,6 +1612,7 @@ async def _aiter_next_step(
color_mapping,
agent_action,
run_manager,
config,
)
for agent_action in actions
],
Expand All @@ -1561,6 +1628,7 @@ async def _aperform_agent_action(
color_mapping: dict[str, str],
agent_action: AgentAction,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
config: Optional[RunnableConfig] = None,
) -> AgentStep:
if run_manager:
await run_manager.on_agent_action(
Expand All @@ -1582,6 +1650,7 @@ async def _aperform_agent_action(
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
config=config,
**tool_run_kwargs,
)
else:
Expand All @@ -1604,6 +1673,9 @@ def _call(
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> dict[str, Any]:
"""Run text through and get agent response."""
# Get config from instance if available
config = getattr(self, "_current_config", None)

# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
Expand All @@ -1624,6 +1696,7 @@ def _call(
inputs,
intermediate_steps,
run_manager=run_manager,
config=config,
)
if isinstance(next_step_output, AgentFinish):
return self._return(
Expand Down Expand Up @@ -1658,6 +1731,9 @@ async def _acall(
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> dict[str, str]:
"""Async run text through and get agent response."""
# Get config from instance if available
config = getattr(self, "_current_config", None)

# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
Expand All @@ -1680,6 +1756,7 @@ async def _acall(
inputs,
intermediate_steps,
run_manager=run_manager,
config=config,
)
if isinstance(next_step_output, AgentFinish):
return await self._areturn(
Expand Down Expand Up @@ -1778,7 +1855,7 @@ def stream(
"""
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
cast("Any", self),
input,
config.get("callbacks"),
tags=config.get("tags"),
Expand Down Expand Up @@ -1810,7 +1887,7 @@ async def astream(

config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
cast("Any", self),
input,
config.get("callbacks"),
tags=config.get("tags"),
Expand Down
Loading