Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 25 additions & 32 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
get_type_hints,
)

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph._internal._runnable import RunnableCallable
Expand Down Expand Up @@ -42,13 +41,15 @@
ResponseFormat,
StructuredOutputValidationError,
ToolStrategy,
_supports_provider_strategy,
)
from langchain.chat_models import init_chat_model
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.runnables import Runnable
from langgraph.cache.base import BaseCache
from langgraph.graph.state import CompiledStateGraph
Expand Down Expand Up @@ -347,29 +348,6 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []


def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
"""Check if a model supports provider-specific structured output.

Args:
model: Model name string or `BaseChatModel` instance.

Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
"""
model_name: str | None = None
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)

return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)


def _handle_structured_output_error(
exception: Exception,
response_format: ResponseFormat,
Expand Down Expand Up @@ -932,16 +910,34 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |

# Determine effective response format (auto-detect if needed)
effective_response_format: ResponseFormat | None
model_name: str = cast(
"str",
(
request.model
if isinstance(request.model, str)
else getattr(request.model, "model_name", "")
),
)
if isinstance(request.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model):
if _supports_provider_strategy(model_name):
# Model supports provider strategy - use it
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
else:
# Model doesn't support provider strategy - use ToolStrategy
effective_response_format = ToolStrategy(schema=request.response_format.schema)
elif isinstance(request.response_format, ProviderStrategy):
if not _supports_provider_strategy(model_name):
msg = (
f"Cannot use ProviderStrategy with {model_name}. "
"Supported models: OpenAI (gpt-5, gpt-4.1, gpt-oss, o3-pro, o3-mini), "
"X.AI (Grok). "
"Consider using a raw schema (which auto-selects the best strategy) or "
"explicitly use `ToolStrategy` for unsupported providers."
)
raise ValueError(msg)
effective_response_format = request.response_format
else:
# User explicitly specified a strategy - preserve it
effective_response_format = request.response_format

# Build final tools list including structured output tools
Expand All @@ -957,12 +953,9 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
if isinstance(effective_response_format, ProviderStrategy):
# Use provider-specific structured output
kwargs = effective_response_format.to_model_kwargs()
return (
request.model.bind_tools(
final_tools, strict=True, **kwargs, **request.model_settings
),
effective_response_format,
)
return request.model.bind_tools(
final_tools, **kwargs, **request.model_settings
), effective_response_format

if isinstance(effective_response_format, ToolStrategy):
# Current implementation requires that tools used for structured output
Expand Down
91 changes: 86 additions & 5 deletions libs/langchain_v1/langchain/agents/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]


def _supports_provider_strategy(model_name: str) -> bool:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fickle. should we make it public?

"""Check if a model supports provider-specific structured output.

Args:
model_name: Model name string.

Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
"""
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)


class StructuredOutputError(Exception):
"""Base class for structured output errors."""

Expand Down Expand Up @@ -238,7 +255,56 @@ def _iter_variants(schema: Any) -> Iterable[Any]:

@dataclass(init=False)
class ProviderStrategy(Generic[SchemaT]):
"""Use the model provider's native structured output method."""
"""Use the model provider's native structured output method.

`ProviderStrategy` uses provider-specific structured output APIs that enforce
JSON schema validation at the model level. This provides stronger guarantees
than tool-based approaches but is only supported by certain providers.

Supported Providers:
- **OpenAI**: All models that support structured outputs (requires `strict=True`)
- **X.AI (Grok)**: All models that support structured outputs (requires `strict=True`)

Important:
When using `ProviderStrategy`, the agent will validate at runtime that the
model provider is supported. If you're using an unsupported provider, consider:

- Using a **raw schema** (recommended): Automatically selects the best strategy
based on model capabilities
- Using **`ToolStrategy`**: Explicitly use tool-based structured output for any
provider

Example:
```python
from langchain.agents import create_agent
from langchain.agents.structured_output import ProviderStrategy
from pydantic import BaseModel


class WeatherResponse(BaseModel):
temperature: float
condition: str


# Explicitly use provider strategy (only for OpenAI/Grok)
agent = create_agent(
model="openai:gpt-4", tools=[], response_format=ProviderStrategy(WeatherResponse)
)

# Or use raw schema for automatic strategy selection (recommended)
# This will auto-select ProviderStrategy for OpenAI/Grok, ToolStrategy for others
agent = create_agent(
model="openai:gpt-4",
tools=[],
response_format=WeatherResponse, # Auto-selects best strategy
)
```

Note:
`ProviderStrategy` can be used with middleware that changes the model at runtime.
Validation occurs after the model is resolved, allowing dynamic model selection
while ensuring provider compatibility.
"""

