Skip to content

Commit 4faf82f

Browse files
committed
WIP: update tests
1 parent 1cf93fc commit 4faf82f

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

src/neo4j_graphrag/llm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def legacy_inputs_to_messages(
2121
if isinstance(message_history, MessageHistory):
2222
messages = message_history.messages
2323
else: # list[LLMMessage]
24-
messages = message_history
24+
messages = [LLMMessage(**m) for m in message_history]
2525
else:
2626
messages = []
2727
if system_instruction is not None:

tests/unit/llm/test_base.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Type, Generator, Optional, Any
2+
from unittest.mock import patch, Mock
3+
4+
from joblib.testing import fixture
5+
6+
from neo4j_graphrag.llm import LLMInterface
7+
from neo4j_graphrag.types import LLMMessage
8+
9+
10+
@fixture(scope="module") # type: ignore[misc]
11+
def llm_interface() -> Generator[Type[LLMInterface], None, None]:
12+
real_abstract_methods = LLMInterface.__abstractmethods__
13+
LLMInterface.__abstractmethods__ = frozenset()
14+
15+
class CustomLLMInterface(LLMInterface):
16+
pass
17+
18+
yield CustomLLMInterface
19+
20+
LLMInterface.__abstractmethods__ = real_abstract_methods
21+
22+
23+
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
24+
def test_base_llm_interface_invoke_with_input_as_str(mock_inputs: Mock, llm_interface: Type[LLMInterface]) -> None:
25+
mock_inputs.return_value = [LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")]
26+
llm = llm_interface(model_name="test")
27+
message_history = [
28+
LLMMessage(**{"role": "user", "content": "When does the sun come up in the summer?"}),
29+
LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}),
30+
]
31+
question = "What about next season?"
32+
system_instruction = "You are a genius."
33+
34+
with patch.object(llm, "_invoke") as mock_invoke:
35+
llm.invoke(question, message_history, system_instruction)
36+
mock_invoke.assert_called_once_with(
37+
[LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")]
38+
)
39+
mock_inputs.assert_called_once_with(
40+
question,
41+
message_history,
42+
system_instruction,
43+
)

tests/unit/llm/test_openai_llm.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM
2222
from neo4j_graphrag.llm.types import ToolCallResponse
2323
from neo4j_graphrag.tool import Tool
24+
from neo4j_graphrag.types import LLMMessage
2425

2526

2627
def get_mock_openai() -> MagicMock:
@@ -50,7 +51,9 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None:
5051

5152

5253
@patch("builtins.__import__")
53-
def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None:
54+
@patch("neo4j_graphrag.llm.openai_llm.legacy_inputs_to_messages")
55+
def test_openai_llm_with_message_history_happy_path(mock_inputs: Mock, mock_import: Mock) -> None:
56+
mock_inputs.return_value = [LLMMessage(role="user", content="text")]
5457
mock_openai = get_mock_openai()
5558
mock_import.return_value = mock_openai
5659
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
@@ -63,18 +66,10 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None:
6366
]
6467
question = "What about next season?"
6568

66-
res = llm.invoke(question, message_history) # type: ignore
67-
assert isinstance(res, LLMResponse)
68-
assert res.content == "openai chat response"
69-
message_history.append({"role": "user", "content": question})
70-
# Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions
71-
llm.client.chat.completions.create.assert_called_once() # type: ignore
72-
# Check call arguments individually
73-
call_args = llm.client.chat.completions.create.call_args[ # type: ignore
74-
1
75-
] # Get the keyword arguments
76-
assert call_args["messages"] == message_history
77-
assert call_args["model"] == "gpt"
69+
with patch.object(llm, "_invoke") as mock_invoke:
70+
llm.invoke(question, message_history) # type: ignore
71+
mock_invoke.assert_called_once_with([LLMMessage(role="user", content="text")])
72+
mock_inputs.assert_called_once_with(input=question, message_history=message_history)
7873

7974

8075
@patch("builtins.__import__")
@@ -404,5 +399,6 @@ def test_azure_openai_llm_with_message_history_validation_error(
404399
question = "What about next season?"
405400

406401
with pytest.raises(LLMGenerationError) as exc_info:
407-
llm.invoke(question, message_history) # type: ignore
402+
r = llm.invoke(question, message_history) # type: ignore
403+
print(r)
408404
assert "Input should be a valid string" in str(exc_info.value)

0 commit comments

Comments
 (0)