Skip to content

SchemaFromExistingGraphExtractor component #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
to automatically extract a schema from an existing Neo4j database.
"""

import asyncio

import neo4j

from neo4j_graphrag.experimental.components.schema import (
SchemaFromExistingGraphExtractor,
GraphSchema,
)


URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"


async def main() -> None:
"""Run the example."""

with neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
) as driver:
extractor = SchemaFromExistingGraphExtractor(driver)
schema: GraphSchema = await extractor.run()
# schema.store_as_json("my_schema.json")
print(schema)


if __name__ == "__main__":
asyncio.run(main())
175 changes: 173 additions & 2 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from __future__ import annotations

import json

import neo4j
import logging
import warnings
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
Expand Down Expand Up @@ -44,6 +46,10 @@
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
from neo4j_graphrag.schema import get_structured_schema


logger = logging.getLogger(__name__)


class PropertyType(BaseModel):
Expand Down Expand Up @@ -294,7 +300,12 @@ def from_file(
raise SchemaValidationError(str(e)) from e


class SchemaBuilder(Component):
class BaseSchemaBuilder(Component):
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
raise NotImplementedError()


class SchemaBuilder(BaseSchemaBuilder):
"""
A builder class for constructing GraphSchema objects from given entities,
relations, and their interrelationships defined in a potential schema.
Expand Down Expand Up @@ -412,7 +423,7 @@ async def run(
)


class SchemaFromTextExtractor(Component):
class SchemaFromTextExtractor(BaseSchemaBuilder):
"""
A component for constructing GraphSchema objects from the output of an LLM after
automatic schema extraction from text.
Expand Down Expand Up @@ -620,3 +631,163 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
"patterns": extracted_patterns,
}
)


class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
"""A class to build a GraphSchema object from an existing graph.

Uses the get_structured_schema function to extract existing node labels,
relationship types, properties and existence constraints.

By default, the built schema does not allow any additional item (property,
node label, relationship type or pattern).

Args:
driver (neo4j.Driver): connection to the neo4j database.
additional_properties (bool, default False): see GraphSchema
additional_node_types (bool, default False): see GraphSchema
additional_relationship_types (bool, default False): see GraphSchema:
additional_patterns (bool, default False): see GraphSchema:
neo4j_database (Optional | str): name of the neo4j database to use
"""

def __init__(
self,
driver: neo4j.Driver,
additional_properties: bool = False,
additional_node_types: bool = False,
additional_relationship_types: bool = False,
additional_patterns: bool = False,
neo4j_database: Optional[str] = None,
) -> None:
self.driver = driver
self.database = neo4j_database

self.additional_properties = additional_properties
self.additional_node_types = additional_node_types
self.additional_relationship_types = additional_relationship_types
self.additional_patterns = additional_patterns

@staticmethod
def _extract_required_properties(
structured_schema: dict[str, Any],
) -> list[tuple[str, str]]:
"""Extract a list of (node label (or rel type), property name) for which
an "EXISTENCE" or "KEY" constraint is defined in the DB.

Args:

structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.

Returns:

list of tuples of (node label (or rel type), property name)