schema: type[SchemaT]
"""Schema for native mode."""
Expand All @@ -255,17 +321,32 @@ def __init__(
self.schema_spec = _SchemaSpec(schema)

def to_model_kwargs(self) -> dict[str, Any]:
"""Convert to kwargs to bind to a model to force structured output."""
# OpenAI:
# - see https://platform.openai.com/docs/guides/structured-outputs
"""Convert to kwargs to bind to a model to force structured output.

Args:
model: The model instance to check provider for conditional `strict` param.

Returns:
Model kwargs with `response_format` and optionally `strict`.
"""
# Provider-specific structured output:
# - OpenAI: https://platform.openai.com/docs/guides/structured-outputs
# - Uses strict=True for schema validation
# - X.AI (Grok): https://docs.x.ai/docs/guides/structured-outputs
# - Uses strict=True for schema validation (required)
response_format = {
"type": "json_schema",
"json_schema": {
"name": self.schema_spec.name,
"schema": self.schema_spec.json_schema,
},
}
return {"response_format": response_format}

# Set strict=True for OpenAI and X.AI (Grok) models
# Both providers require strict=True for structured output
kwargs: dict[str, Any] = {"response_format": response_format, "strict": True}

return kwargs


@dataclass
Expand Down
8 changes: 8 additions & 0 deletions libs/langchain_v1/tests/unit_tests/agents/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.language_models.base import LangSmithParams
from langchain_core.messages import (
AIMessage,
BaseMessage,
Expand All @@ -29,6 +30,8 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
structured_response: StructuredResponseT | None = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"
ls_provider: str = "openai"
model_name: str = "fake-model"

def _generate(
self,
Expand All @@ -52,6 +55,7 @@ def _generate(
tool_calls = []

if is_native and not tool_calls:
content_obj = {}
if isinstance(self.structured_response, BaseModel):
content_obj = self.structured_response.model_dump()
elif is_dataclass(self.structured_response):
Expand All @@ -73,6 +77,10 @@ def _generate(
def _llm_type(self) -> str:
return "fake-tool-call-model"

def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
"""Get LangSmith parameters for this model."""
return LangSmithParams(ls_provider=self.ls_provider, ls_model_type="chat")

def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
Expand Down
35 changes: 12 additions & 23 deletions libs/langchain_v1/tests/unit_tests/agents/test_response_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,9 @@ def test_pydantic_model(self) -> None:
]

model = FakeToolCallingModel[WeatherBaseModel](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_PYDANTIC,
model_name="gpt-4.1",
)

agent = create_agent(
Expand All @@ -637,7 +639,9 @@ def test_dataclass(self) -> None:
]

model = FakeToolCallingModel[WeatherDataclass](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_DATACLASS,
model_name="gpt-4.1",
)

agent = create_agent(
Expand All @@ -657,7 +661,7 @@ def test_typed_dict(self) -> None:
]

model = FakeToolCallingModel[WeatherTypedDict](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
)

agent = create_agent(
Expand All @@ -675,7 +679,7 @@ def test_json_schema(self) -> None:
]

model = FakeToolCallingModel[dict](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
)

agent = create_agent(
Expand All @@ -697,13 +701,13 @@ def test_middleware_model_swap_provider_to_tool_strategy(self) -> None:
on the middleware-modified model (not the original), ensuring the correct strategy is
selected based on the final model's capabilities.
"""
from unittest.mock import patch
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel

# Custom model that we'll use to test whether the tool strategy is applied
# correctly at runtime.
# Custom model that we'll use to test whether the provider strategy is applied
# correctly at runtime. Use a model_name that supports provider strategy.
class CustomModel(GenericFakeChatModel):
model_name: str = "gpt-4.1"
tool_bindings: list[Any] = []

def bind_tools(
Expand Down Expand Up @@ -736,14 +740,6 @@ def wrap_model_call(
request.model = model
return handler(request)

# Track which model is checked for provider strategy support
calls = []

def mock_supports_provider_strategy(model) -> bool:
"""Track which model is checked and return True for ProviderStrategy."""
calls.append(model)
return True

# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
# This should auto-detect strategy based on model capabilities
agent = create_agent(
Expand All @@ -754,14 +750,7 @@ def mock_supports_provider_strategy(model) -> bool:
middleware=[ModelSwappingMiddleware()],
)

with patch(
"langchain.agents.factory._supports_provider_strategy",
side_effect=mock_supports_provider_strategy,
):
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})

# Verify strategy resolution was deferred: check was called once during _get_bound_model
assert len(calls) == 1
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})

# Verify successful parsing of JSON as structured output via ProviderStrategy
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
Expand Down