From f15ce67d0ce1d61fe799a525c003d4bba0d1eab7 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Sun, 14 Sep 2025 23:03:30 -0500 Subject: [PATCH 01/23] Check in rails/ type-fixes --- nemoguardrails/actions/llm/generation.py | 2 +- nemoguardrails/context.py | 31 ++- nemoguardrails/rails/llm/buffer.py | 17 +- nemoguardrails/rails/llm/config.py | 8 +- nemoguardrails/rails/llm/llmrails.py | 314 ++++++++++++++++++----- nemoguardrails/rails/llm/options.py | 23 +- nemoguardrails/utils.py | 2 +- 7 files changed, 305 insertions(+), 92 deletions(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 74fa763c5..377b0bc5e 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -82,7 +82,7 @@ class LLMGenerationActions: def __init__( self, config: RailsConfig, - llm: Union[BaseLLM, BaseChatModel], + llm: Optional[Union[BaseLLM, BaseChatModel]], llm_task_manager: LLMTaskManager, get_embedding_search_provider_instance: Callable[ [Optional[EmbeddingSearchProvider]], EmbeddingsIndex diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index 0659faafb..6688af919 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -14,25 +14,42 @@ # limitations under the License. import contextvars -from typing import Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None) +if TYPE_CHECKING: + from nemoguardrails.logging.explain import ExplainInfo + from nemoguardrails.rails.llm.options import GenerationOptions, LLMStats + from nemoguardrails.streaming import StreamingHandler + +streaming_handler_var: contextvars.ContextVar[ + Optional["StreamingHandler"] +] = contextvars.ContextVar("streaming_handler", default=None) # The object that holds additional explanation information. -explain_info_var = contextvars.ContextVar("explain_info", default=None) +explain_info_var: contextvars.ContextVar[ + Optional["ExplainInfo"] +] = contextvars.ContextVar("explain_info", default=None) # The current LLM call. -llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None) +llm_call_info_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "llm_call_info", default=None +) # All the generation options applicable to the current context. -generation_options_var = contextvars.ContextVar("generation_options", default=None) +generation_options_var: contextvars.ContextVar[ + Optional["GenerationOptions"] +] = contextvars.ContextVar("generation_options", default=None) # The stats about the LLM calls. -llm_stats_var = contextvars.ContextVar("llm_stats", default=None) +llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar( + "llm_stats", default=None +) # The raw LLM request that comes from the user. # This is used in passthrough mode. -raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None) +raw_llm_request: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar( + "raw_llm_request", default=None +) reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "reasoning_trace", default=None diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index 30e48c4e3..541f52915 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -14,7 +14,10 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, NamedTuple +from typing import TYPE_CHECKING, AsyncGenerator, List, NamedTuple + +if TYPE_CHECKING: + from collections.abc import AsyncIterator from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig @@ -111,9 +114,7 @@ def format_chunks(self, chunks: List[str]) -> str: ... @abstractmethod - async def process_stream( - self, streaming_handler - ) -> AsyncGenerator[ChunkBatch, None]: + async def process_stream(self, streaming_handler): """Process streaming chunks and yield chunk batches. This is the main method that concrete buffer strategies must implement. @@ -138,9 +139,9 @@ async def process_stream( ... print(f"Processing: {context_formatted}") ... print(f"User: {user_formatted}") """ - ... + yield ChunkBatch([], []) # pragma: no cover - async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]: + async def __call__(self, streaming_handler): """Callable interface that delegates to process_stream. It delegates to the `process_stream` method and can @@ -256,9 +257,7 @@ def from_config(cls, config: OutputRailsStreamingConfig): buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size ) - async def process_stream( - self, streaming_handler - ) -> AsyncGenerator[ChunkBatch, None]: + async def process_stream(self, streaming_handler): """Process streaming chunks using rolling buffer strategy. This method implements the rolling buffer logic, accumulating chunks diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index eac54fd37..f84139cf5 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1128,7 +1128,9 @@ def _load_path( # the first .railsignore file found from cwd down to its subdirectories railsignore_path = utils.get_railsignore_path(config_path) - ignore_patterns = utils.get_railsignore_patterns(railsignore_path) + ignore_patterns = ( + utils.get_railsignore_patterns(railsignore_path) if railsignore_path else set() + ) if os.path.isdir(config_path): for root, _, files in os.walk(config_path, followlinks=True): @@ -1245,8 +1247,8 @@ def _parse_colang_files_recursively( current_file, current_path = colang_files[len(parsed_colang_files)] with open(current_path, "r", encoding="utf-8") as f: + content = f.read() try: - content = f.read() _parsed_config = parse_colang_file( current_file, content=content, version=colang_version ) @@ -1748,7 +1750,7 @@ def streaming_supported(self): # if we have output rails streaming enabled # we keep it in case it was needed when we have # support per rails - if self.rails.output.streaming.enabled: + if self.rails.output.streaming and self.rails.output.streaming.enabled: return True return False diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index fb1bcdf19..c7c292cd7 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -24,7 +24,18 @@ import threading import time from functools import partial -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Type, Union, cast +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from langchain_core.language_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM @@ -205,17 +216,18 @@ def __init__( # We check if the configuration or any of the imported ones have config.py modules. config_modules = [] - for _path in list(self.config.imported_paths.values()) + [ - self.config.config_path - ]: + for _path in list( + self.config.imported_paths.values() if self.config.imported_paths else [] + ) + [self.config.config_path]: if _path: filepath = os.path.join(_path, "config.py") if os.path.exists(filepath): filename = os.path.basename(filepath) spec = importlib.util.spec_from_file_location(filename, filepath) - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) - config_modules.append(config_module) + if spec and spec.loader: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + config_modules.append(config_module) # First, we initialize the runtime. if config.colang_version == "1.0": @@ -393,8 +405,8 @@ def _configure_main_llm_streaming( if not self.config.streaming: return - if "streaming" in llm.model_fields: - llm.streaming = True + if hasattr(llm, "streaming"): + setattr(llm, "streaming", True) self.main_llm_supports_streaming = True else: self.main_llm_supports_streaming = False @@ -509,6 +521,151 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) + def _create_isolated_llms_for_actions(self): + """Create isolated LLM copies for all actions that accept 'llm' parameter.""" + if not self.llm: + log.debug("No main LLM available for creating isolated copies") + return + + try: + actions_needing_llms = self._detect_llm_requiring_actions() + log.info( + "%d actions requiring isolated LLMs: %s", + len(actions_needing_llms), + list(actions_needing_llms), + ) + + created_count = 0 + + configured_actions_names = [] + try: + if self.config.flows: + get_action_details = partial( + get_action_details_from_flow_id, flows=self.config.flows + ) + for flow_id in self.config.rails.input.flows: + action_name, _ = get_action_details(flow_id) + configured_actions_names.append(action_name) + for flow_id in self.config.rails.output.flows: + action_name, _ = get_action_details(flow_id) + configured_actions_names.append(action_name) + else: + # for configurations without flow definitions, use all actions that need LLMs + log.info( + "No flow definitions found, creating isolated LLMs for all actions requiring them" + ) + configured_actions_names = list(actions_needing_llms) + except Exception as e: + # if flow matching fails, fall back to all actions that need LLMs + log.info( + "Flow matching failed (%s), creating isolated LLMs for all actions requiring them", + e, + ) + configured_actions_names = list(actions_needing_llms) + + for action_name in configured_actions_names: + if action_name not in actions_needing_llms: + continue + if f"{action_name}_llm" not in self.runtime.registered_action_params: + isolated_llm = self._create_action_llm_copy(self.llm, action_name) + if isolated_llm: + self.runtime.register_action_param( + f"{action_name}_llm", isolated_llm + ) + created_count += 1 + log.debug("Created isolated LLM for action: %s", action_name) + else: + log.debug( + "Action %s already has dedicated LLM, skipping isolation", + action_name, + ) + + log.info("Created %d isolated LLM instances for actions", created_count) + + except Exception as e: + log.warning("Failed to create isolated LLMs for actions: %s", e) + + def _detect_llm_requiring_actions(self): + """Auto-detect actions that have 'llm' parameter.""" + import inspect + + actions_needing_llms = set() + + if ( + not hasattr(self.runtime, "action_dispatcher") + or not self.runtime.action_dispatcher + ): + log.debug("Action dispatcher not available") + return actions_needing_llms + + for ( + action_name, + action_info, + ) in self.runtime.action_dispatcher.registered_actions.items(): + action_func = self._get_action_function(action_info) + if not action_func: + continue + + try: + sig = inspect.signature(action_func) + if "llm" in sig.parameters: + actions_needing_llms.add(action_name) + log.debug("Action %s has 'llm' parameter", action_name) + + except Exception as e: + log.debug("Could not inspect action %s: %s", action_name, e) + + return actions_needing_llms + + def _get_action_function(self, action_info): + """Extract the actual function from action info.""" + return action_info if callable(action_info) else None + + def _create_action_llm_copy( + self, main_llm: Union[BaseLLM, BaseChatModel], action_name: str + ) -> Optional[Union[BaseLLM, BaseChatModel]]: + """Create an isolated copy of main LLM for a specific action.""" + import copy + + try: + # shallow copy to preserve HTTP clients, credentials, etc. + # but create new instance to avoid shared state + isolated_llm = copy.copy(main_llm) + + # isolate model_kwargs to prevent shared mutable state + if ( + hasattr(isolated_llm, "model_kwargs") + and getattr(isolated_llm, "model_kwargs", None) is not None + ): + setattr( + isolated_llm, + "model_kwargs", + getattr(isolated_llm, "model_kwargs").copy(), + ) + + log.debug( + "Successfully created isolated LLM copy for action: %s", action_name + ) + return isolated_llm + + except Exception as e: + error_msg = ( + "Failed to create isolated LLM instance for action '%s'. " + "This is required to prevent parameter contamination between different actions. " + "\n\nPossible solutions:" + "\n1. If using a custom LLM class, ensure it supports copy.copy() operation" + "\n2. Check that your LLM configuration doesn't contain non-copyable objects" + "\n3. Consider using a dedicated LLM configuration for action '%s'" + "\n\nOriginal error: %s" + "\n\nTo use a dedicated LLM for this action, add to your config:" + "\nmodels:" + "\n - type: %s" + "\n engine: " + "\n model: " + ) % (action_name, action_name, e, action_name) + log.error(error_msg) + raise RuntimeError(error_msg) + def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None ) -> EmbeddingsIndex: @@ -776,7 +933,8 @@ async def generate_async( options = GenerationOptions(**options) # Save the generation options in the current async context. - generation_options_var.set(options) + # At this point, options is either None or GenerationOptions + generation_options_var.set(options if not isinstance(options, dict) else None) if streaming_handler: streaming_handler_var.set(streaming_handler) @@ -796,16 +954,25 @@ async def generate_async( # If we have generation options, we also add them to the context if options: messages = [ - {"role": "context", "content": {"generation_options": options.dict()}} - ] + messages + { + "role": "context", + "content": { + "generation_options": getattr( + options, "dict", lambda: options + )() + }, + } + ] + (messages or []) # If the last message is from the assistant, rather than the user, then # we move that to the `$bot_message` variable. This is to enable a more # convenient interface. (only when dialog rails are disabled) if ( - messages[-1]["role"] == "assistant" + messages + and messages[-1]["role"] == "assistant" and options - and options.rails.dialog is False + and hasattr(options, "rails") + and getattr(getattr(options, "rails", None), "dialog", None) is False ): # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] @@ -822,7 +989,7 @@ async def generate_async( processing_log = [] # The array of events corresponding to the provided sequence of messages. - events = self._get_events_for_messages(messages, state) + events = self._get_events_for_messages(messages or [], state) if self.config.colang_version == "1.0": # If we had a state object, we also need to prepend the events from the state. @@ -846,10 +1013,10 @@ async def generate_async( # Push an error chunk instead of None. error_message = str(e) error_dict = extract_error_json(error_message) - error_payload = json.dumps(error_dict) + error_payload: str = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) # push a termination signal - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore # Re-raise the exact exception raise else: @@ -920,7 +1087,7 @@ async def generate_async( response_events.append(event) if exception: - new_message = {"role": "exception", "content": exception} + new_message: dict = {"role": "exception", "content": exception} else: # Ensure all items in responses are strings @@ -928,7 +1095,7 @@ async def generate_async( str(response) if not isinstance(response, str) else response for response in responses ] - new_message = {"role": "assistant", "content": "\n".join(responses)} + new_message: dict = {"role": "assistant", "content": "\n".join(responses)} if response_tool_calls: new_message["tool_calls"] = response_tool_calls if response_events: @@ -941,7 +1108,7 @@ async def generate_async( # If a state object is not used, then we use the implicit caching if state is None: # Save the new events in the history and update the cache - cache_key = get_history_cache_key(messages + [new_message]) + cache_key = get_history_cache_key((messages or []) + [new_message]) self.events_history_cache[cache_key] = events else: output_state = {"events": events} @@ -964,7 +1131,7 @@ async def generate_async( streaming_handler = streaming_handler_var.get() if streaming_handler: # print("Closing the stream handler explicitly") - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore # IF tracing is enabled we need to set GenerationLog attrs original_log_options = None @@ -1004,11 +1171,15 @@ async def generate_async( if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): if prompt: - res.response = reasoning_trace + res.response + # For prompt mode, response should be a string + if isinstance(res.response, str): + res.response = reasoning_trace + res.response else: - res.response[0]["content"] = ( - reasoning_trace + res.response[0]["content"] - ) + # For message mode, response should be a list + if isinstance(res.response, list) and len(res.response) > 0: + res.response[0]["content"] = ( + reasoning_trace + res.response[0]["content"] + ) if tool_calls: res.tool_calls = tool_calls @@ -1018,13 +1189,12 @@ async def generate_async( if self.config.colang_version == "1.0": # If output variables are specified, we extract their values - if options.output_vars: + if getattr(options, "output_vars", None): context = compute_context(events) - if isinstance(options.output_vars, list): + output_vars = getattr(options, "output_vars", None) + if isinstance(output_vars, list): # If we have only a selection of keys, we filter to only that. - res.output_data = { - k: context.get(k) for k in options.output_vars - } + res.output_data = {k: context.get(k) for k in output_vars} else: # Otherwise, we return the full context res.output_data = context @@ -1032,37 +1202,41 @@ async def generate_async( _log = compute_generation_log(processing_log) # Include information about activated rails and LLM calls if requested - if options.log.activated_rails or options.log.llm_calls: + log_options = getattr(options, "log", None) + if log_options and ( + getattr(log_options, "activated_rails", False) + or getattr(log_options, "llm_calls", False) + ): res.log = GenerationLog() # We always include the stats res.log.stats = _log.stats - if options.log.activated_rails: + if getattr(log_options, "activated_rails", False): res.log.activated_rails = _log.activated_rails - if options.log.llm_calls: + if getattr(log_options, "llm_calls", False): res.log.llm_calls = [] for activated_rail in _log.activated_rails: for executed_action in activated_rail.executed_actions: res.log.llm_calls.extend(executed_action.llm_calls) # Include internal events if requested - if options.log.internal_events: + if getattr(log_options, "internal_events", False): if res.log is None: res.log = GenerationLog() res.log.internal_events = new_events # Include the Colang history if requested - if options.log.colang_history: + if getattr(log_options, "colang_history", False): if res.log is None: res.log = GenerationLog() res.log.colang_history = get_colang_history(events) # Include the raw llm output if requested - if options.llm_output: + if getattr(options, "llm_output", False): # Currently, we include the output from the generation LLM calls. for activated_rail in _log.activated_rails: if activated_rail.type == "generation": @@ -1070,22 +1244,23 @@ async def generate_async( for llm_call in executed_action.llm_calls: res.llm_output = llm_call.raw_response else: - if options.output_vars: + if getattr(options, "output_vars", None): raise ValueError( "The `output_vars` option is not supported for Colang 2.0 configurations." ) - if ( - options.log.activated_rails - or options.log.llm_calls - or options.log.internal_events - or options.log.colang_history + log_options = getattr(options, "log", None) + if log_options and ( + getattr(log_options, "activated_rails", False) + or getattr(log_options, "llm_calls", False) + or getattr(log_options, "internal_events", False) + or getattr(log_options, "colang_history", False) ): raise ValueError( "The `log` option is not supported for Colang 2.0 configurations." ) - if options.llm_output: + if getattr(options, "llm_output", False): raise ValueError( "The `llm_output` option is not supported for Colang 2.0 configurations." ) @@ -1119,20 +1294,26 @@ async def generate_async( if original_log_options: if not any( ( - original_log_options.internal_events, - original_log_options.activated_rails, - original_log_options.llm_calls, - original_log_options.colang_history, + getattr(original_log_options, "internal_events", False), + getattr(original_log_options, "activated_rails", False), + getattr(original_log_options, "llm_calls", False), + getattr(original_log_options, "colang_history", False), ) ): res.log = None else: - if not original_log_options.internal_events: - res.log.internal_events = [] - if not original_log_options.activated_rails: - res.log.activated_rails = [] - if not original_log_options.llm_calls: - res.log.llm_calls = [] + # Ensure res.log exists before setting attributes + if res.log is not None: + if not getattr( + original_log_options, "internal_events", False + ): + res.log.internal_events = [] + if not getattr( + original_log_options, "activated_rails", False + ): + res.log.activated_rails = [] + if not getattr(original_log_options, "llm_calls", False): + res.log.llm_calls = [] return res else: @@ -1161,7 +1342,10 @@ def stream_async( # if an external generator is provided, use it directly if generator: - if self.config.rails.output.streaming.enabled: + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): return self._run_output_rails_in_streaming( streaming_handler=generator, messages=messages, @@ -1194,7 +1378,7 @@ async def _generation_task(): error_dict = extract_error_json(error_message) error_payload = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore task = asyncio.create_task(_generation_task()) @@ -1212,7 +1396,10 @@ def task_done_callback(task): # when we have output rails we wrap the streaming handler # if len(self.config.rails.output.flows) > 0: # - if self.config.rails.output.streaming.enabled: + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): # returns an async generator return self._run_output_rails_in_streaming( streaming_handler=streaming_handler, @@ -1367,7 +1554,7 @@ def process_events( self.process_events_async(events, state, blocking) ) - def register_action(self, action: callable, name: Optional[str] = None) -> Self: + def register_action(self, action: Callable, name: Optional[str] = None) -> Self: """Register a custom action for the rails configuration.""" self.runtime.register_action(action, name) return self @@ -1377,12 +1564,12 @@ def register_action_param(self, name: str, value: Any) -> Self: self.runtime.register_action_param(name, value) return self - def register_filter(self, filter_fn: callable, name: Optional[str] = None) -> Self: + def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self: """Register a custom filter for the rails configuration.""" self.runtime.llm_task_manager.register_filter(filter_fn, name) return self - def register_output_parser(self, output_parser: callable, name: str) -> Self: + def register_output_parser(self, output_parser: Callable, name: str) -> Self: """Register a custom output parser for the rails configuration.""" self.runtime.llm_task_manager.register_output_parser(output_parser, name) return self @@ -1427,6 +1614,8 @@ def register_embedding_provider( def explain(self) -> ExplainInfo: """Helper function to return the latest ExplainInfo object.""" + if self.explain_info is None: + self.explain_info = self._ensure_explain_info() return self.explain_info def __getstate__(self): @@ -1545,6 +1734,8 @@ def _prepare_params( } output_rails_streaming_config = self.config.rails.output.streaming + if output_rails_streaming_config is None: + raise ValueError("Output rails streaming config is not available") buffer_strategy = get_buffer_strategy(output_rails_streaming_config) output_rails_flows_id = self.config.rails.output.flows stream_first = stream_first or output_rails_streaming_config.stream_first @@ -1619,9 +1810,10 @@ def _prepare_params( pass else: # if there are any stop events, content was blocked or internal error occurred - if result.events: + result_events = getattr(result, "events", None) + if result_events: # extract the flow info from the first stop event - stop_event = result.events[0] + stop_event = result_events[0] blocked_flow = stop_event.get("flow_id", "output rails") error_type = stop_event.get("error_type") diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index dd9f87099..3e3bec0e9 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Generation options give more control over the generation and the result. +"""Generation options give more control over the generation and the result. For example, to run only the input rails:: @@ -233,7 +233,7 @@ class ActivatedRail(BaseModel): ) decisions: List[str] = Field( default_factory=list, - descriptino="A sequence of decisions made by the rail, e.g., 'bot refuse to respond', 'stop', 'continue'.", + description="A sequence of decisions made by the rail, e.g., 'bot refuse to respond', 'stop', 'continue'.", ) executed_actions: List[ExecutedAction] = Field( default_factory=list, description="The list of actions executed by the rail." @@ -327,7 +327,7 @@ def print_summary(self): duration = 0 print(f"- Total time: {self.stats.total_duration:.2f}s") - if self.stats.input_rails_duration: + if self.stats.input_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.input_rails_duration / self.stats.total_duration, 2 ) @@ -335,7 +335,7 @@ def print_summary(self): duration += self.stats.input_rails_duration print(f" - [{self.stats.input_rails_duration:.2f}s][{_pc}%]: INPUT Rails") - if self.stats.dialog_rails_duration: + if self.stats.dialog_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.dialog_rails_duration / self.stats.total_duration, 2 ) @@ -345,7 +345,7 @@ def print_summary(self): print( f" - [{self.stats.dialog_rails_duration:.2f}s][{_pc}%]: DIALOG Rails" ) - if self.stats.generation_rails_duration: + if self.stats.generation_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.generation_rails_duration / self.stats.total_duration, 2, @@ -356,7 +356,7 @@ def print_summary(self): print( f" - [{self.stats.generation_rails_duration:.2f}s][{_pc}%]: GENERATION Rails" ) - if self.stats.output_rails_duration: + if self.stats.output_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.output_rails_duration / self.stats.total_duration, 2 ) @@ -367,12 +367,12 @@ def print_summary(self): f" - [{self.stats.output_rails_duration:.2f}s][{_pc}%]: OUTPUT Rails" ) - processing_overhead = self.stats.total_duration - duration + processing_overhead = (self.stats.total_duration or 0) - duration if processing_overhead >= 0.01: _pc = round(100 - pc, 2) print(f" - [{processing_overhead:.2f}s][{_pc}%]: Processing overhead ") - if self.stats.llm_calls_count > 0: + if self.stats.llm_calls_count and self.stats.llm_calls_count > 0: print( f"- {self.stats.llm_calls_count} LLM calls, " f"{self.stats.llm_calls_duration:.2f}s total duration, " @@ -391,7 +391,10 @@ def print_summary(self): for action in activated_rail.executed_actions: llm_calls_count += len(action.llm_calls) llm_calls_durations.extend( - [f"{round(llm_call.duration, 2)}s" for llm_call in action.llm_calls] + [ + f"{round(llm_call.duration or 0, 2)}s" + for llm_call in action.llm_calls + ] ) print( f"- [{activated_rail.duration:.2f}s] {activated_rail.type.upper()} ({activated_rail.name}): " @@ -431,4 +434,4 @@ class GenerationResponse(BaseModel): if __name__ == "__main__": - print(GenerationOptions(**{"rails": {"input": False}})) + print(GenerationOptions(rails=GenerationRailsOptions(input=False))) diff --git a/nemoguardrails/utils.py b/nemoguardrails/utils.py index a337a978f..bc27a6c74 100644 --- a/nemoguardrails/utils.py +++ b/nemoguardrails/utils.py @@ -375,7 +375,7 @@ def get_railsignore_patterns(railsignore_path: Path) -> Set[str]: return ignored_patterns -def is_ignored_by_railsignore(filename: str, ignore_patterns: str) -> bool: +def is_ignored_by_railsignore(filename: str, ignore_patterns: Set[str]) -> bool: """Verify if a filename should be ignored by a railsignore pattern""" ignore = False From 7424e88fa13931ab67444661c9b328ddc9fac433 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 15:16:27 -0500 Subject: [PATCH 02/23] Add coverage to rails/llm/options.py --- tests/test_llm_options.py | 64 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/test_llm_options.py diff --git a/tests/test_llm_options.py b/tests/test_llm_options.py new file mode 100644 index 000000000..72226afda --- /dev/null +++ b/tests/test_llm_options.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM isolation with models that don't have model_kwargs field.""" + +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from pydantic import BaseModel, Field + +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails +from nemoguardrails.rails.llm.options import GenerationLog, GenerationStats + + +def test_generation_log_print_summary(capsys): + """Test printing rais stats with dummy data""" + + stats = GenerationStats( + input_rails_duration=1.0, + dialog_rails_duration=2.0, + generation_rails_duration=3.0, + output_rails_duration=4.0, + total_duration=10.0, # Sum of all previous rail durations + llm_calls_duration=8.0, # Less than total duration + llm_calls_count=4, # Input, dialog, generation and output calls + llm_calls_total_prompt_tokens=1000, + llm_calls_total_completion_tokens=2000, + llm_calls_total_tokens=3000, # Sum of prompt and completion tokens + ) + + generation_log = GenerationLog(activated_rails=[], stats=stats) + + generation_log.print_summary() + capture = capsys.readouterr() + capture_lines = capture.out.splitlines() + + # Check the correct times were printed + assert capture_lines[1] == "# General stats" + assert capture_lines[3] == "- Total time: 10.00s" + assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" + assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" + assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" + assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" + assert ( + capture_lines[8] + == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." + ) From 747b54036d798d729ce71666cdd8cdcfac2b26d8 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:09:21 -0500 Subject: [PATCH 03/23] Add coverage to _configure_main_llm_streaming() --- tests/rails/llm/test_config.py | 35 ++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index 7b4a3cfe1..e12efb4e6 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock + import pytest +from langchain.llms.base import BaseLLM from pydantic import ValidationError from nemoguardrails.rails.llm.config import ( @@ -23,6 +26,7 @@ RailsConfig, TaskPrompt, ) +from nemoguardrails.rails.llm.llmrails import LLMRails def test_task_prompt_valid_content(): @@ -307,3 +311,34 @@ def test_rails_config_none_config_path(): result2 = config3 + config4 assert result2.config_path == "" + + +def test_llm_rails_configure_streaming_with_attr(): + """Check LLM has the streaming attribute set if RailsConfig has it""" + + mock_llm = MagicMock(spec=BaseLLM) + config = RailsConfig( + models=[], + streaming=True, + ) + + rails = LLMRails(config, llm=mock_llm) + setattr(mock_llm, "streaming", None) + rails._configure_main_llm_streaming(llm=mock_llm) + + assert mock_llm.streaming + + +def test_llm_rails_configure_streaming_without_attr(caplog): + """Check LLM has the streaming attribute set if RailsConfig has it""" + + mock_llm = MagicMock(spec=BaseLLM) + config = RailsConfig( + models=[], + streaming=True, + ) + + rails = LLMRails(config, llm=mock_llm) + rails._configure_main_llm_streaming(mock_llm) + + assert caplog.messages[-1] == "Provided main LLM does not support streaming." From 74a0cabf4a8045b915d4a7a2c1431097826ffacb Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 19:40:03 -0500 Subject: [PATCH 04/23] Added test coverage for nemoguardrails/rails/llm/config.py --- nemoguardrails/rails/llm/options.py | 4 ++- tests/rails/llm/test_config.py | 51 +++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index 3e3bec0e9..f9a2f77dd 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -434,4 +434,6 @@ class GenerationResponse(BaseModel): if __name__ == "__main__": - print(GenerationOptions(rails=GenerationRailsOptions(input=False))) + print( + GenerationOptions(rails=GenerationRailsOptions(input=False)) + ) # pragma: no cover (Can't run as script for test coverage) diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index e12efb4e6..f79dbc0ad 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from unittest.mock import MagicMock import pytest from langchain.llms.base import BaseLLM from pydantic import ValidationError -from nemoguardrails.rails.llm.config import ( - Document, - Instruction, - Model, - RailsConfig, - TaskPrompt, -) +from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt from nemoguardrails.rails.llm.llmrails import LLMRails @@ -342,3 +337,45 @@ def test_llm_rails_configure_streaming_without_attr(caplog): rails._configure_main_llm_streaming(mock_llm) assert caplog.messages[-1] == "Provided main LLM does not support streaming." + + +def test_rails_config_streaming_supported_no_output_flows(): + """Check `streaming_supported` property doesn't depend on RailsConfig.streaming with no output flows""" + + config = RailsConfig( + models=[], + streaming=False, + ) + assert config.streaming_supported + + +def test_rails_config_flows_streaming_supported_true(): + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" + + rails = { + "output": { + "flows": ["content_safety_check_output"], + "streaming": {"enabled": True}, + } + } + prompts = [{"task": "content safety check output", "content": "..."}] + rails_config = RailsConfig.model_validate( + {"models": [], "rails": rails, "prompts": prompts} + ) + assert rails_config.streaming_supported + + +def test_rails_config_flows_streaming_supported_false(): + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" + + rails = { + "output": { + "flows": ["content_safety_check_output"], + "streaming": {"enabled": False}, + } + } + prompts = [{"task": "content safety check output", "content": "..."}] + rails_config = RailsConfig.model_validate( + {"models": [], "rails": rails, "prompts": prompts} + ) + assert not rails_config.streaming_supported From ec00d0f716cc227813fff712d1be912ac75916ca Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:19:16 -0500 Subject: [PATCH 05/23] Pass OutputRailsStreamingConfig into _run_output_rails_in_streaming as mandatory argument as all 2 calls check for None --- nemoguardrails/rails/llm/llmrails.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index c7c292cd7..5a10c7703 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -80,7 +80,11 @@ from nemoguardrails.logging.verbose import set_verbose from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop from nemoguardrails.rails.llm.buffer import get_buffer_strategy -from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig +from nemoguardrails.rails.llm.config import ( + EmbeddingSearchProvider, + OutputRailsStreamingConfig, + RailsConfig, +) from nemoguardrails.rails.llm.options import ( GenerationLog, GenerationOptions, @@ -1348,6 +1352,7 @@ def stream_async( ): return self._run_output_rails_in_streaming( streaming_handler=generator, + output_rails_streaming_config=self.config.rails.output.streaming, messages=messages, prompt=prompt, ) @@ -1403,6 +1408,7 @@ def task_done_callback(task): # returns an async generator return self._run_output_rails_in_streaming( streaming_handler=streaming_handler, + output_rails_streaming_config=self.config.rails.output.streaming, messages=messages, prompt=prompt, ) @@ -1631,6 +1637,7 @@ def __setstate__(self, state): async def _run_output_rails_in_streaming( self, streaming_handler: AsyncIterator[str], + output_rails_streaming_config: OutputRailsStreamingConfig, prompt: Optional[str] = None, messages: Optional[List[dict]] = None, stream_first: Optional[bool] = None, @@ -1733,9 +1740,6 @@ def _prepare_params( **action_params, } - output_rails_streaming_config = self.config.rails.output.streaming - if output_rails_streaming_config is None: - raise ValueError("Output rails streaming config is not available") buffer_strategy = get_buffer_strategy(output_rails_streaming_config) output_rails_flows_id = self.config.rails.output.flows stream_first = stream_first or output_rails_streaming_config.stream_first From 6948258bd3254108e8a5af33d6e34aff25bf09e0 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 23 Sep 2025 11:29:30 -0500 Subject: [PATCH 06/23] Add coverage to explain_info() --- tests/test_llmrails.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index f97389284..9b8a2b300 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -15,11 +15,13 @@ import os from typing import Any, Dict, List, Optional, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from langchain_core.language_models import BaseChatModel from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.logging.explain import ExplainInfo from nemoguardrails.rails.llm.config import Model from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id from tests.utils import FakeLLM, clean_events, event_sequence_conforms @@ -1170,3 +1172,18 @@ def dummy_parser(text): assert "chained_action" in rails.runtime.action_dispatcher.registered_actions assert "chained_param" in rails.runtime.registered_action_params assert rails.runtime.registered_action_params["chained_param"] == "param_value" + + +def test_explain_calls_ensure_explain_info(): + """Make sure if no `explain_info` attribute is present in LLMRails it's populated with + an empty ExplainInfo object""" + + mock_llm = MagicMock(spec=BaseChatModel) + config = RailsConfig.from_content(config={"models": []}) + rails = LLMRails(config=config, llm=mock_llm) + rails.generate(messages=[{"role": "user", "content": "Hi!"}]) + + rails.explain_info = None + info = rails.explain() + assert info == ExplainInfo() + assert rails.explain_info == ExplainInfo() From 0c4dc8e34a226299e653fff644f991323abcff19 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 23 Sep 2025 14:56:45 -0500 Subject: [PATCH 07/23] Fix rebase mistake where _create_isolated_llms_for_actions() was still included --- nemoguardrails/rails/llm/llmrails.py | 145 --------------------------- 1 file changed, 145 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 5a10c7703..bd3108071 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -525,151 +525,6 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) - def _create_isolated_llms_for_actions(self): - """Create isolated LLM copies for all actions that accept 'llm' parameter.""" - if not self.llm: - log.debug("No main LLM available for creating isolated copies") - return - - try: - actions_needing_llms = self._detect_llm_requiring_actions() - log.info( - "%d actions requiring isolated LLMs: %s", - len(actions_needing_llms), - list(actions_needing_llms), - ) - - created_count = 0 - - configured_actions_names = [] - try: - if self.config.flows: - get_action_details = partial( - get_action_details_from_flow_id, flows=self.config.flows - ) - for flow_id in self.config.rails.input.flows: - action_name, _ = get_action_details(flow_id) - configured_actions_names.append(action_name) - for flow_id in self.config.rails.output.flows: - action_name, _ = get_action_details(flow_id) - configured_actions_names.append(action_name) - else: - # for configurations without flow definitions, use all actions that need LLMs - log.info( - "No flow definitions found, creating isolated LLMs for all actions requiring them" - ) - configured_actions_names = list(actions_needing_llms) - except Exception as e: - # if flow matching fails, fall back to all actions that need LLMs - log.info( - "Flow matching failed (%s), creating isolated LLMs for all actions requiring them", - e, - ) - configured_actions_names = list(actions_needing_llms) - - for action_name in configured_actions_names: - if action_name not in actions_needing_llms: - continue - if f"{action_name}_llm" not in self.runtime.registered_action_params: - isolated_llm = self._create_action_llm_copy(self.llm, action_name) - if isolated_llm: - self.runtime.register_action_param( - f"{action_name}_llm", isolated_llm - ) - created_count += 1 - log.debug("Created isolated LLM for action: %s", action_name) - else: - log.debug( - "Action %s already has dedicated LLM, skipping isolation", - action_name, - ) - - log.info("Created %d isolated LLM instances for actions", created_count) - - except Exception as e: - log.warning("Failed to create isolated LLMs for actions: %s", e) - - def _detect_llm_requiring_actions(self): - """Auto-detect actions that have 'llm' parameter.""" - import inspect - - actions_needing_llms = set() - - if ( - not hasattr(self.runtime, "action_dispatcher") - or not self.runtime.action_dispatcher - ): - log.debug("Action dispatcher not available") - return actions_needing_llms - - for ( - action_name, - action_info, - ) in self.runtime.action_dispatcher.registered_actions.items(): - action_func = self._get_action_function(action_info) - if not action_func: - continue - - try: - sig = inspect.signature(action_func) - if "llm" in sig.parameters: - actions_needing_llms.add(action_name) - log.debug("Action %s has 'llm' parameter", action_name) - - except Exception as e: - log.debug("Could not inspect action %s: %s", action_name, e) - - return actions_needing_llms - - def _get_action_function(self, action_info): - """Extract the actual function from action info.""" - return action_info if callable(action_info) else None - - def _create_action_llm_copy( - self, main_llm: Union[BaseLLM, BaseChatModel], action_name: str - ) -> Optional[Union[BaseLLM, BaseChatModel]]: - """Create an isolated copy of main LLM for a specific action.""" - import copy - - try: - # shallow copy to preserve HTTP clients, credentials, etc. - # but create new instance to avoid shared state - isolated_llm = copy.copy(main_llm) - - # isolate model_kwargs to prevent shared mutable state - if ( - hasattr(isolated_llm, "model_kwargs") - and getattr(isolated_llm, "model_kwargs", None) is not None - ): - setattr( - isolated_llm, - "model_kwargs", - getattr(isolated_llm, "model_kwargs").copy(), - ) - - log.debug( - "Successfully created isolated LLM copy for action: %s", action_name - ) - return isolated_llm - - except Exception as e: - error_msg = ( - "Failed to create isolated LLM instance for action '%s'. " - "This is required to prevent parameter contamination between different actions. " - "\n\nPossible solutions:" - "\n1. If using a custom LLM class, ensure it supports copy.copy() operation" - "\n2. Check that your LLM configuration doesn't contain non-copyable objects" - "\n3. Consider using a dedicated LLM configuration for action '%s'" - "\n\nOriginal error: %s" - "\n\nTo use a dedicated LLM for this action, add to your config:" - "\nmodels:" - "\n - type: %s" - "\n engine: " - "\n model: " - ) % (action_name, action_name, e, action_name) - log.error(error_msg) - raise RuntimeError(error_msg) - def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None ) -> EmbeddingsIndex: From ba4504e0c419b4da09bc2db22c9f5cd93faad19a Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:01:29 -0500 Subject: [PATCH 08/23] Add pre-and-post LLM call count to work around FakeLLM / TestChat race conditions --- tests/test_retrieve_relevant_chunks.py | 32 ++++++++++++++++---------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/test_retrieve_relevant_chunks.py b/tests/test_retrieve_relevant_chunks.py index 7d1044661..72258ef48 100644 --- a/tests/test_retrieve_relevant_chunks.py +++ b/tests/test_retrieve_relevant_chunks.py @@ -12,15 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest +from langchain_core.language_models import BaseChatModel from nemoguardrails import LLMRails, RailsConfig from nemoguardrails.kb.kb import KnowledgeBase from tests.utils import TestChat -config = RailsConfig.from_content( +RAILS_CONFIG = RailsConfig.from_content( """ import llm import core @@ -55,7 +56,7 @@ def test_relevant_chunk_inserted_in_prompt(): ] chat = TestChat( - config, + RAILS_CONFIG, llm_completions=[ " user express greeting", ' bot respond to aditional context\nbot action: "Hello is there anything else" ', @@ -70,19 +71,21 @@ def test_relevant_chunk_inserted_in_prompt(): {"role": "user", "content": "Hi!"}, ] - new_message = rails.generate(messages=messages) + before_llm_calls = len(rails.explain().llm_calls) + _ = rails.generate(messages=messages) + after_llm_calls = len(rails.explain().llm_calls) + llm_call_count = after_llm_calls - before_llm_calls info = rails.explain() - assert len(info.llm_calls) == 2 - assert "Test Body" in info.llm_calls[1].prompt - - assert "markdown" in info.llm_calls[1].prompt - assert "context" in info.llm_calls[1].prompt + assert llm_call_count == 2 + assert "Test Body" in info.llm_calls[-1].prompt + assert "markdown" in info.llm_calls[-1].prompt + assert "context" in info.llm_calls[-1].prompt def test_relevant_chunk_inserted_in_prompt_no_kb(): chat = TestChat( - config, + RAILS_CONFIG, llm_completions=[ " user express greeting", ' bot respond to aditional context\nbot action: "Hello is there anything else" ', @@ -92,8 +95,13 @@ def test_relevant_chunk_inserted_in_prompt_no_kb(): messages = [ {"role": "user", "content": "Hi!"}, ] - new_message = rails.generate(messages=messages) + + before_llm_calls = len(rails.explain().llm_calls) + _ = rails.generate(messages=messages) + after_llm_calls = len(rails.explain().llm_calls) + llm_call_count = after_llm_calls - before_llm_calls + info = rails.explain() - assert len(info.llm_calls) == 2 + assert llm_call_count == 2 assert "markdown" not in info.llm_calls[1].prompt assert "context" not in info.llm_calls[1].prompt From 3513d703c7c6e3e0c5a2a8dc254adbab8875e33c Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:05:59 -0500 Subject: [PATCH 09/23] Address feedback from Pouyan and Traian --- nemoguardrails/context.py | 19 +++++---- nemoguardrails/rails/llm/buffer.py | 7 ++- nemoguardrails/rails/llm/config.py | 2 +- nemoguardrails/rails/llm/llmrails.py | 2 +- nemoguardrails/rails/llm/options.py | 2 +- tests/test_generation_options.py | 41 +++++++++++++++++- tests/test_llm_options.py | 64 ---------------------------- 7 files changed, 59 insertions(+), 78 deletions(-) delete mode 100644 tests/test_llm_options.py diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index 6688af919..2e7d34b82 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -14,11 +14,14 @@ # limitations under the License. import contextvars -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from nemoguardrails.logging.explain import LLMCallInfo if TYPE_CHECKING: from nemoguardrails.logging.explain import ExplainInfo - from nemoguardrails.rails.llm.options import GenerationOptions, LLMStats + from nemoguardrails.logging.stats import LLMStats + from nemoguardrails.rails.llm.options import GenerationOptions from nemoguardrails.streaming import StreamingHandler streaming_handler_var: contextvars.ContextVar[ @@ -31,9 +34,9 @@ ] = contextvars.ContextVar("explain_info", default=None) # The current LLM call. -llm_call_info_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( - "llm_call_info", default=None -) +llm_call_info_var: contextvars.ContextVar[ + Optional[LLMCallInfo] +] = contextvars.ContextVar("llm_call_info", default=None) # All the generation options applicable to the current context. generation_options_var: contextvars.ContextVar[ @@ -47,9 +50,9 @@ # The raw LLM request that comes from the user. # This is used in passthrough mode. -raw_llm_request: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar( - "raw_llm_request", default=None -) +raw_llm_request: contextvars.ContextVar[ + Optional[Union[str, List[Dict[str, Any]]]] +] = contextvars.ContextVar("raw_llm_request", default=None) reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "reasoning_trace", default=None diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index 541f52915..62ad77291 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -114,7 +114,9 @@ def format_chunks(self, chunks: List[str]) -> str: ... @abstractmethod - async def process_stream(self, streaming_handler): + async def process_stream( + self, streaming_handler + ) -> AsyncGenerator[ChunkBatch, None]: """Process streaming chunks and yield chunk batches. This is the main method that concrete buffer strategies must implement. @@ -139,7 +141,8 @@ async def process_stream(self, streaming_handler): ... print(f"Processing: {context_formatted}") ... print(f"User: {user_formatted}") """ - yield ChunkBatch([], []) # pragma: no cover + raise NotImplementedError # pragma: no cover + yield async def __call__(self, streaming_handler): """Callable interface that delegates to process_stream. diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index f84139cf5..6c5073a78 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -487,7 +487,7 @@ class OutputRails(BaseModel): description="The names of all the flows that implement output rails.", ) - streaming: Optional[OutputRailsStreamingConfig] = Field( + streaming: OutputRailsStreamingConfig = Field( default_factory=OutputRailsStreamingConfig, description="Configuration for streaming output rails.", ) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index bd3108071..811e1dd8b 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -831,7 +831,7 @@ async def generate_async( and messages[-1]["role"] == "assistant" and options and hasattr(options, "rails") - and getattr(getattr(options, "rails", None), "dialog", None) is False + and getattr(getattr(options, "rails"), "dialog", None) is False ): # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index f9a2f77dd..8bc7c0ca0 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -372,7 +372,7 @@ def print_summary(self): _pc = round(100 - pc, 2) print(f" - [{processing_overhead:.2f}s][{_pc}%]: Processing overhead ") - if self.stats.llm_calls_count and self.stats.llm_calls_count > 0: + if self.stats.llm_calls_count: print( f"- {self.stats.llm_calls_count} LLM calls, " f"{self.stats.llm_calls_duration:.2f}s total duration, " diff --git a/tests/test_generation_options.py b/tests/test_generation_options.py index 06895aa87..a8aeff02b 100644 --- a/tests/test_generation_options.py +++ b/tests/test_generation_options.py @@ -18,7 +18,11 @@ import pytest from nemoguardrails import LLMRails, RailsConfig -from nemoguardrails.rails.llm.options import GenerationResponse +from nemoguardrails.rails.llm.options import ( + GenerationLog, + GenerationResponse, + GenerationStats, +) from tests.utils import TestChat @@ -313,3 +317,38 @@ def test_only_input_output_validation(): assert res.response == [ {"content": "I'm sorry, I can't respond to that.", "role": "assistant"} ] + + +def test_generation_log_print_summary(capsys): + """Test printing rais stats with dummy data""" + + stats = GenerationStats( + input_rails_duration=1.0, + dialog_rails_duration=2.0, + generation_rails_duration=3.0, + output_rails_duration=4.0, + total_duration=10.0, # Sum of all previous rail durations + llm_calls_duration=8.0, # Less than total duration + llm_calls_count=4, # Input, dialog, generation and output calls + llm_calls_total_prompt_tokens=1000, + llm_calls_total_completion_tokens=2000, + llm_calls_total_tokens=3000, # Sum of prompt and completion tokens + ) + + generation_log = GenerationLog(activated_rails=[], stats=stats) + + generation_log.print_summary() + capture = capsys.readouterr() + capture_lines = capture.out.splitlines() + + # Check the correct times were printed + assert capture_lines[1] == "# General stats" + assert capture_lines[3] == "- Total time: 10.00s" + assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" + assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" + assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" + assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" + assert ( + capture_lines[8] + == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." + ) diff --git a/tests/test_llm_options.py b/tests/test_llm_options.py deleted file mode 100644 index 72226afda..000000000 --- a/tests/test_llm_options.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for LLM isolation with models that don't have model_kwargs field.""" - -from typing import Any, Dict, List, Optional -from unittest.mock import Mock - -import pytest -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage -from langchain_core.outputs import ChatGeneration, ChatResult -from pydantic import BaseModel, Field - -from nemoguardrails.rails.llm.config import RailsConfig -from nemoguardrails.rails.llm.llmrails import LLMRails -from nemoguardrails.rails.llm.options import GenerationLog, GenerationStats - - -def test_generation_log_print_summary(capsys): - """Test printing rais stats with dummy data""" - - stats = GenerationStats( - input_rails_duration=1.0, - dialog_rails_duration=2.0, - generation_rails_duration=3.0, - output_rails_duration=4.0, - total_duration=10.0, # Sum of all previous rail durations - llm_calls_duration=8.0, # Less than total duration - llm_calls_count=4, # Input, dialog, generation and output calls - llm_calls_total_prompt_tokens=1000, - llm_calls_total_completion_tokens=2000, - llm_calls_total_tokens=3000, # Sum of prompt and completion tokens - ) - - generation_log = GenerationLog(activated_rails=[], stats=stats) - - generation_log.print_summary() - capture = capsys.readouterr() - capture_lines = capture.out.splitlines() - - # Check the correct times were printed - assert capture_lines[1] == "# General stats" - assert capture_lines[3] == "- Total time: 10.00s" - assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" - assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" - assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" - assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" - assert ( - capture_lines[8] - == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." - ) From c0ef65e704eaef39923e8e383b7d99eaacdc2716 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:11:02 -0500 Subject: [PATCH 10/23] Forgot to simplify the line 1753 in rails/llm/config.py --- nemoguardrails/rails/llm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 6c5073a78..c8b9b3e95 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1750,7 +1750,7 @@ def streaming_supported(self): # if we have output rails streaming enabled # we keep it in case it was needed when we have # support per rails - if self.rails.output.streaming and self.rails.output.streaming.enabled: + if self.rails.output.streaming.enabled: return True return False From b87fb4a7d7f1e599b0c46d287312020566a8cefd Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:17:40 -0500 Subject: [PATCH 11/23] Dummy commit to set up the chore/type-clean-guardrails PR and branch --- nemoguardrails/actions/llm/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 377b0bc5e..8031cc715 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -136,7 +136,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item From 7784ac34399b3c7a391e105af22d392f008ff021 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Sun, 14 Sep 2025 23:03:30 -0500 Subject: [PATCH 12/23] Check in rails/ type-fixes --- nemoguardrails/rails/llm/buffer.py | 7 ++----- nemoguardrails/rails/llm/config.py | 2 +- nemoguardrails/rails/llm/llmrails.py | 3 +++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index 62ad77291..541f52915 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -114,9 +114,7 @@ def format_chunks(self, chunks: List[str]) -> str: ... @abstractmethod - async def process_stream( - self, streaming_handler - ) -> AsyncGenerator[ChunkBatch, None]: + async def process_stream(self, streaming_handler): """Process streaming chunks and yield chunk batches. This is the main method that concrete buffer strategies must implement. @@ -141,8 +139,7 @@ async def process_stream( ... print(f"Processing: {context_formatted}") ... print(f"User: {user_formatted}") """ - raise NotImplementedError # pragma: no cover - yield + yield ChunkBatch([], []) # pragma: no cover async def __call__(self, streaming_handler): """Callable interface that delegates to process_stream. diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index c8b9b3e95..6c5073a78 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -1750,7 +1750,7 @@ def streaming_supported(self): # if we have output rails streaming enabled # we keep it in case it was needed when we have # support per rails - if self.rails.output.streaming.enabled: + if self.rails.output.streaming and self.rails.output.streaming.enabled: return True return False diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 811e1dd8b..e259fd556 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1595,6 +1595,9 @@ def _prepare_params( **action_params, } + output_rails_streaming_config = self.config.rails.output.streaming + if output_rails_streaming_config is None: + raise ValueError("Output rails streaming config is not available") buffer_strategy = get_buffer_strategy(output_rails_streaming_config) output_rails_flows_id = self.config.rails.output.flows stream_first = stream_first or output_rails_streaming_config.stream_first From bcb522dcdd65812f7f6fbb2d78dd72d7f2ee4bbc Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 13:39:51 -0500 Subject: [PATCH 13/23] Revert "Dummy commit to set up the chore/type-clean-guardrails PR and branch" This reverts commit 71d00f083fb59bda34c82b82eea85602c1710265. --- nemoguardrails/actions/llm/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 8031cc715..377b0bc5e 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -136,7 +136,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow) -> None: + def _extract_user_message_example(self, flow: Flow): """Heuristic to extract user message examples from a flow.""" elements = [ item From 3673cc1be9e39f598517015eacafb31bafe814c0 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 15:16:27 -0500 Subject: [PATCH 14/23] Add coverage to rails/llm/options.py --- tests/test_llm_options.py | 64 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/test_llm_options.py diff --git a/tests/test_llm_options.py b/tests/test_llm_options.py new file mode 100644 index 000000000..72226afda --- /dev/null +++ b/tests/test_llm_options.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM isolation with models that don't have model_kwargs field.""" + +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from pydantic import BaseModel, Field + +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails +from nemoguardrails.rails.llm.options import GenerationLog, GenerationStats + + +def test_generation_log_print_summary(capsys): + """Test printing rais stats with dummy data""" + + stats = GenerationStats( + input_rails_duration=1.0, + dialog_rails_duration=2.0, + generation_rails_duration=3.0, + output_rails_duration=4.0, + total_duration=10.0, # Sum of all previous rail durations + llm_calls_duration=8.0, # Less than total duration + llm_calls_count=4, # Input, dialog, generation and output calls + llm_calls_total_prompt_tokens=1000, + llm_calls_total_completion_tokens=2000, + llm_calls_total_tokens=3000, # Sum of prompt and completion tokens + ) + + generation_log = GenerationLog(activated_rails=[], stats=stats) + + generation_log.print_summary() + capture = capsys.readouterr() + capture_lines = capture.out.splitlines() + + # Check the correct times were printed + assert capture_lines[1] == "# General stats" + assert capture_lines[3] == "- Total time: 10.00s" + assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" + assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" + assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" + assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" + assert ( + capture_lines[8] + == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." + ) From ecd87ca9c9f2f99cf2766488d4585a21dce49071 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:09:21 -0500 Subject: [PATCH 15/23] Add coverage to _configure_main_llm_streaming() --- tests/rails/llm/test_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index f79dbc0ad..71a42ac1a 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -307,7 +307,6 @@ def test_rails_config_none_config_path(): result2 = config3 + config4 assert result2.config_path == "" - def test_llm_rails_configure_streaming_with_attr(): """Check LLM has the streaming attribute set if RailsConfig has it""" From 483c547dba6ab9cbc8094c869e4b205fe968618c Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:19:16 -0500 Subject: [PATCH 16/23] Pass OutputRailsStreamingConfig into _run_output_rails_in_streaming as mandatory argument as all 2 calls check for None --- nemoguardrails/rails/llm/llmrails.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index e259fd556..811e1dd8b 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1595,9 +1595,6 @@ def _prepare_params( **action_params, } - output_rails_streaming_config = self.config.rails.output.streaming - if output_rails_streaming_config is None: - raise ValueError("Output rails streaming config is not available") buffer_strategy = get_buffer_strategy(output_rails_streaming_config) output_rails_flows_id = self.config.rails.output.flows stream_first = stream_first or output_rails_streaming_config.stream_first From 5e46e5a65ea96a4b281e983cb8a24c0728d6a0ad Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:03:13 +0200 Subject: [PATCH 17/23] review: type hint fixes fix --- nemoguardrails/rails/llm/llmrails.py | 135 ++++++++++++++------------- nemoguardrails/rails/llm/options.py | 3 +- 2 files changed, 72 insertions(+), 66 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 811e1dd8b..086d480e4 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -776,6 +776,19 @@ async def generate_async( The completion (when a prompt is provided) or the next message. System messages are not yet supported.""" + # convert options to gen_options of type GenerationOptions + gen_options: Optional[GenerationOptions] = None + + if prompt is None and messages is None: + raise ValueError("Either prompt or messages must be provided.") + + if prompt is not None and messages is not None: + raise ValueError("Only one of prompt or messages can be provided.") + + if prompt is not None: + # Currently, we transform the prompt request into a single turn conversation + messages = [{"role": "user", "content": prompt}] + # If a state object is specified, then we switch to "generation options" mode. # This is because we want the output to be a GenerationResponse which will contain # the output state. @@ -785,15 +798,25 @@ async def generate_async( state = json_to_state(state["state"]) if options is None: - options = GenerationOptions() - - # We allow options to be specified both as a dict and as an object. - if options and isinstance(options, dict): - options = GenerationOptions(**options) + gen_options = GenerationOptions() + elif isinstance(options, dict): + gen_options = GenerationOptions(**options) + else: + gen_options = options + else: + # We allow options to be specified both as a dict and as an object. + if options and isinstance(options, dict): + gen_options = GenerationOptions(**options) + elif isinstance(options, GenerationOptions): + gen_options = options + elif options is None: + gen_options = None + else: + raise TypeError("options must be a dict or GenerationOptions") # Save the generation options in the current async context. - # At this point, options is either None or GenerationOptions - generation_options_var.set(options if not isinstance(options, dict) else None) + # At this point, gen_options is either None or GenerationOptions + generation_options_var.set(gen_options) if streaming_handler: streaming_handler_var.set(streaming_handler) @@ -803,23 +826,14 @@ async def generate_async( # requests are made. self.explain_info = self._ensure_explain_info() - if prompt is not None: - # Currently, we transform the prompt request into a single turn conversation - messages = [{"role": "user", "content": prompt}] - raw_llm_request.set(prompt) - else: - raw_llm_request.set(messages) + raw_llm_request.set(messages) # If we have generation options, we also add them to the context - if options: + if gen_options: messages = [ { "role": "context", - "content": { - "generation_options": getattr( - options, "dict", lambda: options - )() - }, + "content": {"generation_options": gen_options.model_dump()}, } ] + (messages or []) @@ -848,7 +862,7 @@ async def generate_async( processing_log = [] # The array of events corresponding to the provided sequence of messages. - events = self._get_events_for_messages(messages or [], state) + events = self._get_events_for_messages(messages, state) # type: ignore if self.config.colang_version == "1.0": # If we had a state object, we also need to prepend the events from the state. @@ -967,7 +981,7 @@ async def generate_async( # If a state object is not used, then we use the implicit caching if state is None: # Save the new events in the history and update the cache - cache_key = get_history_cache_key((messages or []) + [new_message]) + cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore self.events_history_cache[cache_key] = events else: output_state = {"events": events} @@ -995,33 +1009,29 @@ async def generate_async( # IF tracing is enabled we need to set GenerationLog attrs original_log_options = None if self.config.tracing.enabled: - if options is None: - options = GenerationOptions() + if gen_options is None: + gen_options = GenerationOptions() else: - # create a copy of the options to avoid modifying the original - if isinstance(options, GenerationOptions): - options = options.model_copy(deep=True) - else: - # If options is a dict, convert it to GenerationOptions - options = GenerationOptions(**options) - original_log_options = options.log.model_copy(deep=True) + # create a copy of the gen_options to avoid modifying the original + gen_options = gen_options.model_copy(deep=True) + original_log_options = gen_options.log.model_copy(deep=True) # enable log options # it is aggressive, but these are required for tracing if ( - not options.log.activated_rails - or not options.log.llm_calls - or not options.log.internal_events + not gen_options.log.activated_rails + or not gen_options.log.llm_calls + or not gen_options.log.internal_events ): - options.log.activated_rails = True - options.log.llm_calls = True - options.log.internal_events = True + gen_options.log.activated_rails = True + gen_options.log.llm_calls = True + gen_options.log.internal_events = True tool_calls = extract_tool_calls_from_events(new_events) llm_metadata = get_and_clear_response_metadata_contextvar() # If we have generation options, we prepare a GenerationResponse instance. - if options: + if gen_options: # If a prompt was used, we only need to return the content of the message. if prompt: res = GenerationResponse(response=new_message["content"]) @@ -1048,9 +1058,9 @@ async def generate_async( if self.config.colang_version == "1.0": # If output variables are specified, we extract their values - if getattr(options, "output_vars", None): + if gen_options and gen_options.output_vars: context = compute_context(events) - output_vars = getattr(options, "output_vars", None) + output_vars = gen_options.output_vars if isinstance(output_vars, list): # If we have only a selection of keys, we filter to only that. res.output_data = {k: context.get(k) for k in output_vars} @@ -1061,41 +1071,40 @@ async def generate_async( _log = compute_generation_log(processing_log) # Include information about activated rails and LLM calls if requested - log_options = getattr(options, "log", None) + log_options = gen_options.log if gen_options else None if log_options and ( - getattr(log_options, "activated_rails", False) - or getattr(log_options, "llm_calls", False) + log_options.activated_rails or log_options.llm_calls ): res.log = GenerationLog() # We always include the stats res.log.stats = _log.stats - if getattr(log_options, "activated_rails", False): + if log_options.activated_rails: res.log.activated_rails = _log.activated_rails - if getattr(log_options, "llm_calls", False): + if log_options.llm_calls: res.log.llm_calls = [] for activated_rail in _log.activated_rails: for executed_action in activated_rail.executed_actions: res.log.llm_calls.extend(executed_action.llm_calls) # Include internal events if requested - if getattr(log_options, "internal_events", False): + if log_options and log_options.internal_events: if res.log is None: res.log = GenerationLog() res.log.internal_events = new_events # Include the Colang history if requested - if getattr(log_options, "colang_history", False): + if log_options and log_options.colang_history: if res.log is None: res.log = GenerationLog() res.log.colang_history = get_colang_history(events) # Include the raw llm output if requested - if getattr(options, "llm_output", False): + if gen_options and gen_options.llm_output: # Currently, we include the output from the generation LLM calls. for activated_rail in _log.activated_rails: if activated_rail.type == "generation": @@ -1103,23 +1112,23 @@ async def generate_async( for llm_call in executed_action.llm_calls: res.llm_output = llm_call.raw_response else: - if getattr(options, "output_vars", None): + if gen_options and gen_options.output_vars: raise ValueError( "The `output_vars` option is not supported for Colang 2.0 configurations." ) - log_options = getattr(options, "log", None) + log_options = gen_options.log if gen_options else None if log_options and ( - getattr(log_options, "activated_rails", False) - or getattr(log_options, "llm_calls", False) - or getattr(log_options, "internal_events", False) - or getattr(log_options, "colang_history", False) + log_options.activated_rails + or log_options.llm_calls + or log_options.internal_events + or log_options.colang_history ): raise ValueError( "The `log` option is not supported for Colang 2.0 configurations." ) - if getattr(options, "llm_output", False): + if gen_options and gen_options.llm_output: raise ValueError( "The `llm_output` option is not supported for Colang 2.0 configurations." ) @@ -1153,25 +1162,21 @@ async def generate_async( if original_log_options: if not any( ( - getattr(original_log_options, "internal_events", False), - getattr(original_log_options, "activated_rails", False), - getattr(original_log_options, "llm_calls", False), - getattr(original_log_options, "colang_history", False), + original_log_options.internal_events, + original_log_options.activated_rails, + original_log_options.llm_calls, + original_log_options.colang_history, ) ): res.log = None else: # Ensure res.log exists before setting attributes if res.log is not None: - if not getattr( - original_log_options, "internal_events", False - ): + if not original_log_options.internal_events: res.log.internal_events = [] - if not getattr( - original_log_options, "activated_rails", False - ): + if not original_log_options.activated_rails: res.log.activated_rails = [] - if not getattr(original_log_options, "llm_calls", False): + if not original_log_options.llm_calls: res.log.llm_calls = [] return res diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index 8bc7c0ca0..67bf9c76a 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -76,6 +76,7 @@ # {..., log: {"llm_calls": [...]}} """ + from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator @@ -156,7 +157,7 @@ class GenerationOptions(BaseModel): default=None, description="Additional parameters that should be used for the LLM call", ) - llm_output: Optional[bool] = Field( + llm_output: bool = Field( default=False, description="Whether the response should also include any custom LLM output.", ) From a27814484acae6ce30018a66e2995c5bb8f3ec51 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 24 Sep 2025 10:08:49 -0500 Subject: [PATCH 18/23] Fixed bad merge of GenerationOptions in llmrails.py (code referred to options, not gen_options) --- nemoguardrails/rails/llm/llmrails.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 086d480e4..fe56bcf08 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -843,9 +843,8 @@ async def generate_async( if ( messages and messages[-1]["role"] == "assistant" - and options - and hasattr(options, "rails") - and getattr(getattr(options, "rails"), "dialog", None) is False + and gen_options + and gen_options.rails.dialog is False ): # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] From e4826175049f308d21070e8bf7076b9ddceee7b9 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 24 Sep 2025 10:24:08 -0500 Subject: [PATCH 19/23] Add pre-commit pyright hook for nemoguardrails/rails/ --- .pre-commit-config.yaml | 9 ++++++++- pyproject.toml | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a5268ed5..48d882884 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,14 @@ repos: args: - --license-filepath - LICENSE.md - + - repo: local + hooks: + - id: pyright + name: pyright + entry: poetry run pyright + language: system + types: [python] + pass_filenames: false # Deactivating this for now. # - repo: https://github.com/pycqa/pylint # rev: v2.17.0 diff --git a/pyproject.toml b/pyproject.toml index 418616a58..86aa7fcad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,7 +151,11 @@ pytest-profiling = "^1.7.0" yara-python = "^4.5.1" opentelemetry-api = "^1.34.1" opentelemetry-sdk = "^1.34.1" +pyright = "^1.1.405" +# Directories in which to run Pyright type-checking +[tool.pyright] +include = ["nemoguardrails/rails/**"] [tool.poetry.group.docs] optional = true From 244ff734f3fd7a5a402a7f40384a675ed0ed017b Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 24 Sep 2025 10:27:47 -0500 Subject: [PATCH 20/23] Regenerate poetry.lock --- poetry.lock | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index eeced1c7e..b5eedf3d0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4298,6 +4298,26 @@ files = [ [package.extras] dev = ["build", "flake8", "mypy", "pytest", "twine"] +[[package]] +name = "pyright" +version = "1.1.405" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, + {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" +typing-extensions = ">=4.1" + +[package.extras] +all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] +nodejs = ["nodejs-wheel-binaries"] + [[package]] name = "pytest" version = "8.4.1" @@ -6448,4 +6468,4 @@ tracing = ["aiofiles", "opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.9,!=3.9.7,<3.14" -content-hash = "6654d6115d5142024695ff1a736cc3d133842421b1282f5c3ba413b6a0250118" +content-hash = "313705d475a9cb177efa633c193da9315388aa99832b9c5b429fafb5b3da44b0" From 360f4ac30f1a8212c209648e25d51fecb1955ce9 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:49:58 +0200 Subject: [PATCH 21/23] revert to original state at HEAD of develop and fix the type hint --- nemoguardrails/rails/llm/buffer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index 541f52915..fdbd5ba08 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -14,10 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, AsyncGenerator, List, NamedTuple - -if TYPE_CHECKING: - from collections.abc import AsyncIterator +from typing import AsyncGenerator, List, NamedTuple from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig @@ -114,7 +111,9 @@ def format_chunks(self, chunks: List[str]) -> str: ... @abstractmethod - async def process_stream(self, streaming_handler): + async def process_stream( + self, streaming_handler + ) -> AsyncGenerator[ChunkBatch, None]: """Process streaming chunks and yield chunk batches. This is the main method that concrete buffer strategies must implement. @@ -139,9 +138,10 @@ async def process_stream(self, streaming_handler): ... print(f"Processing: {context_formatted}") ... print(f"User: {user_formatted}") """ - yield ChunkBatch([], []) # pragma: no cover + raise NotImplementedError + yield - async def __call__(self, streaming_handler): + async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]: """Callable interface that delegates to process_stream. It delegates to the `process_stream` method and can @@ -257,7 +257,9 @@ def from_config(cls, config: OutputRailsStreamingConfig): buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size ) - async def process_stream(self, streaming_handler): + async def process_stream( + self, streaming_handler + ) -> AsyncGenerator[ChunkBatch, None]: """Process streaming chunks using rolling buffer strategy. This method implements the rolling buffer logic, accumulating chunks From f0bc730b75182f5d92b302cad912b30eba19f92c Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:51:07 +0200 Subject: [PATCH 22/23] delete duplicate test --- tests/test_llm_options.py | 64 --------------------------------------- 1 file changed, 64 deletions(-) delete mode 100644 tests/test_llm_options.py diff --git a/tests/test_llm_options.py b/tests/test_llm_options.py deleted file mode 100644 index 72226afda..000000000 --- a/tests/test_llm_options.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for LLM isolation with models that don't have model_kwargs field.""" - -from typing import Any, Dict, List, Optional -from unittest.mock import Mock - -import pytest -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage -from langchain_core.outputs import ChatGeneration, ChatResult -from pydantic import BaseModel, Field - -from nemoguardrails.rails.llm.config import RailsConfig -from nemoguardrails.rails.llm.llmrails import LLMRails -from nemoguardrails.rails.llm.options import GenerationLog, GenerationStats - - -def test_generation_log_print_summary(capsys): - """Test printing rais stats with dummy data""" - - stats = GenerationStats( - input_rails_duration=1.0, - dialog_rails_duration=2.0, - generation_rails_duration=3.0, - output_rails_duration=4.0, - total_duration=10.0, # Sum of all previous rail durations - llm_calls_duration=8.0, # Less than total duration - llm_calls_count=4, # Input, dialog, generation and output calls - llm_calls_total_prompt_tokens=1000, - llm_calls_total_completion_tokens=2000, - llm_calls_total_tokens=3000, # Sum of prompt and completion tokens - ) - - generation_log = GenerationLog(activated_rails=[], stats=stats) - - generation_log.print_summary() - capture = capsys.readouterr() - capture_lines = capture.out.splitlines() - - # Check the correct times were printed - assert capture_lines[1] == "# General stats" - assert capture_lines[3] == "- Total time: 10.00s" - assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" - assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" - assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" - assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" - assert ( - capture_lines[8] - == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." - ) From d9c035843a2abd58157a2368765fc78152a8a14c Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:54:22 +0200 Subject: [PATCH 23/23] fix black style --- tests/rails/llm/test_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index 71a42ac1a..f79dbc0ad 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -307,6 +307,7 @@ def test_rails_config_none_config_path(): result2 = config3 + config4 assert result2.config_path == "" + def test_llm_rails_configure_streaming_with_attr(): """Check LLM has the streaming attribute set if RailsConfig has it"""