diff --git a/nemoguardrails/llm/prompts.py b/nemoguardrails/llm/prompts.py index 8f00b2b55..1f407a17a 100644 --- a/nemoguardrails/llm/prompts.py +++ b/nemoguardrails/llm/prompts.py @@ -134,6 +134,17 @@ def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Optional[Mode # Fetch current task parameters like name, models to use, and the prompting mode task_name = str(task.value) if isinstance(task, Task) else task + # Check if the task name contains a model specification (e.g., "content_safety_check_input $model=content_safety") + if "$model=" in task_name: + # Extract the model type from the task name + model_type = task_name.split("$model=")[-1].strip() + # Look for a model with this specific type + if config.models: + _models = [model for model in config.models if model.type == model_type] + if _models: + return _models[0] + + # If no model specification or no matching model found, fall back to the original logic if config.models: _models = [model for model in config.models if model.type == task_name] if not _models: diff --git a/tests/test_llm_task_manager.py b/tests/test_llm_task_manager.py index 7897e55b6..8443bcc8a 100644 --- a/tests/test_llm_task_manager.py +++ b/tests/test_llm_task_manager.py @@ -532,3 +532,37 @@ def test_get_task_model_fallback_to_main(): result = get_task_model(config, "some_other_task") assert result is not None assert result.type == "main" + + +def test_get_task_model_with_model_specification(): + """Test that get_task_model correctly extracts model type from task names with $model= specification.""" + config = RailsConfig.parse_object( + { + "models": [ + { + "type": "main", + "engine": "openai", + "model": "gpt-3.5-turbo", + }, + { + "type": "content_safety", + "engine": "openai", + "model": "gpt-4", + }, + ] + } + ) + + # Test with a task name that contains $model= specification + result = get_task_model(config, "content_safety_check_input $model=content_safety") + assert result is not None + assert result.type == "content_safety" + assert result.engine == "openai" + assert result.model == "gpt-4" + + # Test fallback to main model when specified model type doesn't exist + result = get_task_model(config, "unknown_task $model=nonexistent") + assert result is not None + assert result.type == "main" + assert result.engine == "openai" + assert result.model == "gpt-3.5-turbo"