|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import json |
17 | 18 | from unittest.mock import Mock |
18 | 19 |
|
19 | 20 | from a2a.types import Artifact |
|
25 | 26 | from a2a.types import TaskStatus |
26 | 27 | from a2a.types import TaskStatusUpdateEvent |
27 | 28 | from a2a.types import TextPart |
| 29 | +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_END_TAG |
28 | 30 | from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY |
| 31 | +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_START_TAG |
| 32 | +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_TEXT_MIME_TYPE |
29 | 33 | from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event |
30 | 34 | from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event |
31 | 35 | from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event |
@@ -465,6 +469,99 @@ def test_convert_a2a_task_to_event_no_text_parts(self): |
465 | 469 | assert event.content is not None |
466 | 470 | assert event.content.parts == [mock_image_part] |
467 | 471 |
|
| 472 | + def test_convert_a2a_task_to_event_data_part_input_required(self): |
| 473 | + """Input-required prompt carried in a data part becomes a function call.""" |
| 474 | + part1 = Mock(spec=A2APart) |
| 475 | + part1.root = Mock() # Not a TextPart. |
| 476 | + part1.root.metadata = {} |
| 477 | + |
| 478 | + task = Task( |
| 479 | + id="task-1", |
| 480 | + context_id="context-1", |
| 481 | + kind="task", |
| 482 | + status=TaskStatus( |
| 483 | + state=TaskState.input_required, |
| 484 | + timestamp="now", |
| 485 | + message=Message( |
| 486 | + message_id="m1", |
| 487 | + role="agent", |
| 488 | + parts=[part1], |
| 489 | + ), |
| 490 | + ), |
| 491 | + ) |
| 492 | + |
| 493 | + prompt = { |
| 494 | + "id": "abc123", |
| 495 | + "text": "Please confirm this action. Do you want to continue?", |
| 496 | + } |
| 497 | + data_part_json = json.dumps({"data": prompt, "kind": "data"}).encode( |
| 498 | + "utf-8" |
| 499 | + ) |
| 500 | + mock_data_blob_part = genai_types.Part( |
| 501 | + inline_data=genai_types.Blob( |
| 502 | + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, |
| 503 | + data=A2A_DATA_PART_START_TAG |
| 504 | + + data_part_json |
| 505 | + + A2A_DATA_PART_END_TAG, |
| 506 | + ) |
| 507 | + ) |
| 508 | + |
| 509 | + event = convert_a2a_task_to_event( |
| 510 | + task, |
| 511 | + author="test-author", |
| 512 | + invocation_context=self.mock_context, |
| 513 | + part_converter=Mock(return_value=[mock_data_blob_part]), |
| 514 | + ) |
| 515 | + |
| 516 | + assert event is not None |
| 517 | + assert event.content is not None |
| 518 | + assert ( |
| 519 | + event.content.parts[0].function_call.name |
| 520 | + == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT |
| 521 | + ) |
| 522 | + assert event.content.parts[0].function_call.args["input_required"] == prompt |
| 523 | + assert event.long_running_tool_ids |
| 524 | + |
| 525 | + def test_convert_a2a_task_to_event_data_part_malformed_json(self): |
| 526 | + """A malformed data-part blob is left untouched (no crash, no fc).""" |
| 527 | + part1 = Mock(spec=A2APart) |
| 528 | + part1.root = Mock() # Not a TextPart. |
| 529 | + part1.root.metadata = {} |
| 530 | + |
| 531 | + task = Task( |
| 532 | + id="task-1", |
| 533 | + context_id="context-1", |
| 534 | + kind="task", |
| 535 | + status=TaskStatus( |
| 536 | + state=TaskState.input_required, |
| 537 | + timestamp="now", |
| 538 | + message=Message( |
| 539 | + message_id="m1", |
| 540 | + role="agent", |
| 541 | + parts=[part1], |
| 542 | + ), |
| 543 | + ), |
| 544 | + ) |
| 545 | + |
| 546 | + mock_bad_blob_part = genai_types.Part( |
| 547 | + inline_data=genai_types.Blob( |
| 548 | + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, |
| 549 | + data=A2A_DATA_PART_START_TAG + b"not-json" + A2A_DATA_PART_END_TAG, |
| 550 | + ) |
| 551 | + ) |
| 552 | + |
| 553 | + event = convert_a2a_task_to_event( |
| 554 | + task, |
| 555 | + author="test-author", |
| 556 | + invocation_context=self.mock_context, |
| 557 | + part_converter=Mock(return_value=[mock_bad_blob_part]), |
| 558 | + ) |
| 559 | + |
| 560 | + assert event is not None |
| 561 | + assert event.content is not None |
| 562 | + assert event.content.parts == [mock_bad_blob_part] |
| 563 | + assert not event.long_running_tool_ids |
| 564 | + |
468 | 565 | def test_convert_a2a_status_update_to_event_success(self): |
469 | 566 | """Test successful conversion of A2A status update to Event.""" |
470 | 567 | a2a_part = Mock(spec=A2APart) |
|
0 commit comments