Skip to content

Commit 1cb21f1

Browse files
committed
Fix tests
1 parent 7db25a8 commit 1cb21f1

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/neo4j_graphrag/experimental/pipeline/pipeline.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
249249
caption=f"{node.component.__class__.__name__}: {n}({comp_inputs})",
250250
size=20, # Component nodes are larger
251251
color="#4C8BF5", # Blue for component nodes
252-
caption_align=CaptionAlignment.CENTER,
253-
caption_size=12,
252+
caption_alignment=CaptionAlignment.CENTER,
253+
caption_size=3,
254254
)
255255
)
256256
node_counter += 1
@@ -265,8 +265,8 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
265265
caption=o,
266266
size=10, # Output nodes are smaller
267267
color="#34A853", # Green for output nodes
268-
caption_align=CaptionAlignment.CENTER,
269-
caption_size=10,
268+
caption_alignment=CaptionAlignment.CENTER,
269+
caption_size=3,
270270
)
271271
)
272272
# Connect component to its output

src/neo4j_graphrag/llm/anthropic_llm.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def invoke(
112112
"""
113113
try:
114114
if isinstance(message_history, MessageHistory):
115-
message_history = message_history.messages
116-
messages = self.get_messages(input, message_history)
115+
message_history.add_message(LLMMessage(role="user", content=input))
116+
messages = message_history
117+
else:
118+
messages = self.get_messages(input, message_history)
117119
response = self.client.messages.create(
118120
model=self.model_name,
119121
system=system_instruction or self.anthropic.NOT_GIVEN,
@@ -148,8 +150,10 @@ async def ainvoke(
148150
"""
149151
try:
150152
if isinstance(message_history, MessageHistory):
151-
message_history = message_history.messages
152-
messages = self.get_messages(input, message_history)
153+
message_history.add_message(LLMMessage(role="user", content=input))
154+
messages = message_history
155+
else:
156+
messages = self.get_messages(input, message_history)
153157
response = await self.async_client.messages.create(
154158
model=self.model_name,
155159
system=system_instruction or self.anthropic.NOT_GIVEN,

0 commit comments

Comments
 (0)