Skip to content

Commit ba2a78c

Browse files
committed
fix mypy errors
1 parent 8aa4684 commit ba2a78c

File tree

9 files changed

+30
-22
lines changed

9 files changed

+30
-22
lines changed

src/neo4j_graphrag/embeddings/cohere.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
try:
2222
import cohere
2323
except ImportError:
24-
cohere = None
24+
cohere = None # type: ignore[assignment]
2525

2626

2727
class CohereEmbeddings(Embedder):

src/neo4j_graphrag/embeddings/mistral.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
try:
2525
from mistralai import Mistral
2626
except ImportError:
27-
Mistral = None
27+
# Define placeholder type for type checking
28+
class Mistral: # type: ignore
29+
pass
30+
Mistral = None # type: ignore
2831

2932

3033
class MistralAIEmbeddings(Embedder):

src/neo4j_graphrag/llm/anthropic_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_messages(
9191
raise LLMGenerationError(e.errors()) from e
9292
messages.extend(cast(Iterable[dict[str, Any]], message_history))
9393
messages.append(UserMessage(content=input).model_dump())
94-
return messages
94+
return messages # type: ignore[return-value]
9595

9696
def invoke(
9797
self,

src/neo4j_graphrag/llm/cohere_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_messages(
9494
raise LLMGenerationError(e.errors()) from e
9595
messages.extend(cast(Iterable[dict[str, Any]], message_history))
9696
messages.append(UserMessage(content=input).model_dump())
97-
return messages
97+
return messages # type: ignore[return-value]
9898

9999
def invoke(
100100
self,

src/neo4j_graphrag/llm/mistralai_llm.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,13 @@
3535
from mistralai import Messages, Mistral
3636
from mistralai.models.sdkerror import SDKError
3737
except ImportError:
38-
Mistral = None
39-
SDKError = None
38+
# Define placeholder types for type checking
39+
class Mistral: # type: ignore
40+
pass
41+
class SDKError(Exception): # type: ignore
42+
pass
43+
Mistral = None # type: ignore
44+
SDKError = None # type: ignore
4045

4146

4247
class MistralAILLM(LLMInterface):

src/neo4j_graphrag/llm/ollama_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_messages(
7676
raise LLMGenerationError(e.errors()) from e
7777
messages.extend(cast(Iterable[dict[str, Any]], message_history))
7878
messages.append(UserMessage(content=input).model_dump())
79-
return messages
79+
return messages # type: ignore[return-value]
8080

8181
def invoke(
8282
self,

tests/unit/llm/test_anthropic_llm.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None:
4949
input_text = "may thy knife chip and shatter"
5050
response = llm.invoke(input_text)
5151
assert response.content == "generated text"
52-
llm.client.messages.create.assert_called_once_with(
52+
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
5353
messages=[{"role": "user", "content": input_text}],
5454
model="claude-3-opus-20240229",
5555
system=anthropic.NOT_GIVEN,
@@ -81,7 +81,7 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
8181
response = llm.invoke(question, message_history)
8282
assert response.content == "generated text"
8383
message_history.add_message(LLMMessage(role="user", content=question))
84-
llm.client.messages.create.assert_called_once_with(
84+
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
8585
messages=message_history,
8686
model="claude-3-opus-20240229",
8787
system=anthropic.NOT_GIVEN,
@@ -107,14 +107,14 @@ def test_anthropic_invoke_with_system_instruction(
107107
assert isinstance(response, LLMResponse)
108108
assert response.content == "generated text"
109109
messages = [{"role": "user", "content": question}]
110-
llm.client.messages.create.assert_called_with(
110+
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
111111
model="claude-3-opus-20240229",
112112
system=system_instruction,
113113
messages=messages,
114114
**model_params,
115115
)
116116

117-
assert llm.client.messages.create.call_count == 1
117+
assert llm.client.messages.create.call_count == 1 # type: ignore[attr-defined]
118118

119119

120120
def test_anthropic_invoke_with_message_history_and_system_instruction(
@@ -145,14 +145,14 @@ def test_anthropic_invoke_with_message_history_and_system_instruction(
145145
assert isinstance(response, LLMResponse)
146146
assert response.content == "generated text"
147147
message_history.add_message(LLMMessage(role="user", content=question))
148-
llm.client.messages.create.assert_called_with(
148+
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
149149
model="claude-3-opus-20240229",
150150
system=system_instruction,
151151
messages=message_history,
152152
**model_params,
153153
)
154154

155-
assert llm.client.messages.create.call_count == 1
155+
assert llm.client.messages.create.call_count == 1 # type: ignore[attr-defined]
156156

157157

158158
def test_anthropic_invoke_with_message_history_validation_error(
@@ -190,7 +190,7 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None:
190190
input_text = "may thy knife chip and shatter"
191191
response = await llm.ainvoke(input_text)
192192
assert response.content == "Return text"
193-
llm.async_client.messages.create.assert_awaited_once_with(
193+
llm.async_client.messages.create.assert_awaited_once_with( # type: ignore[attr-defined]
194194
model="claude-3-opus-20240229",
195195
system=anthropic.NOT_GIVEN,
196196
messages=[{"role": "user", "content": input_text}],

tests/unit/llm/test_mistralai_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None:
7171
messages = [{"role": "system", "content": system_instruction}]
7272
messages.extend(message_history)
7373
messages.append({"role": "user", "content": question})
74-
llm.client.chat.complete.assert_called_once_with(
74+
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
7575
messages=messages,
7676
model=model,
7777
)
@@ -103,12 +103,12 @@ def test_mistralai_llm_invoke_with_message_history_and_system_instruction(
103103
messages = [{"role": "system", "content": system_instruction}]
104104
messages.extend(message_history)
105105
messages.append({"role": "user", "content": question})
106-
llm.client.chat.complete.assert_called_once_with(
106+
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
107107
messages=messages,
108108
model=model,
109109
)
110110

111-
assert llm.client.chat.complete.call_count == 1
111+
assert llm.client.chat.complete.call_count == 1 # type: ignore[attr-defined]
112112

113113

114114
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")

tests/unit/llm/test_ollama_llm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None:
5555
messages = [
5656
{"role": "user", "content": question},
5757
]
58-
llm.client.chat.assert_called_once_with(
58+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
5959
model=model, messages=messages, options=model_params
6060
)
6161

@@ -80,7 +80,7 @@ def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) ->
8080
assert response.content == "ollama chat response"
8181
messages = [{"role": "system", "content": system_instruction}]
8282
messages.append({"role": "user", "content": question})
83-
llm.client.chat.assert_called_once_with(
83+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
8484
model=model, messages=messages, options=model_params
8585
)
8686

@@ -108,7 +108,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
108108
assert response.content == "ollama chat response"
109109
messages = [m for m in message_history]
110110
messages.append({"role": "user", "content": question})
111-
llm.client.chat.assert_called_once_with(
111+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
112112
model=model, messages=messages, options=model_params
113113
)
114114

@@ -144,10 +144,10 @@ def test_ollama_invoke_with_message_history_and_system_instruction(
144144
messages = [{"role": "system", "content": system_instruction}]
145145
messages.extend(message_history)
146146
messages.append({"role": "user", "content": question})
147-
llm.client.chat.assert_called_once_with(
147+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
148148
model=model, messages=messages, options=model_params
149149
)
150-
assert llm.client.chat.call_count == 1
150+
assert llm.client.chat.call_count == 1 # type: ignore[attr-defined]
151151

152152

153153
@patch("builtins.__import__")

0 commit comments

Comments
 (0)