diff --git a/tests/llm/test_utils.py b/tests/llm/test_utils.py new file mode 100644 index 00000000..3861989e --- /dev/null +++ b/tests/llm/test_utils.py @@ -0,0 +1,173 @@ +import pytest + +from strix.llm.utils import ( + _fix_stopword, + _truncate_to_first_function, + clean_content, + format_tool_call, + parse_tool_invocations, +) + + +class TestParseToolInvocations: + """Tests for parse_tool_invocations function.""" + + def test_single_function_with_params(self) -> None: + content = """ +value1 +value2 +""" + result = parse_tool_invocations(content) + assert result == [{"toolName": "test_tool", "args": {"arg1": "value1", "arg2": "value2"}}] + + def test_multiple_functions(self) -> None: + content = """ +1 + + +2 +""" + result = parse_tool_invocations(content) + assert result is not None + assert len(result) == 2 + assert result[0]["toolName"] == "tool1" + assert result[1]["toolName"] == "tool2" + + def test_html_entities_unescaped(self) -> None: + content = """ +<div>hello</div> +""" + result = parse_tool_invocations(content) + assert result is not None + assert result[0]["args"]["code"] == "
hello
" + + def test_empty_content_returns_none(self) -> None: + assert parse_tool_invocations("") is None + + def test_no_functions_returns_none(self) -> None: + assert parse_tool_invocations("just some text without functions") is None + + def test_function_without_params(self) -> None: + content = "" + result = parse_tool_invocations(content) + assert result == [{"toolName": "no_args", "args": {}}] + + def test_multiline_param_value(self) -> None: + content = """ +line 1 +line 2 +line 3 +""" + result = parse_tool_invocations(content) + assert result is not None + assert "line 1\nline 2\nline 3" == result[0]["args"]["content"] + + def test_ampersand_entity(self) -> None: + content = """ +foo & bar +""" + result = parse_tool_invocations(content) + assert result is not None + assert result[0]["args"]["query"] == "foo & bar" + + +class TestFixStopword: + """Tests for _fix_stopword function.""" + + def test_complete_function_unchanged(self) -> None: + content = "\n1\n" + assert _fix_stopword(content) == content + + def test_ending_with_incomplete_slash(self) -> None: + content = "\n1\n") + + def test_missing_closing_tag(self) -> None: + content = "\n1" + result = _fix_stopword(content) + assert result.endswith("") + + def test_multiple_functions_no_fix(self) -> None: + content = "" + assert _fix_stopword(content) == content + + +class TestFormatToolCall: + """Tests for format_tool_call function.""" + + def test_basic_formatting(self) -> None: + result = format_tool_call("my_tool", {"arg": "value"}) + assert "" in result + assert "value" in result + assert "" in result + + def test_multiple_params(self) -> None: + result = format_tool_call("tool", {"a": "1", "b": "2"}) + assert "1" in result + assert "2" in result + + def test_empty_args(self) -> None: + result = format_tool_call("empty", {}) + assert result == "\n" + + def test_preserves_newlines_in_value(self) -> None: + result = format_tool_call("tool", {"code": "line1\nline2"}) + assert "line1\nline2" in result + + +class TestCleanContent: + """Tests for clean_content function.""" + + def test_remove_function_tags(self) -> None: + content = "Hello 1 World" + result = clean_content(content) + assert " None: + content = "Start hidden End" + result = clean_content(content) + assert "inter_agent_message" not in result + assert "hidden" not in result + + def test_remove_agent_completion_report(self) -> None: + content = "Begin done Finish" + result = clean_content(content) + assert "agent_completion_report" not in result + assert "done" not in result + + def test_preserve_normal_content(self) -> None: + content = "This is normal text with no special tags." + assert clean_content(content) == content + + def test_empty_content(self) -> None: + assert clean_content("") == "" + assert clean_content(None) == "" # type: ignore[arg-type] + + def test_collapse_multiple_newlines(self) -> None: + content = "line1\n\n\n\nline2" + result = clean_content(content) + assert "\n\n\n" not in result + + +class TestTruncateToFirstFunction: + """Tests for _truncate_to_first_function function.""" + + def test_single_function_unchanged(self) -> None: + content = "text more" + assert _truncate_to_first_function(content) == content + + def test_multiple_functions_keeps_first(self) -> None: + content = "start middle end" + result = _truncate_to_first_function(content) + assert "" in result + assert "" not in result + + def test_no_functions_unchanged(self) -> None: + content = "no functions here" + assert _truncate_to_first_function(content) == content + + def test_empty_content(self) -> None: + assert _truncate_to_first_function("") == ""