Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ def _prepare_params(
bot_response_chunk: str,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
action_params: Dict[str, Any] = {},
action_params: Optional[Dict[str, Any]] = None,
):
context_message = _get_last_context_message(messages)
user_message = prompt or _get_latest_user_message(messages)
Expand All @@ -1862,9 +1862,12 @@ def _prepare_params(

model_name = flow_id.split("$")[-1].split("=")[-1].strip('"')

# we pass action params that are defined in the flow
# caveate, e.g. prmpt_security uses bot_response=$bot_message
# to resolve replace placeholders in action_params
# Shallow-copy before substituting placeholders so we never mutate
# the original dict returned by get_action_details_from_flow_id.
# Without this copy, "$bot_message" gets replaced by the first
# chunk's text and every subsequent chunk receives the stale value.
action_params = dict(action_params or {})

for key, value in action_params.items():
if value == "$bot_message":
action_params[key] = bot_response_chunk
Expand Down
175 changes: 175 additions & 0 deletions tests/test_streaming_output_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,178 @@ async def self_check_output(**kwargs):
tokens.append(token)

assert "".join(tokens) == "This is a complete response in a single chunk."


@pytest.mark.asyncio
async def test_streaming_action_params_not_stale_across_chunks():
"""Regression test for GH-1935: explicit $bot_message action params must be
re-substituted from the original placeholder on every chunk batch, not reuse
the stale text substituted in the first batch."""

received_chunks = []

config = RailsConfig.from_content(
config={
"models": [],
"rails": {
"output": {
"flows": ["record chunk"],
"streaming": {
"enabled": True,
"chunk_size": 1,
"context_size": 0,
"stream_first": False,
},
}
},
"streaming": True,
},
colang_content="""
define flow record chunk
execute record_chunk(chunk=$bot_message)
""",
)

rails = LLMRails(config)

@action(name="record_chunk", is_system_action=True)
async def record_chunk(chunk=None, **kwargs):
received_chunks.append(chunk)
return True

rails.register_action(record_chunk, "record_chunk")

async def two_word_generator():
yield "alpha "
yield "beta"

async for _ in rails.stream_async(
generator=two_word_generator(),
messages=[{"role": "user", "content": "Hello"}],
):
pass

assert len(received_chunks) >= 2, (
f"Expected at least 2 action invocations (one per chunk), got {received_chunks}"
)
assert "alpha" in received_chunks[0], f"First chunk unexpected: {received_chunks[0]!r}"
assert "beta" in received_chunks[1], (
f"Stale action_params bug: second batch got {received_chunks[1]!r} instead of 'beta'"
)

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
Comment thread
nac7 marked this conversation as resolved.
Outdated


@pytest.mark.asyncio
async def test_streaming_action_params_original_flow_config_not_mutated():
"""After streaming, the parsed flow config's action_params dict must still
hold the '$bot_message' placeholder — _prepare_params must never mutate it."""

from nemoguardrails.rails.llm.utils import get_action_details_from_flow_id

config = RailsConfig.from_content(
config={
"models": [],
"rails": {
"output": {
"flows": ["record chunk"],
"streaming": {
"enabled": True,
"chunk_size": 1,
"context_size": 0,
"stream_first": False,
},
}
},
"streaming": True,
},
colang_content="""
define flow record chunk
execute record_chunk(chunk=$bot_message)
""",
)

_, original_params = get_action_details_from_flow_id("record chunk", config.flows)
assert original_params.get("chunk") == "$bot_message", (
"Test setup error: expected '$bot_message' placeholder in parsed action_params"
)

rails = LLMRails(config)

@action(name="record_chunk", is_system_action=True)
async def record_chunk(chunk=None, **kwargs):
return True

rails.register_action(record_chunk, "record_chunk")

async def single_word_gen():
yield "hello"

async for _ in rails.stream_async(
generator=single_word_gen(),
messages=[{"role": "user", "content": "Hi"}],
):
pass

assert original_params.get("chunk") == "$bot_message", (
"Stale action_params bug: _prepare_params mutated the original flow config dict; "
f"'$bot_message' was replaced with {original_params.get('chunk')!r}"
)

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})


@pytest.mark.asyncio
async def test_streaming_user_message_param_substituted():
"""$user_message in action params is substituted correctly (covers the user_message
substitution branch in _prepare_params alongside the $bot_message branch)."""

received_user_messages = []

config = RailsConfig.from_content(
config={
"models": [],
"rails": {
"output": {
"flows": ["echo user"],
"streaming": {
"enabled": True,
"chunk_size": 1,
"context_size": 0,
"stream_first": False,
},
}
},
"streaming": True,
},
colang_content="""
define flow echo user
execute echo_user(user=$user_message)
""",
)

rails = LLMRails(config)

@action(name="echo_user", is_system_action=True)
async def echo_user(user=None, **kwargs):
received_user_messages.append(user)
return True

rails.register_action(echo_user, "echo_user")

async def one_word_gen():
yield "response"

async for _ in rails.stream_async(
generator=one_word_gen(),
messages=[{"role": "user", "content": "hello there"}],
):
pass

assert len(received_user_messages) >= 1
received = received_user_messages[0]
assert isinstance(received, dict) and received.get("content") == "hello there" or received == "hello there", (
Comment thread
nac7 marked this conversation as resolved.
Outdated
f"$user_message was not substituted correctly: {received!r}"
)

await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
Loading