Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def get_retrieved_objects(
# Normalize single query to batch for uniform processing
effective_batch = [query] if query else query_batch

triplets_batch = await self.get_triplets(query_batch=effective_batch)
triplets_batch = await self.get_triplets_batch(effective_batch)
if not triplets_batch:
return []

Expand Down Expand Up @@ -106,7 +106,7 @@ async def _run_extension_round(self, states: dict):
system_prompt=self.system_prompt,
)

new_triplets_batch = await self.get_triplets(query_batch=list(completions))
new_triplets_batch = await self.get_triplets_batch(list(completions))
for q, new_triplets in zip(active_queries, new_triplets_batch):
states[q].merge_triplets(new_triplets)

Expand Down
4 changes: 2 additions & 2 deletions cognee/modules/retrieval/graph_completion_cot_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def _run_cot_completion(
async def _fetch_initial_triplets_and_context(self, states: dict):
"""Fetch triplets and resolve context text for all queries."""
queries = list(states.keys())
triplets_batch = await self.get_triplets(query_batch=queries)
triplets_batch = await self.get_triplets_batch(queries)
context_batch = await asyncio.gather(
*[self.resolve_edges_to_text(t) for t in triplets_batch]
)
Expand Down Expand Up @@ -243,7 +243,7 @@ def _build_followup_prompts(self, states, reasoning_batch):
async def _merge_followup_triplets(self, states: dict, followup_questions: List[str]):
"""Fetch triplets for follow-up questions and merge with existing state."""
queries = list(states.keys())
new_triplets_batch = await self.get_triplets(query_batch=followup_questions)
new_triplets_batch = await self.get_triplets_batch(followup_questions)

for q, new_triplets in zip(queries, new_triplets_batch):
states[q].merge_triplets(new_triplets)
Expand Down
20 changes: 20 additions & 0 deletions cognee/modules/retrieval/graph_completion_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,26 @@ async def get_triplets(
triplet_distance_penalty=self.triplet_distance_penalty,
)

async def get_triplets_batch(
self,
queries: List[str],
) -> List[List[Edge]]:
"""
Retrieves triplets for a list of queries, using single-query mode when
possible to enable ID-filtered graph projection.

When there is only one query, delegates to single-query mode (query=)
which computes relevant node IDs and filters the graph projection.
For multiple queries, uses batch mode (query_batch=).

Returns:
List[List[Edge]]: One list of edges per query.
"""
if len(queries) == 1:
triplets = await self.get_triplets(query=queries[0])
return [triplets]
return await self.get_triplets(query_batch=queries)
Comment on lines +163 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Type annotation mismatch and missing empty-list guard in get_triplets_batch

Two issues:

  1. mypy incompatibility: get_triplets is typed -> Union[List[Edge], List[List[Edge]]], so both return paths fail mypy's return-type check against the declared -> List[List[Edge]]:

    • Single-query branch: [triplets] has inferred type List[Union[List[Edge], List[List[Edge]]]].
    • Multi-query branch: direct return of Union[List[Edge], List[List[Edge]]].
  2. No guard for empty queries: when len(queries) == 0, the call falls through to get_triplets(query_batch=[]) whose behaviour with an empty batch is undefined in brute_force_triplet_search. As a public method, this edge case should be defended.

🛠️ Proposed fix
+from typing import cast

 async def get_triplets_batch(
     self,
     queries: List[str],
 ) -> List[List[Edge]]:
+    if not queries:
+        return []
     if len(queries) == 1:
-        triplets = await self.get_triplets(query=queries[0])
+        triplets = cast(List[Edge], await self.get_triplets(query=queries[0]))
         return [triplets]
-    return await self.get_triplets(query_batch=queries)
+    return cast(List[List[Edge]], await self.get_triplets(query_batch=queries))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cognee/modules/retrieval/graph_completion_retriever.py` around lines 163 -
181, The method get_triplets_batch must guard the empty-input case and normalize
the heterogeneous return shape of get_triplets so the declared return type
List[List[Edge]] is satisfied for mypy: first, if queries is empty return an
empty list immediately; second, when len(queries)==1 call
get_triplets(query=...) and normalize its result so you always return a
List[List[Edge]] (if get_triplets returns List[Edge] wrap it as [result], if it
returns List[List[Edge]] use it but ensure you return exactly one inner list);
third, when calling get_triplets(query_batch=...) assert/coerce the batch result
to List[List[Edge]] (if the call yields a flat List[Edge] wrap it into a
single-item list-per-query mapping) so both branches have the same concrete type
and mypy passes. Ensure you reference get_triplets and get_triplets_batch while
making these checks and conversions.


async def get_context_from_objects(
self,
query: Optional[str] = None,
Expand Down
Loading