From 03d2bfbd9a00ffdbffce02d27263803a375c8b93 Mon Sep 17 00:00:00 2001 From: Jeffrey Carpenter Date: Tue, 2 Sep 2025 14:28:18 -0700 Subject: [PATCH] fix: resolve content safety model resolution issue - Fix get_task_model function to properly parse task names with = specifications - Extract model type from task names like 'content_safety_check_input =content_safety' - Use extracted model type to find correct model configuration instead of defaulting to main model - Add comprehensive test coverage for the new functionality - Maintain backward compatibility for existing task names Fixes issue where content safety actions would fail with error: 'Could not find prompt for task content_safety_check_input =content_safety and model [main_model_name]' This fix ensures that when a task specifies a model type via =, the system correctly uses that model type for prompt resolution rather than falling back to the main model. --- nemoguardrails/llm/prompts.py | 11 +++++++++++ tests/test_llm_task_manager.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/nemoguardrails/llm/prompts.py b/nemoguardrails/llm/prompts.py index 8f00b2b55..1f407a17a 100644 --- a/nemoguardrails/llm/prompts.py +++ b/nemoguardrails/llm/prompts.py @@ -134,6 +134,17 @@ def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Optional[Mode # Fetch current task parameters like name, models to use, and the prompting mode task_name = str(task.value) if isinstance(task, Task) else task + # Check if the task name contains a model specification (e.g., "content_safety_check_input $model=content_safety") + if "$model=" in task_name: + # Extract the model type from the task name + model_type = task_name.split("$model=")[-1].strip() + # Look for a model with this specific type + if config.models: + _models = [model for model in config.models if model.type == model_type] + if _models: + return _models[0] + + # If no model specification or no matching model found, fall back to the original logic if config.models: _models = [model for model in config.models if model.type == task_name] if not _models: diff --git a/tests/test_llm_task_manager.py b/tests/test_llm_task_manager.py index 7897e55b6..8443bcc8a 100644 --- a/tests/test_llm_task_manager.py +++ b/tests/test_llm_task_manager.py @@ -532,3 +532,37 @@ def test_get_task_model_fallback_to_main(): result = get_task_model(config, "some_other_task") assert result is not None assert result.type == "main" + + +def test_get_task_model_with_model_specification(): + """Test that get_task_model correctly extracts model type from task names with $model= specification.""" + config = RailsConfig.parse_object( + { + "models": [ + { + "type": "main", + "engine": "openai", + "model": "gpt-3.5-turbo", + }, + { + "type": "content_safety", + "engine": "openai", + "model": "gpt-4", + }, + ] + } + ) + + # Test with a task name that contains $model= specification + result = get_task_model(config, "content_safety_check_input $model=content_safety") + assert result is not None + assert result.type == "content_safety" + assert result.engine == "openai" + assert result.model == "gpt-4" + + # Test fallback to main model when specified model type doesn't exist + result = get_task_model(config, "unknown_task $model=nonexistent") + assert result is not None + assert result.type == "main" + assert result.engine == "openai" + assert result.model == "gpt-3.5-turbo"