diff --git a/garak b/garak new file mode 160000 index 000000000..a4e29f929 --- /dev/null +++ b/garak @@ -0,0 +1 @@ +Subproject commit a4e29f929a0247682d2bc82a615d7c9eb6a6936f diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 1eb282848..961f8af4f 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, List, Literal, Optional from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage @@ -41,6 +41,7 @@ def __init__( input_key: str = "input", output_key: str = "output", verbose: bool = False, + consistent_output_format: Literal["preserve", "always_dict", "always_string"] = "preserve", ) -> None: self.llm = llm self.passthrough = passthrough @@ -48,7 +49,10 @@ def __init__( self.passthrough_user_input_key = input_key self.passthrough_bot_output_key = output_key self.verbose = verbose + self.consistent_output_format = consistent_output_format self.config: Optional[RunnableConfig] = None + self._current_config: Optional[RunnableConfig] = None + self._current_kwargs: dict = {} # We override the config passthrough. config.passthrough = passthrough @@ -74,7 +78,10 @@ async def passthrough_fn(context: dict, events: List[dict]): # First, we fetch the input from the context _input = context.get("passthrough_input") async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke) - _output = await async_wrapped_invoke(_input, self.config, **self.kwargs) + + # Pass the config and kwargs that were captured in the invoke method + # This ensures that callbacks (like Langfuse tracing) are properly propagated + _output = await async_wrapped_invoke(_input, self._current_config, **self._current_kwargs) # If the output is a string, we consider it to be the output text if isinstance(_output, str): @@ -86,6 +93,36 @@ async def passthrough_fn(context: dict, events: List[dict]): self.rails.llm_generation_actions.passthrough_fn = passthrough_fn + def _format_output_consistently(self, output: Any, input_type: Any) -> Any: + """Format output according to consistent_output_format setting.""" + if self.consistent_output_format == "preserve": + return output + elif self.consistent_output_format == "always_dict": + if isinstance(output, str): + return {self.passthrough_bot_output_key: output} + elif isinstance(output, dict): + return output + else: + return {self.passthrough_bot_output_key: str(output)} + elif self.consistent_output_format == "always_string": + if isinstance(output, dict): + # Try to extract string from dict + if self.passthrough_bot_output_key in output: + return output[self.passthrough_bot_output_key] + elif "content" in output: + return output["content"] + elif "output" in output: + return output["output"] + else: + # Return the first string value found or convert to string + for value in output.values(): + if isinstance(value, str): + return value + return str(output) + else: + return str(output) + return output + def __or__(self, other): if isinstance(other, BaseLanguageModel): self.llm = other @@ -188,8 +225,12 @@ def invoke( ) -> Output: """Invoke this runnable synchronously.""" input_messages = self._transform_input_to_rails_format(input) + # Store config and kwargs for use in passthrough function + # This ensures callbacks are properly passed to the underlying runnable self.config = config self.kwargs = kwargs + self._current_config = config + self._current_kwargs = kwargs res = self.rails.generate( messages=input_messages, options=GenerationOptions(output_vars=True) ) @@ -222,20 +263,25 @@ def invoke( elif isinstance(passthrough_output, dict): passthrough_output[self.passthrough_bot_output_key] = bot_message - return passthrough_output + return self._format_output_consistently(passthrough_output, type(input)) else: if isinstance(input, ChatPromptValue): - return AIMessage(content=result["content"]) + output = AIMessage(content=result["content"]) + return self._format_output_consistently(output, type(input)) elif isinstance(input, StringPromptValue): if isinstance(result, dict): - return result["content"] + output = result["content"] else: - return result + output = result + return self._format_output_consistently(output, type(input)) elif isinstance(input, dict): user_input = input["input"] if isinstance(user_input, str): - return {"output": result["content"]} + output = {"output": result["content"]} elif isinstance(user_input, list): - return {"output": result} + output = {"output": result} + else: + output = {"output": result["content"]} + return self._format_output_consistently(output, type(input)) else: raise ValueError(f"Unexpected input type: {type(input)}") diff --git a/tests/test_runnable_rails.py b/tests/test_runnable_rails.py index 10b33c056..453cb685e 100644 --- a/tests/test_runnable_rails.py +++ b/tests/test_runnable_rails.py @@ -658,3 +658,124 @@ def log(x): print(result) assert "LOL" not in result["output"] assert "can't respond" in result["output"] + + +def test_runnable_config_callback_passthrough(): + """Test that RunnableConfig with callbacks is properly passed to passthrough runnable.""" + config_received = [] + + class CallbackTestRunnable(Runnable): + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + # Capture the config to verify callbacks were passed + config_received.append(config) + return {"output": "Test response"} + + # Create a mock callback for testing + mock_callbacks = ["mock_callback"] + test_config = RunnableConfig(callbacks=mock_callbacks) + + rails_config = RailsConfig.from_content(config={"models": []}) + runnable_with_rails = RunnableRails( + rails_config, passthrough=True, runnable=CallbackTestRunnable() + ) + + # Invoke with the config containing callbacks + result = runnable_with_rails.invoke("test input", config=test_config) + + # Verify that the config with callbacks was passed through + assert len(config_received) == 1 + assert config_received[0] is not None + assert config_received[0].get("callbacks") == mock_callbacks + assert result == {"output": "Test response"} + + +def test_consistent_output_format_preserve(): + """Test that preserve mode maintains original inconsistent behavior.""" + config = RailsConfig.from_content(config={"models": []}) + + # Test normal passthrough (should return string) + runnable_with_rails = RunnableRails( + config, passthrough=True, runnable=MockRunnable2(), + consistent_output_format="preserve" + ) + result = runnable_with_rails.invoke("test input") + assert result == "PARIS!!" # String format preserved + + +def test_consistent_output_format_always_dict(): + """Test that always_dict mode forces dictionary format.""" + config = RailsConfig.from_content(config={"models": []}) + + # Test with string output - should be converted to dict + runnable_with_rails = RunnableRails( + config, passthrough=True, runnable=MockRunnable2(), + consistent_output_format="always_dict" + ) + result = runnable_with_rails.invoke("test input") + assert result == {"output": "PARIS!!"} # Converted to dict + + # Test with dict output - should remain dict + runnable_with_rails2 = RunnableRails( + config, passthrough=True, runnable=MockRunnable(), + consistent_output_format="always_dict" + ) + result2 = runnable_with_rails2.invoke("test input") + assert result2 == {"output": "PARIS!!"} # Already dict format + + +def test_consistent_output_format_always_string(): + """Test that always_string mode forces string format when possible.""" + config = RailsConfig.from_content(config={"models": []}) + + # Test with dict output - should be converted to string + runnable_with_rails = RunnableRails( + config, passthrough=True, runnable=MockRunnable(), + consistent_output_format="always_string" + ) + result = runnable_with_rails.invoke("test input") + assert result == "PARIS!!" # Extracted from dict + + # Test with string output - should remain string + runnable_with_rails2 = RunnableRails( + config, passthrough=True, runnable=MockRunnable2(), + consistent_output_format="always_string" + ) + result2 = runnable_with_rails2.invoke("test input") + assert result2 == "PARIS!!" # Already string + + +def test_consistent_output_format_with_rails_blocking(): + """Test consistent format when rails block the input/output.""" + llm = FakeLLM(responses=[" ask off topic question", " ask off topic question"]) + config = RailsConfig.from_content( + config={"models": []}, + colang_content=""" + define user ask off topic question + "Can you help me cook something?" + + define flow + user ask off topic question + bot refuse to respond + + define bot refuse to respond + "I'm sorry, I can't help with that." + """, + ) + + # Test with always_string - even when rails trigger, should return string + runnable_with_rails = RunnableRails( + config, llm=llm, passthrough=True, runnable=MockRunnable(), + consistent_output_format="always_string" + ) + result = runnable_with_rails.invoke("This is an off topic question") + assert isinstance(result, str) + assert result == "I'm sorry, I can't help with that." + + # Test with always_dict - should return dict format + runnable_with_rails2 = RunnableRails( + config, llm=llm, passthrough=True, runnable=MockRunnable(), + consistent_output_format="always_dict" + ) + result2 = runnable_with_rails2.invoke("This is an off topic question") + assert isinstance(result2, dict) + assert result2 == {"output": "I'm sorry, I can't help with that."}