diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index 02af9dcaa..7d972c6e7 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -109,6 +109,10 @@ def _replace_match(match) -> str: else: raise KeyError(f'Context variable not found: `{var_name}`.') + def _replace_escaped(match) -> str: + return match.group().replace('{{', '{').replace('}}', '}') + + instruction_template = re.sub(r'{{+[^{}]*}}+', _replace_escaped, instruction_template) return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template) diff --git a/src/google/adk/tests/integration/test_context_variable.py b/src/google/adk/tests/integration/test_context_variable.py index ba6af06af..2a1dd2eeb 100644 --- a/src/google/adk/tests/integration/test_context_variable.py +++ b/src/google/adk/tests/integration/test_context_variable.py @@ -16,52 +16,66 @@ import pytest -# Skip until fixed. -pytest.skip(allow_module_level=True) - from .fixture import context_variable_agent from .utils import TestRunner @pytest.mark.parametrize( "agent_runner", - [{"agent": context_variable_agent.agent.state_variable_echo_agent}], + [{"agent": context_variable_agent.context_variable_echo_agent}], indirect=True, ) def test_context_variable_missing(agent_runner: TestRunner): - with pytest.raises(KeyError) as e_info: - agent_runner.run("Hi echo my customer id.") - assert "customerId" in str(e_info.value) + with pytest.raises(KeyError) as e_info: + agent_runner.run("Hi echo my customer id.") + assert "customerId" in str(e_info.value) @pytest.mark.parametrize( "agent_runner", - [{"agent": context_variable_agent.agent.state_variable_update_agent}], + [{"agent": context_variable_agent.context_variable_update_agent}], indirect=True, ) def test_context_variable_update(agent_runner: TestRunner): - _call_function_and_assert( - agent_runner, - "update_fc", - ["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]], - "successfully", - ) + _call_function_and_assert( + agent_runner, + "update_fc", + ["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]], + "successfully", + ) + + +@pytest.mark.parametrize( + "agent_runner", + [{"agent": context_variable_agent.context_variable_with_complicated_format_agent}], + indirect=True, +) +def test_context_variable_with_complicated_format(agent_runner: TestRunner): + agent_runner.run("Hi echo my customer id.") + model_response_event = agent_runner.get_events()[-1] + assert model_response_event.author == "context_variable_with_complicated_format_agent" + assert model_response_event.content.role == "model" + assert "1234567890" in model_response_event.content.parts[0].text.strip() + assert "30" in model_response_event.content.parts[0].text.strip() + assert "{ non-identifier-float}}" in model_response_event.content.parts[0].text.strip() + assert "{'key1': 'value1'}" in model_response_event.content.parts[0].text.strip() + assert "{{'key2': 'value2'}}" in model_response_event.content.parts[0].text.strip() def _call_function_and_assert( agent_runner: TestRunner, function_name: str, params, expected ): - param_section = ( - " with params" - f" {params if isinstance(params, str) else json.dumps(params)}" - if params is not None - else "" - ) - agent_runner.run( - f"Call {function_name}{param_section} and show me the result" - ) + param_section = ( + " with params" + f" {params if isinstance(params, str) else json.dumps(params)}" + if params is not None + else "" + ) + agent_runner.run( + f"Call {function_name}{param_section} and show me the result" + ) - model_response_event = agent_runner.get_events()[-1] - assert model_response_event.author == "context_variable_update_agent" - assert model_response_event.content.role == "model" - assert expected in model_response_event.content.parts[0].text.strip() + model_response_event = agent_runner.get_events()[-1] + assert model_response_event.author == "context_variable_update_agent" + assert model_response_event.content.role == "model" + assert expected in model_response_event.content.parts[0].text.strip()