Skip to content

Commit e11b822

Browse files
Fix stream error using LiteLLM (#589)
In response to issue #587 , I implemented a solution to first check if `refusal` and `usage` attributes exist in the `delta` object. I added a unit test similar to `test_openai_chatcompletions_stream.py`. Let me know if I should change something. --------- Co-authored-by: Rohan Mehta <[email protected]>
1 parent af80e3a commit e11b822

File tree

2 files changed

+290
-2
lines changed

2 files changed

+290
-2
lines changed

src/agents/models/chatcmpl_stream_handler.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ async def handle_stream(
5656
type="response.created",
5757
)
5858

59-
usage = chunk.usage
59+
# This is always set by the OpenAI API, but not by others e.g. LiteLLM
60+
usage = chunk.usage if hasattr(chunk, "usage") else None
6061

6162
if not chunk.choices or not chunk.choices[0].delta:
6263
continue
@@ -112,7 +113,8 @@ async def handle_stream(
112113
state.text_content_index_and_output[1].text += delta.content
113114

114115
# Handle refusals (model declines to answer)
115-
if delta.refusal:
116+
# This is always set by the OpenAI API, but not by others e.g. LiteLLM
117+
if hasattr(delta, "refusal") and delta.refusal:
116118
if not state.refusal_content_index_and_output:
117119
# Initialize a content tracker for streaming refusal text
118120
state.refusal_content_index_and_output = (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
from collections.abc import AsyncIterator
2+
3+
import pytest
4+
from openai.types.chat.chat_completion_chunk import (
5+
ChatCompletionChunk,
6+
Choice,
7+
ChoiceDelta,
8+
ChoiceDeltaToolCall,
9+
ChoiceDeltaToolCallFunction,
10+
)
11+
from openai.types.completion_usage import CompletionUsage
12+
from openai.types.responses import (
13+
Response,
14+
ResponseFunctionToolCall,
15+
ResponseOutputMessage,
16+
ResponseOutputRefusal,
17+
ResponseOutputText,
18+
)
19+
20+
from agents.extensions.models.litellm_model import LitellmModel
21+
from agents.extensions.models.litellm_provider import LitellmProvider
22+
from agents.model_settings import ModelSettings
23+
from agents.models.interface import ModelTracing
24+
25+
26+
@pytest.mark.allow_call_model_methods
27+
@pytest.mark.asyncio
28+
async def test_stream_response_yields_events_for_text_content(monkeypatch) -> None:
29+
"""
30+
Validate that `stream_response` emits the correct sequence of events when
31+
streaming a simple assistant message consisting of plain text content.
32+
We simulate two chunks of text returned from the chat completion stream.
33+
"""
34+
# Create two chunks that will be emitted by the fake stream.
35+
chunk1 = ChatCompletionChunk(
36+
id="chunk-id",
37+
created=1,
38+
model="fake",
39+
object="chat.completion.chunk",
40+
choices=[Choice(index=0, delta=ChoiceDelta(content="He"))],
41+
)
42+
# Mark last chunk with usage so stream_response knows this is final.
43+
chunk2 = ChatCompletionChunk(
44+
id="chunk-id",
45+
created=1,
46+
model="fake",
47+
object="chat.completion.chunk",
48+
choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))],
49+
usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12),
50+
)
51+
52+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
53+
for c in (chunk1, chunk2):
54+
yield c
55+
56+
# Patch _fetch_response to inject our fake stream
57+
async def patched_fetch_response(self, *args, **kwargs):
58+
# `_fetch_response` is expected to return a Response skeleton and the async stream
59+
resp = Response(
60+
id="resp-id",
61+
created_at=0,
62+
model="fake-model",
63+
object="response",
64+
output=[],
65+
tool_choice="none",
66+
tools=[],
67+
parallel_tool_calls=False,
68+
)
69+
return resp, fake_stream()
70+
71+
monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
72+
model = LitellmProvider().get_model("gpt-4")
73+
output_events = []
74+
async for event in model.stream_response(
75+
system_instructions=None,
76+
input="",
77+
model_settings=ModelSettings(),
78+
tools=[],
79+
output_schema=None,
80+
handoffs=[],
81+
tracing=ModelTracing.DISABLED,
82+
previous_response_id=None,
83+
):
84+
output_events.append(event)
85+
# We expect a response.created, then a response.output_item.added, content part added,
86+
# two content delta events (for "He" and "llo"), a content part done, the assistant message
87+
# output_item.done, and finally response.completed.
88+
# There should be 8 events in total.
89+
assert len(output_events) == 8
90+
# First event indicates creation.
91+
assert output_events[0].type == "response.created"
92+
# The output item added and content part added events should mark the assistant message.
93+
assert output_events[1].type == "response.output_item.added"
94+
assert output_events[2].type == "response.content_part.added"
95+
# Two text delta events.
96+
assert output_events[3].type == "response.output_text.delta"
97+
assert output_events[3].delta == "He"
98+
assert output_events[4].type == "response.output_text.delta"
99+
assert output_events[4].delta == "llo"
100+
# After streaming, the content part and item should be marked done.
101+
assert output_events[5].type == "response.content_part.done"
102+
assert output_events[6].type == "response.output_item.done"
103+
# Last event indicates completion of the stream.
104+
assert output_events[7].type == "response.completed"
105+
# The completed response should have one output message with full text.
106+
completed_resp = output_events[7].response
107+
assert isinstance(completed_resp.output[0], ResponseOutputMessage)
108+
assert isinstance(completed_resp.output[0].content[0], ResponseOutputText)
109+
assert completed_resp.output[0].content[0].text == "Hello"
110+
111+
assert completed_resp.usage, "usage should not be None"
112+
assert completed_resp.usage.input_tokens == 7
113+
assert completed_resp.usage.output_tokens == 5
114+
assert completed_resp.usage.total_tokens == 12
115+
116+
117+
@pytest.mark.allow_call_model_methods
118+
@pytest.mark.asyncio
119+
async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None:
120+
"""
121+
Validate that when the model streams a refusal string instead of normal content,
122+
`stream_response` emits the appropriate sequence of events including
123+
`response.refusal.delta` events for each chunk of the refusal message and
124+
constructs a completed assistant message with a `ResponseOutputRefusal` part.
125+
"""
126+
# Simulate refusal text coming in two pieces, like content but using the `refusal`
127+
# field on the delta rather than `content`.
128+
chunk1 = ChatCompletionChunk(
129+
id="chunk-id",
130+
created=1,
131+
model="fake",
132+
object="chat.completion.chunk",
133+
choices=[Choice(index=0, delta=ChoiceDelta(refusal="No"))],
134+
)
135+
chunk2 = ChatCompletionChunk(
136+
id="chunk-id",
137+
created=1,
138+
model="fake",
139+
object="chat.completion.chunk",
140+
choices=[Choice(index=0, delta=ChoiceDelta(refusal="Thanks"))],
141+
usage=CompletionUsage(completion_tokens=2, prompt_tokens=2, total_tokens=4),
142+
)
143+
144+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
145+
for c in (chunk1, chunk2):
146+
yield c
147+
148+
async def patched_fetch_response(self, *args, **kwargs):
149+
resp = Response(
150+
id="resp-id",
151+
created_at=0,
152+
model="fake-model",
153+
object="response",
154+
output=[],
155+
tool_choice="none",
156+
tools=[],
157+
parallel_tool_calls=False,
158+
)
159+
return resp, fake_stream()
160+
161+
monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
162+
model = LitellmProvider().get_model("gpt-4")
163+
output_events = []
164+
async for event in model.stream_response(
165+
system_instructions=None,
166+
input="",
167+
model_settings=ModelSettings(),
168+
tools=[],
169+
output_schema=None,
170+
handoffs=[],
171+
tracing=ModelTracing.DISABLED,
172+
previous_response_id=None,
173+
):
174+
output_events.append(event)
175+
# Expect sequence similar to text: created, output_item.added, content part added,
176+
# two refusal delta events, content part done, output_item.done, completed.
177+
assert len(output_events) == 8
178+
assert output_events[0].type == "response.created"
179+
assert output_events[1].type == "response.output_item.added"
180+
assert output_events[2].type == "response.content_part.added"
181+
assert output_events[3].type == "response.refusal.delta"
182+
assert output_events[3].delta == "No"
183+
assert output_events[4].type == "response.refusal.delta"
184+
assert output_events[4].delta == "Thanks"
185+
assert output_events[5].type == "response.content_part.done"
186+
assert output_events[6].type == "response.output_item.done"
187+
assert output_events[7].type == "response.completed"
188+
completed_resp = output_events[7].response
189+
assert isinstance(completed_resp.output[0], ResponseOutputMessage)
190+
refusal_part = completed_resp.output[0].content[0]
191+
assert isinstance(refusal_part, ResponseOutputRefusal)
192+
assert refusal_part.refusal == "NoThanks"
193+
194+
195+
@pytest.mark.allow_call_model_methods
196+
@pytest.mark.asyncio
197+
async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
198+
"""
199+
Validate that `stream_response` emits the correct sequence of events when
200+
the model is streaming a function/tool call instead of plain text.
201+
The function call will be split across two chunks.
202+
"""
203+
# Simulate a single tool call whose ID stays constant and function name/args built over chunks.
204+
tool_call_delta1 = ChoiceDeltaToolCall(
205+
index=0,
206+
id="tool-id",
207+
function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"),
208+
type="function",
209+
)
210+
tool_call_delta2 = ChoiceDeltaToolCall(
211+
index=0,
212+
id="tool-id",
213+
function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"),
214+
type="function",
215+
)
216+
chunk1 = ChatCompletionChunk(
217+
id="chunk-id",
218+
created=1,
219+
model="fake",
220+
object="chat.completion.chunk",
221+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
222+
)
223+
chunk2 = ChatCompletionChunk(
224+
id="chunk-id",
225+
created=1,
226+
model="fake",
227+
object="chat.completion.chunk",
228+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
229+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
230+
)
231+
232+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
233+
for c in (chunk1, chunk2):
234+
yield c
235+
236+
async def patched_fetch_response(self, *args, **kwargs):
237+
resp = Response(
238+
id="resp-id",
239+
created_at=0,
240+
model="fake-model",
241+
object="response",
242+
output=[],
243+
tool_choice="none",
244+
tools=[],
245+
parallel_tool_calls=False,
246+
)
247+
return resp, fake_stream()
248+
249+
monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
250+
model = LitellmProvider().get_model("gpt-4")
251+
output_events = []
252+
async for event in model.stream_response(
253+
system_instructions=None,
254+
input="",
255+
model_settings=ModelSettings(),
256+
tools=[],
257+
output_schema=None,
258+
handoffs=[],
259+
tracing=ModelTracing.DISABLED,
260+
previous_response_id=None,
261+
):
262+
output_events.append(event)
263+
# Sequence should be: response.created, then after loop we expect function call-related events:
264+
# one response.output_item.added for function call, a response.function_call_arguments.delta,
265+
# a response.output_item.done, and finally response.completed.
266+
assert output_events[0].type == "response.created"
267+
# The next three events are about the tool call.
268+
assert output_events[1].type == "response.output_item.added"
269+
# The added item should be a ResponseFunctionToolCall.
270+
added_fn = output_events[1].item
271+
assert isinstance(added_fn, ResponseFunctionToolCall)
272+
assert added_fn.name == "my_func" # Name should be concatenation of both chunks.
273+
assert added_fn.arguments == "arg1arg2"
274+
assert output_events[2].type == "response.function_call_arguments.delta"
275+
assert output_events[2].delta == "arg1arg2"
276+
assert output_events[3].type == "response.output_item.done"
277+
assert output_events[4].type == "response.completed"
278+
assert output_events[2].delta == "arg1arg2"
279+
assert output_events[3].type == "response.output_item.done"
280+
assert output_events[4].type == "response.completed"
281+
assert added_fn.name == "my_func" # Name should be concatenation of both chunks.
282+
assert added_fn.arguments == "arg1arg2"
283+
assert output_events[2].type == "response.function_call_arguments.delta"
284+
assert output_events[2].delta == "arg1arg2"
285+
assert output_events[3].type == "response.output_item.done"
286+
assert output_events[4].type == "response.completed"

0 commit comments

Comments
 (0)