21
21
from neo4j_graphrag .llm .openai_llm import AzureOpenAILLM , OpenAILLM
22
22
from neo4j_graphrag .llm .types import ToolCallResponse
23
23
from neo4j_graphrag .tool import Tool
24
+ from neo4j_graphrag .types import LLMMessage
24
25
25
26
26
27
def get_mock_openai () -> MagicMock :
@@ -50,7 +51,9 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None:
50
51
51
52
52
53
@patch ("builtins.__import__" )
53
- def test_openai_llm_with_message_history_happy_path (mock_import : Mock ) -> None :
54
+ @patch ("neo4j_graphrag.llm.openai_llm.legacy_inputs_to_messages" )
55
+ def test_openai_llm_with_message_history_happy_path (mock_inputs : Mock , mock_import : Mock ) -> None :
56
+ mock_inputs .return_value = [LLMMessage (role = "user" , content = "text" )]
54
57
mock_openai = get_mock_openai ()
55
58
mock_import .return_value = mock_openai
56
59
mock_openai .OpenAI .return_value .chat .completions .create .return_value = MagicMock (
@@ -63,18 +66,10 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None:
63
66
]
64
67
question = "What about next season?"
65
68
66
- res = llm .invoke (question , message_history ) # type: ignore
67
- assert isinstance (res , LLMResponse )
68
- assert res .content == "openai chat response"
69
- message_history .append ({"role" : "user" , "content" : question })
70
- # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions
71
- llm .client .chat .completions .create .assert_called_once () # type: ignore
72
- # Check call arguments individually
73
- call_args = llm .client .chat .completions .create .call_args [ # type: ignore
74
- 1
75
- ] # Get the keyword arguments
76
- assert call_args ["messages" ] == message_history
77
- assert call_args ["model" ] == "gpt"
69
+ with patch .object (llm , "_invoke" ) as mock_invoke :
70
+ llm .invoke (question , message_history ) # type: ignore
71
+ mock_invoke .assert_called_once_with ([LLMMessage (role = "user" , content = "text" )])
72
+ mock_inputs .assert_called_once_with (input = question , message_history = message_history )
78
73
79
74
80
75
@patch ("builtins.__import__" )
@@ -404,5 +399,6 @@ def test_azure_openai_llm_with_message_history_validation_error(
404
399
question = "What about next season?"
405
400
406
401
with pytest .raises (LLMGenerationError ) as exc_info :
407
- llm .invoke (question , message_history ) # type: ignore
402
+ r = llm .invoke (question , message_history ) # type: ignore
403
+ print (r )
408
404
assert "Input should be a valid string" in str (exc_info .value )
0 commit comments