Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions garak
Submodule garak added at a4e29f
62 changes: 54 additions & 8 deletions nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from typing import Any, List, Optional
from typing import Any, List, Literal, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
Expand All @@ -41,14 +41,18 @@ def __init__(
input_key: str = "input",
output_key: str = "output",
verbose: bool = False,
consistent_output_format: Literal["preserve", "always_dict", "always_string"] = "preserve",
) -> None:
self.llm = llm
self.passthrough = passthrough
self.passthrough_runnable = runnable
self.passthrough_user_input_key = input_key
self.passthrough_bot_output_key = output_key
self.verbose = verbose
self.consistent_output_format = consistent_output_format
self.config: Optional[RunnableConfig] = None
self._current_config: Optional[RunnableConfig] = None
self._current_kwargs: dict = {}

# We override the config passthrough.
config.passthrough = passthrough
Expand All @@ -74,7 +78,10 @@ async def passthrough_fn(context: dict, events: List[dict]):
# First, we fetch the input from the context
_input = context.get("passthrough_input")
async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke)
_output = await async_wrapped_invoke(_input, self.config, **self.kwargs)

# Pass the config and kwargs that were captured in the invoke method
# This ensures that callbacks (like Langfuse tracing) are properly propagated
_output = await async_wrapped_invoke(_input, self._current_config, **self._current_kwargs)

# If the output is a string, we consider it to be the output text
if isinstance(_output, str):
Expand All @@ -86,6 +93,36 @@ async def passthrough_fn(context: dict, events: List[dict]):

self.rails.llm_generation_actions.passthrough_fn = passthrough_fn

def _format_output_consistently(self, output: Any, input_type: Any) -> Any:
"""Format output according to consistent_output_format setting."""
if self.consistent_output_format == "preserve":
return output
elif self.consistent_output_format == "always_dict":
if isinstance(output, str):
return {self.passthrough_bot_output_key: output}
elif isinstance(output, dict):
return output
else:
return {self.passthrough_bot_output_key: str(output)}
elif self.consistent_output_format == "always_string":
if isinstance(output, dict):
# Try to extract string from dict
if self.passthrough_bot_output_key in output:
return output[self.passthrough_bot_output_key]
elif "content" in output:
return output["content"]
elif "output" in output:
return output["output"]
else:
# Return the first string value found or convert to string
for value in output.values():
if isinstance(value, str):
return value
return str(output)
else:
return str(output)
return output

