diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 8669971d09..50651e82c4 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -104,7 +104,7 @@ jobs: - name: Upload coverage to Codecov if: inputs.with-coverage - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v4 with: directory: ./coverage/reports/ env_vars: PYTHON diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 63c0fa27b7..06ed0fb639 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1847,7 +1847,7 @@ def _prepare_params( bot_response_chunk: str, prompt: Optional[str] = None, messages: Optional[List[dict]] = None, - action_params: Dict[str, Any] = {}, + action_params: Optional[Dict[str, Any]] = None, ): context_message = _get_last_context_message(messages) user_message = prompt or _get_latest_user_message(messages) @@ -1862,9 +1862,12 @@ def _prepare_params( model_name = flow_id.split("$")[-1].split("=")[-1].strip('"') - # we pass action params that are defined in the flow - # caveate, e.g. prmpt_security uses bot_response=$bot_message - # to resolve replace placeholders in action_params + # Shallow-copy before substituting placeholders so we never mutate + # the original dict returned by get_action_details_from_flow_id. + # Without this copy, "$bot_message" gets replaced by the first + # chunk's text and every subsequent chunk receives the stale value. + action_params = dict(action_params or {}) + for key, value in action_params.items(): if value == "$bot_message": action_params[key] = bot_response_chunk diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index a9791eba76..96f66932c4 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -450,3 +450,176 @@ async def self_check_output(**kwargs): tokens.append(token) assert "".join(tokens) == "This is a complete response in a single chunk." + + +@pytest.mark.asyncio +async def test_streaming_action_params_not_stale_across_chunks(): + """Regression test for GH-1935: explicit $bot_message action params must be + re-substituted from the original placeholder on every chunk batch, not reuse + the stale text substituted in the first batch.""" + + received_chunks = [] + + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "flows": ["record chunk"], + "streaming": { + "enabled": True, + "chunk_size": 1, + "context_size": 0, + "stream_first": False, + }, + } + }, + "streaming": True, + }, + colang_content=""" + define flow record chunk + execute record_chunk(chunk=$bot_message) + """, + ) + + rails = LLMRails(config) + + @action(name="record_chunk", is_system_action=True) + async def record_chunk(chunk=None, **kwargs): + received_chunks.append(chunk) + return True + + rails.register_action(record_chunk, "record_chunk") + + async def two_word_generator(): + yield "alpha " + yield "beta" + + async for _ in rails.stream_async( + generator=two_word_generator(), + messages=[{"role": "user", "content": "Hello"}], + ): + pass + + assert len(received_chunks) >= 2, f"Expected at least 2 action invocations (one per chunk), got {received_chunks}" + assert "alpha" in received_chunks[0], f"First chunk unexpected: {received_chunks[0]!r}" + assert "beta" in received_chunks[1], ( + f"Stale action_params bug: second batch got {received_chunks[1]!r} instead of 'beta'" + ) + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_streaming_action_params_original_flow_config_not_mutated(): + """After streaming, the parsed flow config's action_params dict must still + hold the '$bot_message' placeholder — _prepare_params must never mutate it.""" + + from nemoguardrails.rails.llm.utils import get_action_details_from_flow_id + + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "flows": ["record chunk"], + "streaming": { + "enabled": True, + "chunk_size": 1, + "context_size": 0, + "stream_first": False, + }, + } + }, + "streaming": True, + }, + colang_content=""" + define flow record chunk + execute record_chunk(chunk=$bot_message) + """, + ) + + _, original_params = get_action_details_from_flow_id("record chunk", config.flows) + assert original_params.get("chunk") == "$bot_message", ( + "Test setup error: expected '$bot_message' placeholder in parsed action_params" + ) + + rails = LLMRails(config) + + @action(name="record_chunk", is_system_action=True) + async def record_chunk(chunk=None, **kwargs): + return True + + rails.register_action(record_chunk, "record_chunk") + + async def single_word_gen(): + yield "hello" + + async for _ in rails.stream_async( + generator=single_word_gen(), + messages=[{"role": "user", "content": "Hi"}], + ): + pass + + assert original_params.get("chunk") == "$bot_message", ( + "Stale action_params bug: _prepare_params mutated the original flow config dict; " + f"'$bot_message' was replaced with {original_params.get('chunk')!r}" + ) + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_streaming_user_message_param_substituted(): + """$user_message in action params is substituted correctly (covers the user_message + substitution branch in _prepare_params alongside the $bot_message branch).""" + + received_user_messages = [] + + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "flows": ["echo user"], + "streaming": { + "enabled": True, + "chunk_size": 1, + "context_size": 0, + "stream_first": False, + }, + } + }, + "streaming": True, + }, + colang_content=""" + define flow echo user + execute echo_user(user=$user_message) + """, + ) + + rails = LLMRails(config) + + @action(name="echo_user", is_system_action=True) + async def echo_user(user=None, **kwargs): + received_user_messages.append(user) + return True + + rails.register_action(echo_user, "echo_user") + + async def one_word_gen(): + yield "response" + + async for _ in rails.stream_async( + generator=one_word_gen(), + messages=[{"role": "user", "content": "hello there"}], + ): + pass + + assert len(received_user_messages) >= 1 + received = received_user_messages[0] + assert (isinstance(received, dict) and received.get("content") == "hello there") or received == "hello there", ( + f"$user_message was not substituted correctly: {received!r}" + ) + + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}, return_exceptions=True)