Skip to content

Commit 02e99d4

Browse files
committed
Improve test coverage for utils and base modules
1 parent 4faf82f commit 02e99d4

File tree

4 files changed

+195
-33
lines changed

4 files changed

+195
-33
lines changed

src/neo4j_graphrag/llm/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from abc import ABC, abstractmethod
1818
from typing import Any, List, Optional, Sequence, Union
1919

20+
from pydantic import ValidationError
21+
2022
from neo4j_graphrag.message_history import MessageHistory
2123
from neo4j_graphrag.types import LLMMessage
2224
from .rate_limit import rate_limit_handler
@@ -31,6 +33,7 @@
3133

3234
from .rate_limit import RateLimitHandler
3335
from .utils import legacy_inputs_to_messages
36+
from ..exceptions import LLMGenerationError
3437

3538

3639
class LLMInterface(ABC):
@@ -65,7 +68,12 @@ def invoke(
6568
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
6669
system_instruction: Optional[str] = None,
6770
) -> LLMResponse:
68-
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
71+
try:
72+
messages = legacy_inputs_to_messages(
73+
input, message_history, system_instruction
74+
)
75+
except ValidationError as e:
76+
raise LLMGenerationError("Input validation failed") from e
6977
return self._invoke(messages)
7078

7179
@abstractmethod
@@ -138,7 +146,12 @@ def invoke_with_tools(
138146
LLMGenerationError: If anything goes wrong.
139147
NotImplementedError: If the LLM provider does not support tool calling.
140148
"""
141-
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
149+
try:
150+
messages = legacy_inputs_to_messages(
151+
input, message_history, system_instruction
152+
)
153+
except ValidationError as e:
154+
raise LLMGenerationError("Input validation failed") from e
142155
return self._invoke_with_tools(messages, tools)
143156

144157
def _invoke_with_tools(

src/neo4j_graphrag/llm/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
import warnings
216
from typing import Union, Optional
317

18+
from pydantic import TypeAdapter
19+
420
from neo4j_graphrag.message_history import MessageHistory
521
from neo4j_graphrag.types import LLMMessage
622

@@ -12,6 +28,9 @@ def system_instruction_from_messages(messages: list[LLMMessage]) -> str | None:
1228
return None
1329

1430

31+
llm_messages_adapter = TypeAdapter(list[LLMMessage])
32+
33+
1534
def legacy_inputs_to_messages(
1635
input: Union[str, list[LLMMessage], MessageHistory],
1736
message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None,
@@ -21,7 +40,7 @@ def legacy_inputs_to_messages(
2140
if isinstance(message_history, MessageHistory):
2241
messages = message_history.messages
2342
else: # list[LLMMessage]
24-
messages = [LLMMessage(**m) for m in message_history]
43+
messages = llm_messages_adapter.validate_python(message_history)
2544
else:
2645
messages = []
2746
if system_instruction is not None:

tests/unit/llm/test_base.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Type, Generator, Optional, Any
1+
from typing import Type, Generator
22
from unittest.mock import patch, Mock
33

4+
import pytest
45
from joblib.testing import fixture
6+
from pydantic import ValidationError
57

8+
from neo4j_graphrag.exceptions import LLMGenerationError
69
from neo4j_graphrag.llm import LLMInterface
710
from neo4j_graphrag.types import LLMMessage
811

@@ -21,11 +24,20 @@ class CustomLLMInterface(LLMInterface):
2124

2225

2326
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
24-
def test_base_llm_interface_invoke_with_input_as_str(mock_inputs: Mock, llm_interface: Type[LLMInterface]) -> None:
25-
mock_inputs.return_value = [LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")]
27+
def test_base_llm_interface_invoke_with_input_as_str(
28+
mock_inputs: Mock, llm_interface: Type[LLMInterface]
29+
) -> None:
30+
mock_inputs.return_value = [
31+
LLMMessage(
32+
role="user",
33+
content="return value of the legacy_inputs_to_messages function",
34+
)
35+
]
2636
llm = llm_interface(model_name="test")
2737
message_history = [
28-
LLMMessage(**{"role": "user", "content": "When does the sun come up in the summer?"}),
38+
LLMMessage(
39+
**{"role": "user", "content": "When does the sun come up in the summer?"}
40+
),
2941
LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}),
3042
]
3143
question = "What about next season?"
@@ -34,10 +46,91 @@ def test_base_llm_interface_invoke_with_input_as_str(mock_inputs: Mock, llm_inte
3446
with patch.object(llm, "_invoke") as mock_invoke:
3547
llm.invoke(question, message_history, system_instruction)
3648
mock_invoke.assert_called_once_with(
37-
[LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")]
49+
[
50+
LLMMessage(
51+
role="user",
52+
content="return value of the legacy_inputs_to_messages function",
53+
)
54+
]
3855
)
3956
mock_inputs.assert_called_once_with(
4057
question,
4158
message_history,
4259
system_instruction,
4360
)
61+
62+
63+
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
64+
def test_base_llm_interface_invoke_with_invalid_inputs(
65+
mock_inputs: Mock, llm_interface: Type[LLMInterface]
66+
) -> None:
67+
mock_inputs.side_effect = [
68+
ValidationError.from_exception_data("Invalid data", line_errors=[])
69+
]
70+
llm = llm_interface(model_name="test")
71+
question = "What about next season?"
72+
73+
with pytest.raises(LLMGenerationError, match="Input validation failed"):
74+
llm.invoke(question)
75+
mock_inputs.assert_called_once_with(
76+
question,
77+
None,
78+
None,
79+
)
80+
81+
82+
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
83+
def test_base_llm_interface_invoke_with_tools_with_input_as_str(
84+
mock_inputs: Mock, llm_interface: Type[LLMInterface]
85+
) -> None:
86+
mock_inputs.return_value = [
87+
LLMMessage(
88+
role="user",
89+
content="return value of the legacy_inputs_to_messages function",
90+
)
91+
]
92+
llm = llm_interface(model_name="test")
93+
message_history = [
94+
LLMMessage(
95+
**{"role": "user", "content": "When does the sun come up in the summer?"}
96+
),
97+
LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}),
98+
]
99+
question = "What about next season?"
100+
system_instruction = "You are a genius."
101+
102+
with patch.object(llm, "_invoke_with_tools") as mock_invoke:
103+
llm.invoke_with_tools(question, [], message_history, system_instruction)
104+
mock_invoke.assert_called_once_with(
105+
[
106+
LLMMessage(
107+
role="user",
108+
content="return value of the legacy_inputs_to_messages function",
109+
)
110+
],
111+
[], # tools
112+
)
113+
mock_inputs.assert_called_once_with(
114+
question,
115+
message_history,
116+
system_instruction,
117+
)
118+
119+
120+
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
121+
def test_base_llm_interface_invoke_with_tools_with_invalid_inputs(
122+
mock_inputs: Mock, llm_interface: Type[LLMInterface]
123+
) -> None:
124+
mock_inputs.side_effect = [
125+
ValidationError.from_exception_data("Invalid data", line_errors=[])
126+
]
127+
llm = llm_interface(model_name="test")
128+
question = "What about next season?"
129+
130+
with pytest.raises(LLMGenerationError, match="Input validation failed"):
131+
llm.invoke_with_tools(question, [])
132+
mock_inputs.assert_called_once_with(
133+
question,
134+
None,
135+
None,
136+
)

tests/unit/llm/test_utils.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
import pytest
16+
from pydantic import ValidationError
217

3-
from neo4j_graphrag.llm.utils import system_instruction_from_messages, \
4-
legacy_inputs_to_messages
18+
from neo4j_graphrag.llm.utils import (
19+
system_instruction_from_messages,
20+
legacy_inputs_to_messages,
21+
)
522
from neo4j_graphrag.message_history import InMemoryMessageHistory
623
from neo4j_graphrag.types import LLMMessage
724

825

9-
def test_system_instruction_from_messages():
26+
def test_system_instruction_from_messages() -> None:
1027
messages = [
1128
LLMMessage(role="system", content="text"),
1229
]
@@ -21,66 +38,76 @@ def test_system_instruction_from_messages():
2138
assert system_instruction_from_messages(messages) is None
2239

2340

24-
def test_legacy_inputs_to_messages_only_input_as_llm_message_list():
25-
messages = legacy_inputs_to_messages(input=[
26-
LLMMessage(role="user", content="text"),
27-
])
41+
def test_legacy_inputs_to_messages_only_input_as_llm_message_list() -> None:
42+
messages = legacy_inputs_to_messages(
43+
input=[
44+
LLMMessage(role="user", content="text"),
45+
]
46+
)
2847
assert messages == [
2948
LLMMessage(role="user", content="text"),
3049
]
3150

3251

33-
def test_legacy_inputs_to_messages_only_input_as_message_history():
34-
messages = legacy_inputs_to_messages(input=InMemoryMessageHistory(
35-
messages=[
36-
LLMMessage(role="user", content="text"),
37-
]
38-
))
52+
def test_legacy_inputs_to_messages_only_input_as_message_history() -> None:
53+
messages = legacy_inputs_to_messages(
54+
input=InMemoryMessageHistory(
55+
messages=[
56+
LLMMessage(role="user", content="text"),
57+
]
58+
)
59+
)
3960
assert messages == [
4061
LLMMessage(role="user", content="text"),
4162
]
4263

4364

44-
def test_legacy_inputs_to_messages_only_input_as_str():
65+
def test_legacy_inputs_to_messages_only_input_as_str() -> None:
4566
messages = legacy_inputs_to_messages(input="text")
4667
assert messages == [
4768
LLMMessage(role="user", content="text"),
4869
]
4970

5071

51-
def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list():
72+
def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list() -> (
73+
None
74+
):
5275
messages = legacy_inputs_to_messages(
5376
input="text",
5477
message_history=[
5578
LLMMessage(role="assistant", content="How can I assist you today?"),
56-
]
79+
],
5780
)
5881
assert messages == [
5982
LLMMessage(role="assistant", content="How can I assist you today?"),
6083
LLMMessage(role="user", content="text"),
6184
]
6285

6386

64-
def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history():
87+
def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history() -> (
88+
None
89+
):
6590
messages = legacy_inputs_to_messages(
6691
input="text",
67-
message_history=InMemoryMessageHistory(messages=[
68-
LLMMessage(role="assistant", content="How can I assist you today?"),
69-
])
92+
message_history=InMemoryMessageHistory(
93+
messages=[
94+
LLMMessage(role="assistant", content="How can I assist you today?"),
95+
]
96+
),
7097
)
7198
assert messages == [
7299
LLMMessage(role="assistant", content="How can I assist you today?"),
73100
LLMMessage(role="user", content="text"),
74101
]
75102

76103

77-
def test_legacy_inputs_to_messages_with_explicit_system_instruction():
104+
def test_legacy_inputs_to_messages_with_explicit_system_instruction() -> None:
78105
messages = legacy_inputs_to_messages(
79106
input="text",
80107
message_history=[
81108
LLMMessage(role="assistant", content="How can I assist you today?"),
82109
],
83-
system_instruction="You are a genius."
110+
system_instruction="You are a genius.",
84111
)
85112
assert messages == [
86113
LLMMessage(role="system", content="You are a genius."),
@@ -89,19 +116,29 @@ def test_legacy_inputs_to_messages_with_explicit_system_instruction():
89116
]
90117

91118

92-
def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction():
119+
def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction() -> None:
93120
with pytest.warns(
94121
UserWarning,
95-
match="system_instruction provided but ignored as the message history already contains a system message"
122+
match="system_instruction provided but ignored as the message history already contains a system message",
96123
):
97124
messages = legacy_inputs_to_messages(
98125
input="text",
99126
message_history=[
100127
LLMMessage(role="system", content="You are super smart."),
101128
],
102-
system_instruction="You are a genius."
129+
system_instruction="You are a genius.",
103130
)
104131
assert messages == [
105132
LLMMessage(role="system", content="You are super smart."),
106133
LLMMessage(role="user", content="text"),
107134
]
135+
136+
137+
def test_legacy_inputs_to_messages_wrong_type_in_message_list() -> None:
138+
with pytest.raises(ValidationError, match="Input should be a valid string"):
139+
legacy_inputs_to_messages(
140+
input="text",
141+
message_history=[
142+
{"role": "system", "content": 10}, # type: ignore
143+
],
144+
)

0 commit comments

Comments
 (0)