From 157e7faa9578c400f76b563b177bb0361f9adfa4 Mon Sep 17 00:00:00 2001 From: Shyam Sathish <150350918+ShyamSathish005@users.noreply.github.com> Date: Fri, 11 Apr 2025 20:23:07 +0530 Subject: [PATCH] Fix handling of escaped variables in llm_flow Fixes #86 Update `llm_flow` to handle escaped variables correctly and prevent `KeyError`. * **Instructions Handling**: - Add `_replace_escaped` function to handle escaped variables in `src/google/adk/flows/llm_flows/instructions.py`. - Modify `_replace_match` function to differentiate between escaped and non-escaped variables. - Update `instruction_template` to process escaped variables correctly. * **Testing**: - Add test case in `src/google/adk/tests/integration/test_context_variable.py` to verify handling of escaped variables. - Ensure test case covers both escaped and non-escaped variables. - Update existing test cases to use the correct agent. --- .../adk/flows/llm_flows/instructions.py | 4 ++ .../integration/test_context_variable.py | 68 +++++++++++-------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index 02af9dcaa..7d972c6e7 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -109,6 +109,10 @@ def _replace_match(match) -> str: else: raise KeyError(f'Context variable not found: `{var_name}`.') + def _replace_escaped(match) -> str: + return match.group().replace('{{', '{').replace('}}', '}') + + instruction_template = re.sub(r'{{+[^{}]*}}+', _replace_escaped, instruction_template) return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template) diff --git a/src/google/adk/tests/integration/test_context_variable.py b/src/google/adk/tests/integration/test_context_variable.py index ba6af06af..2a1dd2eeb 100644 --- a/src/google/adk/tests/integration/test_context_variable.py +++ b/src/google/adk/tests/integration/test_context_variable.py @@ -16,52 +16,66 @@ import pytest -# Skip until fixed. -pytest.skip(allow_module_level=True) - from .fixture import context_variable_agent from .utils import TestRunner @pytest.mark.parametrize( "agent_runner", - [{"agent": context_variable_agent.agent.state_variable_echo_agent}], + [{"agent": context_variable_agent.context_variable_echo_agent}], indirect=True, ) def test_context_variable_missing(agent_runner: TestRunner): - with pytest.raises(KeyError) as e_info: - agent_runner.run("Hi echo my customer id.") - assert "customerId" in str(e_info.value) + with pytest.raises(KeyError) as e_info: + agent_runner.run("Hi echo my customer id.") + assert "customerId" in str(e_info.value) @pytest.mark.parametrize( "agent_runner", - [{"agent": context_variable_agent.agent.state_variable_update_agent}], + [{"agent": context_variable_agent.context_variable_update_agent}], indirect=True, ) def test_context_variable_update(agent_runner: TestRunner): - _call_function_and_assert( - agent_runner, - "update_fc", - ["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]], - "successfully", - ) + _call_function_and_assert( + agent_runner, + "update_fc", + ["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]], + "successfully", + ) + + +@pytest.mark.parametrize( + "agent_runner", + [{"agent": context_variable_agent.context_variable_with_complicated_format_agent}], + indirect=True, +) +def test_context_variable_with_complicated_format(agent_runner: TestRunner): + agent_runner.run("Hi echo my customer id.") + model_response_event = agent_runner.get_events()[-1] + assert model_response_event.author == "context_variable_with_complicated_format_agent" + assert model_response_event.content.role == "model" + assert "1234567890" in model_response_event.content.parts[0].text.strip() + assert "30" in model_response_event.content.parts[0].text.strip() + assert "{ non-identifier-float}}" in model_response_event.content.parts[0].text.strip() + assert "{'key1': 'value1'}" in model_response_event.content.parts[0].text.strip() + assert "{{'key2': 'value2'}}" in model_response_event.content.parts[0].text.strip() def _call_function_and_assert( agent_runner: TestRunner, function_name: str, params, expected ): - param_section = ( - " with params" - f" {params if isinstance(params, str) else json.dumps(params)}" - if params is not None - else "" - ) - agent_runner.run( - f"Call {function_name}{param_section} and show me the result" - ) + param_section = ( + " with params" + f" {params if isinstance(params, str) else json.dumps(params)}" + if params is not None + else "" + ) + agent_runner.run( + f"Call {function_name}{param_section} and show me the result" + ) - model_response_event = agent_runner.get_events()[-1] - assert model_response_event.author == "context_variable_update_agent" - assert model_response_event.content.role == "model" - assert expected in model_response_event.content.parts[0].text.strip() + model_response_event = agent_runner.get_events()[-1] + assert model_response_event.author == "context_variable_update_agent" + assert model_response_event.content.role == "model" + assert expected in model_response_event.content.parts[0].text.strip()