Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Postgres query parsing issues #1087

Merged
merged 3 commits into from
Mar 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 49 additions & 78 deletions lightrag/kg/postgres_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,14 +962,7 @@ def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
vertices.get(edge["end_id"], {}),
)
else:
if v is None:
d[k] = v
elif isinstance(v, str) and (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
elif isinstance(v, str):
d[k] = json.loads(v)
else:
d[k] = v
d[k] = json.loads(v) if isinstance(v, str) and ("{" in v or "[" in v) else v

return d

Expand Down Expand Up @@ -1411,9 +1404,7 @@ async def embed_nodes(
embed_func = self._node_embed_algorithms[algorithm]
return await embed_func()

async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.

Expand All @@ -1426,91 +1417,71 @@ async def get_knowledge_graph(
"""
MAX_GRAPH_NODES = 1000

# Build the query based on whether we want the full graph or a specific subgraph.
if node_label == "*":
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (
self.graph_name,
MAX_GRAPH_NODES,
)
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT {MAX_GRAPH_NODES}
$$) AS (n agtype, r agtype, m agtype)"""
else:
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype, relationships agtype)""" % (
self.graph_name,
encoded_node_label,
max_depth,
MAX_GRAPH_NODES,
)
encoded_label = self._encode_graph_label(node_label.strip('"'))
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:Entity {{node_id: "{encoded_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT {MAX_GRAPH_NODES}
$$) AS (nodes agtype, relationships agtype)"""

results = await self._query(query)

nodes = {}
edges = []
unique_edge_ids = set()

for result in results:
if node_label == "*":
if result["n"]:
node = result["n"]
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node

if result["m"]:
node = result["m"]
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["r"]:
edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
else:
if result["nodes"]:
for node in result["nodes"]:
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node

if result["relationships"]:
for edge in result["relationships"]: # src --DIRECTED--> target
src_id = self._decode_graph_label(edge[0]["node_id"])
tgt_id = self._decode_graph_label(edge[2]["node_id"])
id = src_id + "," + tgt_id
if id in unique_edge_ids:
continue
else:
unique_edge_ids.add(id)
edges.append(
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
)
def add_node(node_data: dict):
node_id = self._decode_graph_label(node_data["node_id"])
if node_id not in nodes:
nodes[node_id] = node_data

def add_edge(edge_data: list):
src_id = self._decode_graph_label(edge_data[0]["node_id"])
tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
edge_key = f"{src_id},{tgt_id}"
if edge_key not in unique_edge_ids:
unique_edge_ids.add(edge_key)
edges.append((edge_key, src_id, tgt_id, {"source": edge_data[0], "target": edge_data[2]}))

# Process the query results.
if node_label == "*":
for result in results:
if result.get("n"):
add_node(result["n"])
if result.get("m"):
add_node(result["m"])
if result.get("r"):
add_edge(result["r"])
else:
for result in results:
for node in result.get("nodes", []):
add_node(node)
for edge in result.get("relationships", []):
add_edge(edge)

# Construct and return the KnowledgeGraph.
kg = KnowledgeGraph(
nodes=[
KnowledgeGraphNode(
id=node_id, labels=[node_id], properties=nodes[node_id]
)
for node_id in nodes
KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
for node_id, node_data in nodes.items()
],
edges=[
KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props
)
for id, src, tgt, props in edges
KnowledgeGraphEdge(id=edge_id, type="DIRECTED", source=src, target=tgt, properties=props)
for edge_id, src, tgt, props in edges
],
)

return kg

async def drop(self) -> None:
"""Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
Expand Down
Loading