Skip to content

Commit b9e7fca

Browse files
google-genai-botcopybara-github
authored andcommitted
fix(a2a): render HITL interrupt when prompt is in a data part
A2A input/auth-required prompts sent as a DataPart became an opaque inline_data JSON blob, so no HITL function call was produced and the client rendered nothing. Extract the prompt from the data part so these tasks surface a proper HITL function call. Adds unit tests. PiperOrigin-RevId: 933586203
1 parent f0ec997 commit b9e7fca

2 files changed

Lines changed: 140 additions & 5 deletions

File tree

src/google/adk/a2a/converters/to_adk_event.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
from ...events.event import Event
3636
from ...events.event_actions import EventActions
3737
from ..experimental import a2a_experimental
38+
from .part_converter import A2A_DATA_PART_END_TAG
3839
from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY
40+
from .part_converter import A2A_DATA_PART_START_TAG
41+
from .part_converter import A2A_DATA_PART_TEXT_MIME_TYPE
3942
from .part_converter import A2APartToGenAIPartConverter
4043
from .part_converter import convert_a2a_part_to_genai_part
4144
from .utils import _get_adk_metadata_key
@@ -283,6 +286,40 @@ def _merge_event_actions(
283286
return EventActions.model_validate(merged_actions_data)
284287

285288

289+
def _extract_user_input_prompt(part: genai_types.Part) -> Any:
290+
"""Extracts a prompt from a converted ADK part."""
291+
if part.text:
292+
return part.text
293+
294+
blob = part.inline_data
295+
if (
296+
blob is None
297+
or blob.data is None
298+
or blob.mime_type != A2A_DATA_PART_TEXT_MIME_TYPE
299+
or not blob.data.startswith(A2A_DATA_PART_START_TAG)
300+
or not blob.data.endswith(A2A_DATA_PART_END_TAG)
301+
):
302+
return None
303+
304+
raw_json = blob.data[
305+
len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG)
306+
]
307+
try:
308+
data_part = json.loads(raw_json)
309+
except (ValueError, TypeError) as e:
310+
logger.warning("Failed to parse A2A data part JSON for HITL prompt: %s", e)
311+
return None
312+
313+
if not isinstance(data_part, dict):
314+
logger.warning(
315+
"Unexpected A2A data part JSON of type %s for HITL prompt",
316+
type(data_part).__name__,
317+
)
318+
return None
319+
320+
return data_part.get("data")
321+
322+
286323
def _create_mock_function_call_for_required_user_input(
287324
state: TaskState,
288325
output_parts: list[genai_types.Part],
@@ -308,15 +345,16 @@ def _create_mock_function_call_for_required_user_input(
308345
else:
309346
return output_parts, long_running_function_ids
310347

311-
# Find the last text part from the bottom to replace it with a function call.
312-
# In case of input-required / auth-required events, the LLM should stop the
313-
# production of other parts.
348+
# Find the last part with a usable prompt from the bottom to replace it with a
349+
# function call. In case of input-required / auth-required events, the LLM
350+
# should stop the production of other parts.
314351
for i in range(len(output_parts) - 1, -1, -1):
315-
if output_parts[i].text:
352+
prompt = _extract_user_input_prompt(output_parts[i])
353+
if prompt:
316354
function_call = genai_types.FunctionCall(
317355
id=str(uuid.uuid4()),
318356
name=function_name,
319-
args={args_key: output_parts[i].text},
357+
args={args_key: prompt},
320358
)
321359
long_running_function_ids = set()
322360
long_running_function_ids.add(function_call.id)

tests/unittests/a2a/converters/test_to_adk.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import json
1718
from unittest.mock import Mock
1819

1920
from a2a.types import Artifact
@@ -25,7 +26,10 @@
2526
from a2a.types import TaskStatus
2627
from a2a.types import TaskStatusUpdateEvent
2728
from a2a.types import TextPart
29+
from google.adk.a2a.converters.part_converter import A2A_DATA_PART_END_TAG
2830
from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY
31+
from google.adk.a2a.converters.part_converter import A2A_DATA_PART_START_TAG
32+
from google.adk.a2a.converters.part_converter import A2A_DATA_PART_TEXT_MIME_TYPE
2933
from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event
3034
from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event
3135
from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event
@@ -465,6 +469,99 @@ def test_convert_a2a_task_to_event_no_text_parts(self):
465469
assert event.content is not None
466470
assert event.content.parts == [mock_image_part]
467471

472+
def test_convert_a2a_task_to_event_data_part_input_required(self):
473+
"""Input-required prompt carried in a data part becomes a function call."""
474+
part1 = Mock(spec=A2APart)
475+
part1.root = Mock() # Not a TextPart.
476+
part1.root.metadata = {}
477+
478+
task = Task(
479+
id="task-1",
480+
context_id="context-1",
481+
kind="task",
482+
status=TaskStatus(
483+
state=TaskState.input_required,
484+
timestamp="now",
485+
message=Message(
486+
message_id="m1",
487+
role="agent",
488+
parts=[part1],
489+
),
490+
),
491+
)
492+
493+
prompt = {
494+
"id": "abc123",
495+
"text": "Please confirm this action. Do you want to continue?",
496+
}
497+
data_part_json = json.dumps({"data": prompt, "kind": "data"}).encode(
498+
"utf-8"
499+
)
500+
mock_data_blob_part = genai_types.Part(
501+
inline_data=genai_types.Blob(
502+
mime_type=A2A_DATA_PART_TEXT_MIME_TYPE,
503+
data=A2A_DATA_PART_START_TAG
504+
+ data_part_json
505+
+ A2A_DATA_PART_END_TAG,
506+
)
507+
)
508+
509+
event = convert_a2a_task_to_event(
510+
task,
511+
author="test-author",
512+
invocation_context=self.mock_context,
513+
part_converter=Mock(return_value=[mock_data_blob_part]),
514+
)
515+
516+
assert event is not None
517+
assert event.content is not None
518+
assert (
519+
event.content.parts[0].function_call.name
520+
== MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
521+
)
522+
assert event.content.parts[0].function_call.args["input_required"] == prompt
523+
assert event.long_running_tool_ids
524+
525+
def test_convert_a2a_task_to_event_data_part_malformed_json(self):
526+
"""A malformed data-part blob is left untouched (no crash, no fc)."""
527+
part1 = Mock(spec=A2APart)
528+
part1.root = Mock() # Not a TextPart.
529+
part1.root.metadata = {}
530+
531+
task = Task(
532+
id="task-1",
533+
context_id="context-1",
534+
kind="task",
535+
status=TaskStatus(
536+
state=TaskState.input_required,
537+
timestamp="now",
538+
message=Message(
539+
message_id="m1",
540+
role="agent",
541+
parts=[part1],
542+
),
543+
),
544+
)
545+
546+
mock_bad_blob_part = genai_types.Part(
547+
inline_data=genai_types.Blob(
548+
mime_type=A2A_DATA_PART_TEXT_MIME_TYPE,
549+
data=A2A_DATA_PART_START_TAG + b"not-json" + A2A_DATA_PART_END_TAG,
550+
)
551+
)
552+
553+
event = convert_a2a_task_to_event(
554+
task,
555+
author="test-author",
556+
invocation_context=self.mock_context,
557+
part_converter=Mock(return_value=[mock_bad_blob_part]),
558+
)
559+
560+
assert event is not None
561+
assert event.content is not None
562+
assert event.content.parts == [mock_bad_blob_part]
563+
assert not event.long_running_tool_ids
564+
468565
def test_convert_a2a_status_update_to_event_success(self):
469566
"""Test successful conversion of A2A status update to Event."""
470567
a2a_part = Mock(spec=A2APart)

0 commit comments

Comments
 (0)