Skip to content

Commit 955c262

Browse files
committed
Add basic Parqeut Writer component
1 parent b34419b commit 955c262

File tree

4 files changed

+248
-2
lines changed

4 files changed

+248
-2
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import asyncio
2+
3+
from neo4j_graphrag.experimental.components.kg_writer import (
4+
KGWriterModel,
5+
ParquetWriter,
6+
)
7+
from neo4j_graphrag.experimental.components.schema import GraphSchema
8+
from neo4j_graphrag.experimental.components.types import (
9+
Neo4jGraph,
10+
Neo4jNode,
11+
Neo4jRelationship,
12+
)
13+
14+
OUTPUT_FOLDER = "output"
15+
16+
17+
async def main(graph: Neo4jGraph) -> KGWriterModel:
18+
writer = ParquetWriter(output_folder=OUTPUT_FOLDER)
19+
result = await writer.run(graph=graph)
20+
return result
21+
22+
23+
if __name__ == "__main__":
24+
graph = Neo4jGraph(
25+
nodes=[
26+
Neo4jNode(
27+
id="p-0",
28+
label="Person",
29+
properties={
30+
"name": "Alice",
31+
"eyeColor": "blue",
32+
},
33+
),
34+
Neo4jNode(
35+
id="p-1",
36+
label="Person",
37+
properties={
38+
"name": "Robert",
39+
"eyeColor": "brown",
40+
"nickName": "Bob",
41+
},
42+
),
43+
Neo4jNode(
44+
id="l-0",
45+
label="Location",
46+
properties={
47+
"name": "Wonderland",
48+
},
49+
),
50+
],
51+
relationships=[
52+
Neo4jRelationship(
53+
type="KNOWS",
54+
start_node_id="p-0",
55+
end_node_id="p-1",
56+
properties={"reason": "Cryptography"},
57+
)
58+
],
59+
)
60+
61+
schema = GraphSchema.model_validate(
62+
{
63+
"node_types": [
64+
"Location",
65+
{
66+
"label": "Person",
67+
"properties": [
68+
{
69+
"name": "name",
70+
"type": "STRING",
71+
"required": True,
72+
}
73+
],
74+
"additional_properties": True,
75+
},
76+
],
77+
"relationship_types": [
78+
"KNOWS",
79+
],
80+
}
81+
)
82+
83+
res = asyncio.run(main(graph=graph))
84+
print(res)
85+
# import pandas as pd
86+
#
87+
# df = pd.read_parquet(f'{OUTPUT_FOLDER}/nodes/Person.parquet')
88+
# print(df.head(10))
89+
# df = pd.read_parquet(f'{OUTPUT_FOLDER}/nodes/Location.parquet')
90+
# print(df.head(10))
91+
# df = pd.read_parquet(f'{OUTPUT_FOLDER}/relationships/KNOWS.parquet')
92+
# print(df.head(10))

poetry.lock

Lines changed: 58 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ scipy = [
6161
{ version = "^1.15.0", python = ">=3.13,<3.14" }
6262
]
6363
tenacity = "^9.1.2"
64+
pyarrow = "^21.0.0"
6465

6566
[tool.poetry.group.dev.dependencies]
6667
urllib3 = "<2"

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16+
import pandas as pd
1617

1718
import logging
1819
from abc import abstractmethod
20+
from collections import defaultdict
21+
from pathlib import Path
1922
from typing import Any, Generator, Literal, Optional
2023

2124
import neo4j
@@ -83,6 +86,100 @@ async def run(
8386
pass
8487

8588

89+
class ParquetWriter(KGWriter):
90+
"""Writes a knowledge graph to Parquet files.
91+
92+
Args:
93+
node_file_path (str): The file path to write the nodes Parquet file.
94+
relationship_file_path (str): The file path to write the relationships Parquet file.
95+
"""
96+
97+
def __init__(
98+
self,
99+
output_folder: str,
100+
):
101+
self._output_folder = Path(output_folder)
102+
103+
@staticmethod
104+
def _nodes_to_rows(
105+
nodes: list[Neo4jNode],
106+
label_to_rows: dict[str, list[dict[str, Any]]],
107+
lexical_graph_config: LexicalGraphConfig,
108+
) -> None:
109+
for node in nodes:
110+
row: dict[str, Any] = dict()
111+
112+
labels = [node.label]
113+
if node.label not in lexical_graph_config.lexical_graph_node_labels:
114+
labels.append("__Entity__")
115+
116+
row["id"] = node.id
117+
row["labels"] = labels
118+
row.update(node.properties)
119+
if node.embedding_properties is not None:
120+
row.update(node.embedding_properties)
121+
122+
label_to_rows[node.label].append(row)
123+
124+
@staticmethod
125+
def _relationships_to_rows(
126+
relationships: list[Neo4jRelationship],
127+
type_to_rows: dict[str, list[dict[str, Any]]],
128+
) -> None:
129+
for rel in relationships:
130+
row: dict[str, Any] = dict()
131+
132+
row["from"] = rel.start_node_id
133+
row["to"] = rel.end_node_id
134+
row["type"] = rel.type
135+
row.update(rel.properties)
136+
if rel.embedding_properties is not None:
137+
row.update(rel.embedding_properties)
138+
139+
type_to_rows[rel.type].append(row)
140+
141+
@validate_call
142+
async def run(
143+
self,
144+
graph: Neo4jGraph,
145+
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
146+
) -> KGWriterModel:
147+
"""Writes a knowledge graph to Parquet files.
148+
149+
Args:
150+
graph (Neo4jGraph): The knowledge graph to write to Parquet files.
151+
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
152+
"""
153+
(self._output_folder / "nodes").mkdir(parents=True, exist_ok=True)
154+
(self._output_folder / "relationships").mkdir(parents=True, exist_ok=True)
155+
156+
label_to_rows: defaultdict[str, list[dict[str, Any]]] = defaultdict(list)
157+
type_to_rows: defaultdict[str, list[dict[str, Any]]] = defaultdict(list)
158+
159+
# TODO: Parallelize?
160+
self._nodes_to_rows(graph.nodes, label_to_rows, lexical_graph_config)
161+
self._relationships_to_rows(graph.relationships, type_to_rows)
162+
163+
for label, node_df in label_to_rows.items():
164+
pd.DataFrame(node_df).to_parquet(
165+
self._output_folder / "nodes" / f"{label}.parquet", index=False
166+
)
167+
168+
for rtype, rel_df in type_to_rows.items():
169+
pd.DataFrame(rel_df).to_parquet(
170+
self._output_folder / "relationships" / f"{rtype}.parquet", index=False
171+
)
172+
173+
return KGWriterModel(
174+
status="SUCCESS",
175+
metadata={
176+
"node_count": len(graph.nodes),
177+
"relationship_count": len(graph.relationships),
178+
"output_folder": str(self._output_folder),
179+
},
180+
)
181+
182+
86183
class Neo4jWriter(KGWriter):
87184
"""Writes a knowledge graph to a Neo4j database.
88185

0 commit comments

Comments
 (0)