"""
schema_metadata = structured_schema.get("metadata", {})
existence_constraint = [] # list of (node label, property name)
for constraint in schema_metadata.get("constraint", []):
if constraint["type"] in (
"NODE_PROPERTY_EXISTENCE",
"NODE_KEY",
"RELATIONSHIP_PROPERTY_EXISTENCE",
"RELATIONSHIP_KEY",
):
properties = constraint["properties"]
labels = constraint["labelsOrTypes"]
# note: existence constraint only apply to a single property
# and a single label
prop = properties[0]
lab = labels[0]
existence_constraint.append((lab, prop))
return existence_constraint

async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
structured_schema = get_structured_schema(self.driver, database=self.database)
existence_constraint = self._extract_required_properties(structured_schema)

# node label with properties
node_labels = set(structured_schema["node_props"].keys())
node_types = [
{
"label": key,
"properties": [
{
"name": p["property"],
"type": p["type"],
"required": (key, p["property"]) in existence_constraint,
}
for p in properties
],
"additional_properties": self.additional_properties,
}
for key, properties in structured_schema["node_props"].items()
]

# relationships with properties
rel_labels = set(structured_schema["rel_props"].keys())
relationship_types = [
{
"label": key,
"properties": [
{
"name": p["property"],
"type": p["type"],
"required": (key, p["property"]) in existence_constraint,
}
for p in properties
],
}
for key, properties in structured_schema["rel_props"].items()
]

patterns = [
(s["start"], s["type"], s["end"])
for s in structured_schema["relationships"]
]

# deal with nodes and relationships without properties
for source, rel, target in patterns:
if source not in node_labels:
if not self.additional_properties:
logger.warning(
f"SCHEMA: found node label {source} without property and additional_properties=False: this node label will always be pruned!"
)
node_labels.add(source)
node_types.append(
{
"label": source,
}
)
if target not in node_labels:
if not self.additional_properties:
logger.warning(
f"SCHEMA: found node label {target} without property and additional_properties=False: this node label will always be pruned!"
)
node_labels.add(target)
node_types.append(
{
"label": target,
}
)
if rel not in rel_labels:
if not self.additional_properties:
logger.warning(
f"SCHEMA: found relationship type {rel} without property and additional_properties=False: this relationship type will always be pruned!"
)
rel_labels.add(rel)
relationship_types.append(
{
"label": rel,
}
)
return GraphSchema.model_validate(
{
"node_types": node_types,
"relationship_types": relationship_types,
"patterns": patterns,
"additional_node_types": self.additional_node_types,
"additional_relationship_types": self.additional_relationship_types,
"additional_patterns": self.additional_patterns,
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
SchemaBuilder,
GraphSchema,
SchemaFromTextExtractor,
BaseSchemaBuilder,
)
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
Expand Down Expand Up @@ -178,7 +179,7 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]:
def _get_chunk_embedder(self) -> TextChunkEmbedder:
return TextChunkEmbedder(embedder=self.get_default_embedder())

def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
def _get_schema(self) -> BaseSchemaBuilder:
"""
Get the appropriate schema component based on configuration.
Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema.
Expand Down
66 changes: 65 additions & 1 deletion tests/unit/experimental/components/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import json
from typing import Tuple, Any
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, patch, Mock

import pytest
from pydantic import ValidationError
Expand All @@ -29,6 +29,7 @@
RelationshipType,
SchemaFromTextExtractor,
GraphSchema,
SchemaFromExistingGraphExtractor,
)
import os
import tempfile
Expand Down Expand Up @@ -957,3 +958,66 @@ async def test_schema_from_text_filters_relationships_without_labels(
assert len(schema.patterns) == 2
assert ("Person", "WORKS_FOR", "Organization") in schema.patterns
assert ("Person", "MANAGES", "Organization") in schema.patterns


@pytest.mark.asyncio
@patch("neo4j_graphrag.experimental.components.schema.get_structured_schema")
async def test_schema_from_existing_graph(mock_get_structured_schema: Mock) -> None:
mock_get_structured_schema.return_value = {
"node_props": {
"Person": [
{"property": "id", "type": "INTEGER"},
{"property": "name", "type": "STRING"},
]
},
"rel_props": {"KNOWS": [{"property": "fromDate", "type": "DATE"}]},
"relationships": [
{"start": "Person", "type": "KNOWS", "end": "Person"},
{"start": "Person", "type": "LIVES_IN", "end": "City"},
],
"metadata": {
"constraint": [
{
"id": 7,
"name": "person_id",
"type": "NODE_PROPERTY_EXISTENCE",
"entityType": "NODE",
"labelsOrTypes": ["Person"],
"properties": ["id"],
"ownedIndex": "person_id",
"propertyType": None,
},
],
"index": [
{
"label": "Person",
"properties": ["name"],
"size": 2,
"type": "RANGE",
"valuesSelectivity": 1.0,
"distinctValues": 2.0,
},
],
},
}
driver = Mock()
schema_builder = SchemaFromExistingGraphExtractor(
driver=driver,
)
schema = await schema_builder.run()
assert isinstance(schema, GraphSchema)
assert len(schema.node_types) == 2
person_node_type = schema.node_type_from_label("Person")
assert person_node_type is not None
id_person_property = [p for p in person_node_type.properties if p.name == "id"][0]
assert id_person_property.required is True

assert schema.node_type_from_label("City") is not None
assert len(schema.relationship_types) == 2
assert schema.relationship_type_from_label("KNOWS") is not None
assert schema.relationship_type_from_label("LIVES_IN") is not None

assert schema.patterns == (
("Person", "KNOWS", "Person"),
("Person", "LIVES_IN", "City"),
)