diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ed0eb4af5..2b6f10b27 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -962,14 +962,7 @@ def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]: vertices.get(edge["end_id"], {}), ) else: - if v is None: - d[k] = v - elif isinstance(v, str) and (v.count("{") < 1 and v.count("[") < 1): - d[k] = v - elif isinstance(v, str): - d[k] = json.loads(v) - else: - d[k] = v + d[k] = json.loads(v) if isinstance(v, str) and ("{" in v or "[" in v) else v return d @@ -1411,9 +1404,7 @@ async def embed_nodes( embed_func = self._node_embed_algorithms[algorithm] return await embed_func() - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: """ Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. @@ -1426,29 +1417,22 @@ async def get_knowledge_graph( """ MAX_GRAPH_NODES = 1000 + # Build the query based on whether we want the full graph or a specific subgraph. if node_label == "*": - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - OPTIONAL MATCH (n)-[r]->(m:Entity) - RETURN n, r, m - LIMIT %d - $$) AS (n agtype, r agtype, m agtype)""" % ( - self.graph_name, - MAX_GRAPH_NODES, - ) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]->(m:Entity) + RETURN n, r, m + LIMIT {MAX_GRAPH_NODES} + $$) AS (n agtype, r agtype, m agtype)""" else: - encoded_node_label = self._encode_graph_label(node_label.strip('"')) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - OPTIONAL MATCH p = (n)-[*..%d]-(m) - RETURN nodes(p) AS nodes, relationships(p) AS relationships - LIMIT %d - $$) AS (nodes agtype, relationships agtype)""" % ( - self.graph_name, - encoded_node_label, - max_depth, - MAX_GRAPH_NODES, - ) + encoded_label = self._encode_graph_label(node_label.strip('"')) + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:Entity {{node_id: "{encoded_label}"}}) + OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) + RETURN nodes(p) AS nodes, relationships(p) AS relationships + LIMIT {MAX_GRAPH_NODES} + $$) AS (nodes agtype, relationships agtype)""" results = await self._query(query) @@ -1456,61 +1440,48 @@ async def get_knowledge_graph( edges = [] unique_edge_ids = set() - for result in results: - if node_label == "*": - if result["n"]: - node = result["n"] - node_id = self._decode_graph_label(node["node_id"]) - if node_id not in nodes: - nodes[node_id] = node - - if result["m"]: - node = result["m"] - node_id = self._decode_graph_label(node["node_id"]) - if node_id not in nodes: - nodes[node_id] = node - if result["r"]: - edge = result["r"] - src_id = self._decode_graph_label(edge["start_id"]) - tgt_id = self._decode_graph_label(edge["end_id"]) - edges.append((src_id, tgt_id)) - else: - if result["nodes"]: - for node in result["nodes"]: - node_id = self._decode_graph_label(node["node_id"]) - if node_id not in nodes: - nodes[node_id] = node - - if result["relationships"]: - for edge in result["relationships"]: # src --DIRECTED--> target - src_id = self._decode_graph_label(edge[0]["node_id"]) - tgt_id = self._decode_graph_label(edge[2]["node_id"]) - id = src_id + "," + tgt_id - if id in unique_edge_ids: - continue - else: - unique_edge_ids.add(id) - edges.append( - (id, src_id, tgt_id, {"source": edge[0], "target": edge[2]}) - ) + def add_node(node_data: dict): + node_id = self._decode_graph_label(node_data["node_id"]) + if node_id not in nodes: + nodes[node_id] = node_data + + def add_edge(edge_data: list): + src_id = self._decode_graph_label(edge_data[0]["node_id"]) + tgt_id = self._decode_graph_label(edge_data[2]["node_id"]) + edge_key = f"{src_id},{tgt_id}" + if edge_key not in unique_edge_ids: + unique_edge_ids.add(edge_key) + edges.append((edge_key, src_id, tgt_id, {"source": edge_data[0], "target": edge_data[2]})) + # Process the query results. + if node_label == "*": + for result in results: + if result.get("n"): + add_node(result["n"]) + if result.get("m"): + add_node(result["m"]) + if result.get("r"): + add_edge(result["r"]) + else: + for result in results: + for node in result.get("nodes", []): + add_node(node) + for edge in result.get("relationships", []): + add_edge(edge) + + # Construct and return the KnowledgeGraph. kg = KnowledgeGraph( nodes=[ - KnowledgeGraphNode( - id=node_id, labels=[node_id], properties=nodes[node_id] - ) - for node_id in nodes + KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data) + for node_id, node_data in nodes.items() ], edges=[ - KnowledgeGraphEdge( - id=id, type="DIRECTED", source=src, target=tgt, properties=props - ) - for id, src, tgt, props in edges + KnowledgeGraphEdge(id=edge_id, type="DIRECTED", source=src, target=tgt, properties=props) + for edge_id, src, tgt, props in edges ], ) return kg - async def drop(self) -> None: """Drop the storage""" drop_sql = SQL_TEMPLATES["drop_vdb_entity"]