|
16 | 16 |
|
17 | 17 | import asyncio |
18 | 18 | import base64 |
| 19 | +from collections.abc import AsyncGenerator |
19 | 20 | from collections.abc import Callable |
20 | 21 | from datetime import datetime |
21 | 22 | from datetime import timezone |
|
24 | 25 |
|
25 | 26 | from google.adk.models import interactions_utils |
26 | 27 | from google.adk.models.llm_request import LlmRequest |
| 28 | +from google.adk.models.llm_response import LlmResponse |
27 | 29 | from google.genai import interactions |
28 | 30 | from google.genai import types |
29 | 31 | from google.genai.interactions import CodeExecutionResultStep |
@@ -61,27 +63,63 @@ async def __anext__(self): |
61 | 63 |
|
62 | 64 |
|
63 | 65 | class _FakeInteractions: |
64 | | - """Minimal fake interactions resource for streaming tests.""" |
65 | | - |
66 | | - def __init__(self, events: list[object]): |
67 | | - self._events = events |
68 | | - |
69 | | - async def create(self, **_kwargs): |
70 | | - return _MockAsyncIterator(self._events) |
| 66 | + """Fake interactions resource for create() tests. |
| 67 | +
|
| 68 | + Records each create() call's kwargs (including the ``stream`` flag) so tests |
| 69 | + can assert verbatim forwarding. Streaming calls (``stream`` truthy) return an |
| 70 | + async iterator over the configured events; non-streaming calls return the |
| 71 | + configured Interaction. ``_create_interactions`` always passes ``stream`` |
| 72 | + explicitly, so there is no need to distinguish "unset" from ``stream=False``. |
| 73 | + """ |
| 74 | + |
| 75 | + def __init__( |
| 76 | + self, |
| 77 | + events: list[object] | None = None, |
| 78 | + *, |
| 79 | + interaction: Interaction | None = None, |
| 80 | + ): |
| 81 | + self._events = events or [] |
| 82 | + self._interaction = interaction |
| 83 | + self.create_calls: list[dict[str, object]] = [] |
| 84 | + |
| 85 | + async def create(self, **kwargs): |
| 86 | + self.create_calls.append(kwargs) |
| 87 | + if kwargs.get('stream'): |
| 88 | + return _MockAsyncIterator(self._events) |
| 89 | + return self._interaction |
71 | 90 |
|
72 | 91 |
|
73 | 92 | class _FakeAio: |
74 | 93 | """Namespace matching the expected api_client.aio shape.""" |
75 | 94 |
|
76 | | - def __init__(self, events: list[object]): |
77 | | - self.interactions = _FakeInteractions(events) |
| 95 | + def __init__( |
| 96 | + self, |
| 97 | + events: list[object] | None = None, |
| 98 | + *, |
| 99 | + interaction: Interaction | None = None, |
| 100 | + ): |
| 101 | + self.interactions = _FakeInteractions(events, interaction=interaction) |
78 | 102 |
|
79 | 103 |
|
80 | 104 | class _FakeApiClient: |
81 | | - """Minimal fake API client for generate_content_via_interactions tests.""" |
| 105 | + """Minimal fake API client for interactions create() tests. |
| 106 | +
|
| 107 | + Streaming calls return an async iterator over the configured events; |
| 108 | + non-streaming calls return the configured Interaction. ``create_calls`` |
| 109 | + exposes the recorded kwargs of each ``interactions.create`` call. |
| 110 | + """ |
82 | 111 |
|
83 | | - def __init__(self, events: list[object]): |
84 | | - self.aio = _FakeAio(events) |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + events: list[object] | None = None, |
| 115 | + *, |
| 116 | + interaction: Interaction | None = None, |
| 117 | + ): |
| 118 | + self.aio = _FakeAio(events, interaction=interaction) |
| 119 | + |
| 120 | + @property |
| 121 | + def create_calls(self) -> list[dict[str, object]]: |
| 122 | + return self.aio.interactions.create_calls |
85 | 123 |
|
86 | 124 |
|
87 | 125 | def _build_llm_request() -> LlmRequest: |
@@ -1389,3 +1427,182 @@ def test_generate_content_via_interactions_stream_extracts_interaction_id( |
1389 | 1427 | asyncio.run(_collect_function_call_interaction_ids(streamed_events)) |
1390 | 1428 | == expected_ids |
1391 | 1429 | ) |
| 1430 | + |
| 1431 | + |
| 1432 | +def _build_simple_text_stream() -> list[object]: |
| 1433 | + """A minimal streamed interaction: created -> text delta -> completed.""" |
| 1434 | + now = datetime.now(timezone.utc).isoformat() |
| 1435 | + created = InteractionCreatedEvent( |
| 1436 | + event_type='interaction.created', |
| 1437 | + interaction=InteractionSseEventInteraction( |
| 1438 | + id='interaction_xyz', |
| 1439 | + created=now, |
| 1440 | + updated=now, |
| 1441 | + status='requires_action', |
| 1442 | + steps=[], |
| 1443 | + ), |
| 1444 | + ) |
| 1445 | + step_start = StepStart( |
| 1446 | + event_type='step.start', |
| 1447 | + index=0, |
| 1448 | + step=ModelOutputStep(type='model_output'), |
| 1449 | + ) |
| 1450 | + step_delta = StepDelta( |
| 1451 | + event_type='step.delta', |
| 1452 | + index=0, |
| 1453 | + delta={'type': 'text', 'text': 'Sunny in Tokyo.'}, |
| 1454 | + ) |
| 1455 | + step_stop = StepStop(event_type='step.stop', index=0) |
| 1456 | + completed = InteractionCompletedEvent( |
| 1457 | + event_type='interaction.completed', |
| 1458 | + interaction=InteractionSseEventInteraction( |
| 1459 | + id='interaction_xyz', |
| 1460 | + created=now, |
| 1461 | + updated=now, |
| 1462 | + status='completed', |
| 1463 | + steps=[ |
| 1464 | + ModelOutputStep( |
| 1465 | + type='model_output', |
| 1466 | + content=[TextContent(type='text', text='Sunny in Tokyo.')], |
| 1467 | + ) |
| 1468 | + ], |
| 1469 | + ), |
| 1470 | + ) |
| 1471 | + return [created, step_start, step_delta, step_stop, completed] |
| 1472 | + |
| 1473 | + |
| 1474 | +async def _collect_stream_responses(events: list[object]): |
| 1475 | + api_client = _FakeApiClient(events) |
| 1476 | + llm_request = _build_llm_request() |
| 1477 | + responses = [] |
| 1478 | + async for resp in interactions_utils.generate_content_via_interactions( |
| 1479 | + api_client, llm_request, stream=True |
| 1480 | + ): |
| 1481 | + responses.append(resp) |
| 1482 | + return responses |
| 1483 | + |
| 1484 | + |
| 1485 | +async def test_generate_content_via_interactions_stream_characterization(): |
| 1486 | + """Streaming yields text responses carrying the interaction id.""" |
| 1487 | + responses = await _collect_stream_responses(_build_simple_text_stream()) |
| 1488 | + |
| 1489 | + assert responses, 'expected at least one streamed LlmResponse' |
| 1490 | + assert all(r.interaction_id == 'interaction_xyz' for r in responses) |
| 1491 | + joined = ''.join( |
| 1492 | + part.text |
| 1493 | + for r in responses |
| 1494 | + if r.content and r.content.parts |
| 1495 | + for part in r.content.parts |
| 1496 | + if part.text |
| 1497 | + ) |
| 1498 | + assert 'Sunny in Tokyo.' in joined |
| 1499 | + |
| 1500 | + |
| 1501 | +def _build_non_streaming_interaction() -> Interaction: |
| 1502 | + """A completed non-streaming Interaction with a single text output.""" |
| 1503 | + now = datetime.now(timezone.utc).isoformat() |
| 1504 | + return Interaction( |
| 1505 | + id='interaction_ns', |
| 1506 | + status='completed', |
| 1507 | + created=now, |
| 1508 | + updated=now, |
| 1509 | + steps=[ |
| 1510 | + ModelOutputStep( |
| 1511 | + type='model_output', |
| 1512 | + content=[TextContent(type='text', text='Sunny in Tokyo.')], |
| 1513 | + ) |
| 1514 | + ], |
| 1515 | + ) |
| 1516 | + |
| 1517 | + |
| 1518 | +async def _drain( |
| 1519 | + responses: AsyncGenerator[LlmResponse, None], |
| 1520 | +) -> list[LlmResponse]: |
| 1521 | + """Collect all responses yielded by an async generator.""" |
| 1522 | + return [resp async for resp in responses] |
| 1523 | + |
| 1524 | + |
| 1525 | +async def test_create_interactions_streaming_forwards_kwargs_and_converts(): |
| 1526 | + """Streaming forwards create_kwargs verbatim (plus stream) and converts.""" |
| 1527 | + # Arrange. |
| 1528 | + api_client = _FakeApiClient(_build_simple_text_stream()) |
| 1529 | + create_kwargs = { |
| 1530 | + 'model': 'gemini-2.5-flash', |
| 1531 | + 'input': [{ |
| 1532 | + 'type': 'user_input', |
| 1533 | + 'content': [{'type': 'text', 'text': 'Weather in Tokyo?'}], |
| 1534 | + }], |
| 1535 | + 'previous_interaction_id': None, |
| 1536 | + } |
| 1537 | + |
| 1538 | + # Act. |
| 1539 | + responses = await _drain( |
| 1540 | + interactions_utils._create_interactions( |
| 1541 | + api_client, create_kwargs=create_kwargs, stream=True |
| 1542 | + ) |
| 1543 | + ) |
| 1544 | + |
| 1545 | + # Assert: exactly one create() call forwarding kwargs plus the stream flag. |
| 1546 | + assert len(api_client.create_calls) == 1 |
| 1547 | + assert api_client.create_calls[0] == {**create_kwargs, 'stream': True} |
| 1548 | + |
| 1549 | + # Assert: the streamed events are converted into text responses. |
| 1550 | + assert responses, 'expected at least one streamed LlmResponse' |
| 1551 | + assert all(r.interaction_id == 'interaction_xyz' for r in responses) |
| 1552 | + joined = ''.join( |
| 1553 | + part.text |
| 1554 | + for r in responses |
| 1555 | + if r.content and r.content.parts |
| 1556 | + for part in r.content.parts |
| 1557 | + if part.text |
| 1558 | + ) |
| 1559 | + assert 'Sunny in Tokyo.' in joined |
| 1560 | + |
| 1561 | + |
| 1562 | +async def test_create_interactions_non_streaming_forwards_kwargs_and_yields_single_response(): |
| 1563 | + """Non-streaming forwards kwargs verbatim and yields a single response.""" |
| 1564 | + # Arrange. |
| 1565 | + interaction = _build_non_streaming_interaction() |
| 1566 | + api_client = _FakeApiClient(interaction=interaction) |
| 1567 | + create_kwargs = { |
| 1568 | + 'model': 'gemini-2.5-flash', |
| 1569 | + 'input': [{ |
| 1570 | + 'type': 'user_input', |
| 1571 | + 'content': [{'type': 'text', 'text': 'Weather in Tokyo?'}], |
| 1572 | + }], |
| 1573 | + 'previous_interaction_id': None, |
| 1574 | + } |
| 1575 | + |
| 1576 | + # Act. |
| 1577 | + responses = await _drain( |
| 1578 | + interactions_utils._create_interactions( |
| 1579 | + api_client, create_kwargs=create_kwargs, stream=False |
| 1580 | + ) |
| 1581 | + ) |
| 1582 | + |
| 1583 | + # Assert: exactly one create() call forwarding kwargs plus the stream flag. |
| 1584 | + assert len(api_client.create_calls) == 1 |
| 1585 | + assert api_client.create_calls[0] == {**create_kwargs, 'stream': False} |
| 1586 | + |
| 1587 | + # Assert: a single converted LlmResponse carrying the interaction output. |
| 1588 | + assert len(responses) == 1 |
| 1589 | + assert responses[0].interaction_id == 'interaction_ns' |
| 1590 | + assert responses[0].content.parts[0].text == 'Sunny in Tokyo.' |
| 1591 | + |
| 1592 | + |
| 1593 | +async def test_generate_content_via_interactions_non_streaming_yields_single_response(): |
| 1594 | + """The public function yields a single response on the non-streaming path.""" |
| 1595 | + # Arrange. |
| 1596 | + api_client = _FakeApiClient(interaction=_build_non_streaming_interaction()) |
| 1597 | + |
| 1598 | + # Act. |
| 1599 | + responses = await _drain( |
| 1600 | + interactions_utils.generate_content_via_interactions( |
| 1601 | + api_client, _build_llm_request(), stream=False |
| 1602 | + ) |
| 1603 | + ) |
| 1604 | + |
| 1605 | + # Assert: a single end-to-end converted LlmResponse with the expected text. |
| 1606 | + assert len(responses) == 1 |
| 1607 | + assert responses[0].interaction_id == 'interaction_ns' |
| 1608 | + assert responses[0].content.parts[0].text == 'Sunny in Tokyo.' |
0 commit comments