Skip to content

Commit a562a31

Browse files
haranrkcopybara-github
authored andcommitted
refactor(interactions): extract _create_interactions transport loop
Extract the shared interactions.create + response-conversion loop out of generate_content_via_interactions so it can be reused for both streaming and non-streaming calls. No behavior change. Co-authored-by: Haran Rajkumar <haranrk@google.com> PiperOrigin-RevId: 937471267
1 parent 8679aa8 commit a562a31

2 files changed

Lines changed: 291 additions & 58 deletions

File tree

src/google/adk/models/interactions_utils.py

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,53 @@ def _get_latest_user_contents(
11661166
return latest_user_contents
11671167

11681168

1169+
async def _create_interactions(
1170+
api_client: Client,
1171+
*,
1172+
create_kwargs: dict[str, Any],
1173+
stream: bool,
1174+
) -> AsyncGenerator[LlmResponse, None]:
1175+
"""Issue ``interactions.create`` and convert the response(s) to LlmResponses.
1176+
1177+
This is the shared transport + conversion loop. The caller assembles
1178+
``create_kwargs`` (``model`` or ``agent``, ``input``, ``tools``, etc.); this
1179+
helper owns issuing the call and mapping the stream to ``LlmResponse``s.
1180+
1181+
Args:
1182+
api_client: The Google GenAI client.
1183+
create_kwargs: Keyword arguments passed verbatim to
1184+
``api_client.aio.interactions.create`` (excluding ``stream``).
1185+
stream: Whether to stream the response.
1186+
1187+
Yields:
1188+
LlmResponse objects converted from interaction responses.
1189+
"""
1190+
current_interaction_id: str | None = None
1191+
1192+
if stream:
1193+
responses = await api_client.aio.interactions.create(
1194+
**create_kwargs, stream=True
1195+
)
1196+
aggregated_parts: list[types.Part] = []
1197+
async for event in responses:
1198+
logger.debug(build_interactions_event_log(event))
1199+
interaction_id = _extract_stream_interaction_id(event)
1200+
if interaction_id:
1201+
current_interaction_id = interaction_id
1202+
llm_response = convert_interaction_event_to_llm_response(
1203+
event, aggregated_parts, current_interaction_id
1204+
)
1205+
if llm_response:
1206+
yield llm_response
1207+
else:
1208+
interaction = await api_client.aio.interactions.create(
1209+
**create_kwargs, stream=False
1210+
)
1211+
logger.info('Interaction response received.')
1212+
logger.debug(build_interactions_response_log(interaction))
1213+
yield convert_interaction_to_llm_response(interaction)
1214+
1215+
11691216
async def generate_content_via_interactions(
11701217
api_client: Client,
11711218
llm_request: LlmRequest,
@@ -1227,49 +1274,18 @@ async def generate_content_via_interactions(
12271274
)
12281275
)
12291276

1230-
# Track the current interaction ID from responses
1231-
current_interaction_id: str | None = None
1232-
1233-
if stream:
1234-
# Streaming mode
1235-
responses = await api_client.aio.interactions.create(
1236-
model=llm_request.model,
1237-
input=input_steps,
1238-
stream=True,
1239-
system_instruction=system_instruction,
1240-
tools=interaction_tools if interaction_tools else None,
1241-
generation_config=generation_config if generation_config else None,
1242-
previous_interaction_id=previous_interaction_id,
1243-
)
1244-
1245-
aggregated_parts: list[types.Part] = []
1246-
async for event in responses:
1247-
# Log the streaming event
1248-
logger.debug(build_interactions_event_log(event))
1249-
1250-
interaction_id = _extract_stream_interaction_id(event)
1251-
if interaction_id:
1252-
current_interaction_id = interaction_id
1253-
llm_response = convert_interaction_event_to_llm_response(
1254-
event, aggregated_parts, current_interaction_id
1255-
)
1256-
if llm_response:
1257-
yield llm_response
1258-
1259-
else:
1260-
# Non-streaming mode
1261-
interaction = await api_client.aio.interactions.create(
1262-
model=llm_request.model,
1263-
input=input_steps,
1264-
stream=False,
1265-
system_instruction=system_instruction,
1266-
tools=interaction_tools if interaction_tools else None,
1267-
generation_config=generation_config if generation_config else None,
1268-
previous_interaction_id=previous_interaction_id,
1269-
)
1270-
1271-
# Log the response
1272-
logger.info('Interaction response received from the model.')
1273-
logger.debug(build_interactions_response_log(interaction))
1274-
1275-
yield convert_interaction_to_llm_response(interaction)
1277+
# Assemble the create() kwargs for the model path and delegate the
1278+
# transport + conversion loop to the shared helper.
1279+
create_kwargs: dict[str, Any] = {
1280+
'model': llm_request.model,
1281+
'input': input_steps,
1282+
'system_instruction': system_instruction,
1283+
'tools': interaction_tools if interaction_tools else None,
1284+
'generation_config': generation_config if generation_config else None,
1285+
'previous_interaction_id': previous_interaction_id,
1286+
}
1287+
1288+
async for llm_response in _create_interactions(
1289+
api_client, create_kwargs=create_kwargs, stream=stream
1290+
):
1291+
yield llm_response

tests/unittests/models/test_interactions_utils.py

Lines changed: 229 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import asyncio
1818
import base64
19+
from collections.abc import AsyncGenerator
1920
from collections.abc import Callable
2021
from datetime import datetime
2122
from datetime import timezone
@@ -24,6 +25,7 @@
2425

2526
from google.adk.models import interactions_utils
2627
from google.adk.models.llm_request import LlmRequest
28+
from google.adk.models.llm_response import LlmResponse
2729
from google.genai import interactions
2830
from google.genai import types
2931
from google.genai.interactions import CodeExecutionResultStep
@@ -61,27 +63,63 @@ async def __anext__(self):
6163

6264

6365
class _FakeInteractions:
64-
"""Minimal fake interactions resource for streaming tests."""
65-
66-
def __init__(self, events: list[object]):
67-
self._events = events
68-
69-
async def create(self, **_kwargs):
70-
return _MockAsyncIterator(self._events)
66+
"""Fake interactions resource for create() tests.
67+
68+
Records each create() call's kwargs (including the ``stream`` flag) so tests
69+
can assert verbatim forwarding. Streaming calls (``stream`` truthy) return an
70+
async iterator over the configured events; non-streaming calls return the
71+
configured Interaction. ``_create_interactions`` always passes ``stream``
72+
explicitly, so there is no need to distinguish "unset" from ``stream=False``.
73+
"""
74+
75+
def __init__(
76+
self,
77+
events: list[object] | None = None,
78+
*,
79+
interaction: Interaction | None = None,
80+
):
81+
self._events = events or []
82+
self._interaction = interaction
83+
self.create_calls: list[dict[str, object]] = []
84+
85+
async def create(self, **kwargs):
86+
self.create_calls.append(kwargs)
87+
if kwargs.get('stream'):
88+
return _MockAsyncIterator(self._events)
89+
return self._interaction
7190

7291

7392
class _FakeAio:
7493
"""Namespace matching the expected api_client.aio shape."""
7594

76-
def __init__(self, events: list[object]):
77-
self.interactions = _FakeInteractions(events)
95+
def __init__(
96+
self,
97+
events: list[object] | None = None,
98+
*,
99+
interaction: Interaction | None = None,
100+
):
101+
self.interactions = _FakeInteractions(events, interaction=interaction)
78102

79103

80104
class _FakeApiClient:
81-
"""Minimal fake API client for generate_content_via_interactions tests."""
105+
"""Minimal fake API client for interactions create() tests.
106+
107+
Streaming calls return an async iterator over the configured events;
108+
non-streaming calls return the configured Interaction. ``create_calls``
109+
exposes the recorded kwargs of each ``interactions.create`` call.
110+
"""
82111

83-
def __init__(self, events: list[object]):
84-
self.aio = _FakeAio(events)
112+
def __init__(
113+
self,
114+
events: list[object] | None = None,
115+
*,
116+
interaction: Interaction | None = None,
117+
):
118+
self.aio = _FakeAio(events, interaction=interaction)
119+
120+
@property
121+
def create_calls(self) -> list[dict[str, object]]:
122+
return self.aio.interactions.create_calls
85123

86124

87125
def _build_llm_request() -> LlmRequest:
@@ -1389,3 +1427,182 @@ def test_generate_content_via_interactions_stream_extracts_interaction_id(
13891427
asyncio.run(_collect_function_call_interaction_ids(streamed_events))
13901428
== expected_ids
13911429
)
1430+
1431+
1432+
def _build_simple_text_stream() -> list[object]:
1433+
"""A minimal streamed interaction: created -> text delta -> completed."""
1434+
now = datetime.now(timezone.utc).isoformat()
1435+
created = InteractionCreatedEvent(
1436+
event_type='interaction.created',
1437+
interaction=InteractionSseEventInteraction(
1438+
id='interaction_xyz',
1439+
created=now,
1440+
updated=now,
1441+
status='requires_action',
1442+
steps=[],
1443+
),
1444+
)
1445+
step_start = StepStart(
1446+
event_type='step.start',
1447+
index=0,
1448+
step=ModelOutputStep(type='model_output'),
1449+
)
1450+
step_delta = StepDelta(
1451+
event_type='step.delta',
1452+
index=0,
1453+
delta={'type': 'text', 'text': 'Sunny in Tokyo.'},
1454+
)
1455+
step_stop = StepStop(event_type='step.stop', index=0)
1456+
completed = InteractionCompletedEvent(
1457+
event_type='interaction.completed',
1458+
interaction=InteractionSseEventInteraction(
1459+
id='interaction_xyz',
1460+
created=now,
1461+
updated=now,
1462+
status='completed',
1463+
steps=[
1464+
ModelOutputStep(
1465+
type='model_output',
1466+
content=[TextContent(type='text', text='Sunny in Tokyo.')],
1467+
)
1468+
],
1469+
),
1470+
)
1471+
return [created, step_start, step_delta, step_stop, completed]
1472+
1473+
1474+
async def _collect_stream_responses(events: list[object]):
1475+
api_client = _FakeApiClient(events)
1476+
llm_request = _build_llm_request()
1477+
responses = []
1478+
async for resp in interactions_utils.generate_content_via_interactions(
1479+
api_client, llm_request, stream=True
1480+
):
1481+
responses.append(resp)
1482+
return responses
1483+
1484+
1485+
async def test_generate_content_via_interactions_stream_characterization():
1486+
"""Streaming yields text responses carrying the interaction id."""
1487+
responses = await _collect_stream_responses(_build_simple_text_stream())
1488+
1489+
assert responses, 'expected at least one streamed LlmResponse'
1490+
assert all(r.interaction_id == 'interaction_xyz' for r in responses)
1491+
joined = ''.join(
1492+
part.text
1493+
for r in responses
1494+
if r.content and r.content.parts
1495+
for part in r.content.parts
1496+
if part.text
1497+
)
1498+
assert 'Sunny in Tokyo.' in joined
1499+
1500+
1501+
def _build_non_streaming_interaction() -> Interaction:
1502+
"""A completed non-streaming Interaction with a single text output."""
1503+
now = datetime.now(timezone.utc).isoformat()
1504+
return Interaction(
1505+
id='interaction_ns',
1506+
status='completed',
1507+
created=now,
1508+
updated=now,
1509+
steps=[
1510+
ModelOutputStep(
1511+
type='model_output',
1512+
content=[TextContent(type='text', text='Sunny in Tokyo.')],
1513+
)
1514+
],
1515+
)
1516+
1517+
1518+
async def _drain(
1519+
responses: AsyncGenerator[LlmResponse, None],
1520+
) -> list[LlmResponse]:
1521+
"""Collect all responses yielded by an async generator."""
1522+
return [resp async for resp in responses]
1523+
1524+
1525+
async def test_create_interactions_streaming_forwards_kwargs_and_converts():
1526+
"""Streaming forwards create_kwargs verbatim (plus stream) and converts."""
1527+
# Arrange.
1528+
api_client = _FakeApiClient(_build_simple_text_stream())
1529+
create_kwargs = {
1530+
'model': 'gemini-2.5-flash',
1531+
'input': [{
1532+
'type': 'user_input',
1533+
'content': [{'type': 'text', 'text': 'Weather in Tokyo?'}],
1534+
}],
1535+
'previous_interaction_id': None,
1536+
}
1537+
1538+
# Act.
1539+
responses = await _drain(
1540+
interactions_utils._create_interactions(
1541+
api_client, create_kwargs=create_kwargs, stream=True
1542+
)
1543+
)
1544+
1545+
# Assert: exactly one create() call forwarding kwargs plus the stream flag.
1546+
assert len(api_client.create_calls) == 1
1547+
assert api_client.create_calls[0] == {**create_kwargs, 'stream': True}
1548+
1549+
# Assert: the streamed events are converted into text responses.
1550+
assert responses, 'expected at least one streamed LlmResponse'
1551+
assert all(r.interaction_id == 'interaction_xyz' for r in responses)
1552+
joined = ''.join(
1553+
part.text
1554+
for r in responses
1555+
if r.content and r.content.parts
1556+
for part in r.content.parts
1557+
if part.text
1558+
)
1559+
assert 'Sunny in Tokyo.' in joined
1560+
1561+
1562+
async def test_create_interactions_non_streaming_forwards_kwargs_and_yields_single_response():
1563+
"""Non-streaming forwards kwargs verbatim and yields a single response."""
1564+
# Arrange.
1565+
interaction = _build_non_streaming_interaction()
1566+
api_client = _FakeApiClient(interaction=interaction)
1567+
create_kwargs = {
1568+
'model': 'gemini-2.5-flash',
1569+
'input': [{
1570+
'type': 'user_input',
1571+
'content': [{'type': 'text', 'text': 'Weather in Tokyo?'}],
1572+
}],
1573+
'previous_interaction_id': None,
1574+
}
1575+
1576+
# Act.
1577+
responses = await _drain(
1578+
interactions_utils._create_interactions(
1579+
api_client, create_kwargs=create_kwargs, stream=False
1580+
)
1581+
)
1582+
1583+
# Assert: exactly one create() call forwarding kwargs plus the stream flag.
1584+
assert len(api_client.create_calls) == 1
1585+
assert api_client.create_calls[0] == {**create_kwargs, 'stream': False}
1586+
1587+
# Assert: a single converted LlmResponse carrying the interaction output.
1588+
assert len(responses) == 1
1589+
assert responses[0].interaction_id == 'interaction_ns'
1590+
assert responses[0].content.parts[0].text == 'Sunny in Tokyo.'
1591+
1592+
1593+
async def test_generate_content_via_interactions_non_streaming_yields_single_response():
1594+
"""The public function yields a single response on the non-streaming path."""
1595+
# Arrange.
1596+
api_client = _FakeApiClient(interaction=_build_non_streaming_interaction())
1597+
1598+
# Act.
1599+
responses = await _drain(
1600+
interactions_utils.generate_content_via_interactions(
1601+
api_client, _build_llm_request(), stream=False
1602+
)
1603+
)
1604+
1605+
# Assert: a single end-to-end converted LlmResponse with the expected text.
1606+
assert len(responses) == 1
1607+
assert responses[0].interaction_id == 'interaction_ns'
1608+
assert responses[0].content.parts[0].text == 'Sunny in Tokyo.'

0 commit comments

Comments
 (0)