def __or__(self, other):
if isinstance(other, BaseLanguageModel):
self.llm = other
Expand Down Expand Up @@ -188,8 +225,12 @@ def invoke(
) -> Output:
"""Invoke this runnable synchronously."""
input_messages = self._transform_input_to_rails_format(input)
# Store config and kwargs for use in passthrough function
# This ensures callbacks are properly passed to the underlying runnable
self.config = config
self.kwargs = kwargs
self._current_config = config
self._current_kwargs = kwargs
res = self.rails.generate(
messages=input_messages, options=GenerationOptions(output_vars=True)
)
Expand Down Expand Up @@ -222,20 +263,25 @@ def invoke(
elif isinstance(passthrough_output, dict):
passthrough_output[self.passthrough_bot_output_key] = bot_message

return passthrough_output
return self._format_output_consistently(passthrough_output, type(input))
else:
if isinstance(input, ChatPromptValue):
return AIMessage(content=result["content"])
output = AIMessage(content=result["content"])
return self._format_output_consistently(output, type(input))
elif isinstance(input, StringPromptValue):
if isinstance(result, dict):
return result["content"]
output = result["content"]
else:
return result
output = result
return self._format_output_consistently(output, type(input))
elif isinstance(input, dict):
user_input = input["input"]
if isinstance(user_input, str):
return {"output": result["content"]}
output = {"output": result["content"]}
elif isinstance(user_input, list):
return {"output": result}
output = {"output": result}
else:
output = {"output": result["content"]}
return self._format_output_consistently(output, type(input))
else:
raise ValueError(f"Unexpected input type: {type(input)}")
121 changes: 121 additions & 0 deletions tests/test_runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,124 @@ def log(x):
print(result)
assert "LOL" not in result["output"]
assert "can't respond" in result["output"]


def test_runnable_config_callback_passthrough():
"""Test that RunnableConfig with callbacks is properly passed to passthrough runnable."""
config_received = []

class CallbackTestRunnable(Runnable):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# Capture the config to verify callbacks were passed
config_received.append(config)
return {"output": "Test response"}

# Create a mock callback for testing
mock_callbacks = ["mock_callback"]
test_config = RunnableConfig(callbacks=mock_callbacks)

rails_config = RailsConfig.from_content(config={"models": []})
runnable_with_rails = RunnableRails(
rails_config, passthrough=True, runnable=CallbackTestRunnable()
)

# Invoke with the config containing callbacks
result = runnable_with_rails.invoke("test input", config=test_config)

# Verify that the config with callbacks was passed through
assert len(config_received) == 1
assert config_received[0] is not None
assert config_received[0].get("callbacks") == mock_callbacks
assert result == {"output": "Test response"}


def test_consistent_output_format_preserve():
"""Test that preserve mode maintains original inconsistent behavior."""
config = RailsConfig.from_content(config={"models": []})

# Test normal passthrough (should return string)
runnable_with_rails = RunnableRails(
config, passthrough=True, runnable=MockRunnable2(),
consistent_output_format="preserve"
)
result = runnable_with_rails.invoke("test input")
assert result == "PARIS!!" # String format preserved


def test_consistent_output_format_always_dict():
"""Test that always_dict mode forces dictionary format."""
config = RailsConfig.from_content(config={"models": []})

# Test with string output - should be converted to dict
runnable_with_rails = RunnableRails(
config, passthrough=True, runnable=MockRunnable2(),
consistent_output_format="always_dict"
)
result = runnable_with_rails.invoke("test input")
assert result == {"output": "PARIS!!"} # Converted to dict

# Test with dict output - should remain dict
runnable_with_rails2 = RunnableRails(
config, passthrough=True, runnable=MockRunnable(),
consistent_output_format="always_dict"
)
result2 = runnable_with_rails2.invoke("test input")
assert result2 == {"output": "PARIS!!"} # Already dict format


def test_consistent_output_format_always_string():
"""Test that always_string mode forces string format when possible."""
config = RailsConfig.from_content(config={"models": []})

# Test with dict output - should be converted to string
runnable_with_rails = RunnableRails(
config, passthrough=True, runnable=MockRunnable(),
consistent_output_format="always_string"
)
result = runnable_with_rails.invoke("test input")
assert result == "PARIS!!" # Extracted from dict

# Test with string output - should remain string
runnable_with_rails2 = RunnableRails(
config, passthrough=True, runnable=MockRunnable2(),
consistent_output_format="always_string"
)
result2 = runnable_with_rails2.invoke("test input")
assert result2 == "PARIS!!" # Already string


def test_consistent_output_format_with_rails_blocking():
"""Test consistent format when rails block the input/output."""
llm = FakeLLM(responses=[" ask off topic question", " ask off topic question"])
config = RailsConfig.from_content(
config={"models": []},
colang_content="""
define user ask off topic question
"Can you help me cook something?"

define flow
user ask off topic question
bot refuse to respond

define bot refuse to respond
"I'm sorry, I can't help with that."
""",
)

# Test with always_string - even when rails trigger, should return string
runnable_with_rails = RunnableRails(
config, llm=llm, passthrough=True, runnable=MockRunnable(),
consistent_output_format="always_string"
)
result = runnable_with_rails.invoke("This is an off topic question")
assert isinstance(result, str)
assert result == "I'm sorry, I can't help with that."

# Test with always_dict - should return dict format
runnable_with_rails2 = RunnableRails(
config, llm=llm, passthrough=True, runnable=MockRunnable(),
consistent_output_format="always_dict"
)
result2 = runnable_with_rails2.invoke("This is an off topic question")
assert isinstance(result2, dict)
assert result2 == {"output": "I'm sorry, I can't help with that."}
Loading