Skip to content

Commit 81edacd

Browse files
committed
Add SchemaFromExistingGraphExtractor component
Parses the result from get_structured_schema and returns a GraphSchema object
1 parent dffd484 commit 81edacd

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
2+
to automatically extract a schema from an existing Neo4j database.
3+
"""
4+
5+
import asyncio
6+
7+
import neo4j
8+
9+
from neo4j_graphrag.experimental.components.schema import (
10+
SchemaFromExistingGraphExtractor,
11+
GraphSchema,
12+
)
13+
14+
15+
URI = "neo4j+s://demo.neo4jlabs.com"
16+
AUTH = ("recommendations", "recommendations")
17+
DATABASE = "recommendations"
18+
INDEX = "moviePlotsEmbedding"
19+
20+
21+
async def main() -> None:
22+
"""Run the example."""
23+
24+
with neo4j.GraphDatabase.driver(
25+
URI,
26+
auth=AUTH,
27+
) as driver:
28+
extractor = SchemaFromExistingGraphExtractor(driver)
29+
schema: GraphSchema = await extractor.run()
30+
# schema.store_as_json("my_schema.json")
31+
print(schema)
32+
33+
34+
if __name__ == "__main__":
35+
asyncio.run(main())

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from __future__ import annotations
1616

1717
import json
18+
19+
import neo4j
1820
import logging
1921
import warnings
2022
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
@@ -43,6 +45,7 @@
4345
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
4446
from neo4j_graphrag.llm import LLMInterface
4547
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
48+
from neo4j_graphrag.schema import get_structured_schema
4649

4750

4851
class PropertyType(BaseModel):
@@ -270,7 +273,12 @@ def from_file(
270273
raise SchemaValidationError(str(e)) from e
271274

272275

273-
class SchemaBuilder(Component):
276+
class BaseSchemaBuilder(Component):
277+
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
278+
raise NotImplementedError()
279+
280+
281+
class SchemaBuilder(BaseSchemaBuilder):
274282
"""
275283
A builder class for constructing GraphSchema objects from given entities,
276284
relations, and their interrelationships defined in a potential schema.
@@ -379,7 +387,7 @@ async def run(
379387
return self.create_schema_model(node_types, relationship_types, patterns)
380388

381389

382-
class SchemaFromTextExtractor(Component):
390+
class SchemaFromTextExtractor(BaseSchemaBuilder):
383391
"""
384392
A component for constructing GraphSchema objects from the output of an LLM after
385393
automatic schema extraction from text.
@@ -462,3 +470,75 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
462470
"patterns": extracted_patterns,
463471
}
464472
)
473+
474+
475+
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
476+
"""A class to build a GraphSchema object from an existing graph."""
477+
478+
def __init__(self, driver: neo4j.Driver) -> None:
479+
self.driver = driver
480+
481+
async def run(self, **kwargs: Any) -> GraphSchema:
482+
structured_schema = get_structured_schema(self.driver)
483+
node_labels = set(structured_schema["node_props"].keys())
484+
node_types = [
485+
{
486+
"label": key,
487+
"properties": [
488+
{
489+
"name": p["property"],
490+
"type": p["type"],
491+
}
492+
for p in properties
493+
],
494+
}
495+
for key, properties in structured_schema["node_props"].items()
496+
]
497+
rel_labels = set(structured_schema["rel_props"].keys())
498+
relationship_types = [
499+
{
500+
"label": key,
501+
"properties": [
502+
{
503+
"name": p["property"],
504+
"type": p["type"],
505+
}
506+
for p in properties
507+
],
508+
}
509+
for key, properties in structured_schema["rel_props"].items()
510+
]
511+
patterns = [
512+
(s["start"], s["type"], s["end"])
513+
for s in structured_schema["relationships"]
514+
]
515+
# deal with nodes and relationships without properties
516+
for source, rel, target in patterns:
517+
if source not in node_labels:
518+
node_labels.add(source)
519+
node_types.append(
520+
{
521+
"label": source,
522+
}
523+
)
524+
if target not in node_labels:
525+
node_labels.add(target)
526+
node_types.append(
527+
{
528+
"label": target,
529+
}
530+
)
531+
if rel not in rel_labels:
532+
rel_labels.add(rel)
533+
relationship_types.append(
534+
{
535+
"label": rel,
536+
}
537+
)
538+
return GraphSchema.model_validate(
539+
{
540+
"node_types": node_types,
541+
"relationship_types": relationship_types,
542+
"patterns": patterns,
543+
}
544+
)

0 commit comments

Comments
 (0)