Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions cognee/infrastructure/engine/models/DataPoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID, uuid4
from uuid import UUID, NAMESPACE_OID, uuid4, uuid5
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime, timezone
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict
from typing import Optional, Any, Dict, List


Expand All @@ -13,6 +13,7 @@ class MetaData(TypedDict):

type: str
index_fields: list[str]
identity_fields: NotRequired[list[str]]


# Updated DataPoint model with versioning and new fields
Expand Down Expand Up @@ -48,9 +49,49 @@ class DataPoint(BaseModel):
belongs_to_set: Optional[List["DataPoint"]] = None

def __init__(self, **data):
if "id" not in data:
identity_fields = self.__class__._get_identity_fields()
if identity_fields:
identity_id = self.__class__._generate_identity_id(
identity_fields, data, self.__class__.__name__
)
if identity_id is not None:
data["id"] = identity_id

Comment on lines 51 to +60

This comment was marked as resolved.

super().__init__(**data)
object.__setattr__(self, "type", self.__class__.__name__)

@classmethod
def _get_identity_fields(cls) -> Optional[list[str]]:
"""Get identity_fields from the class's metadata field default, if defined."""
metadata_field = cls.model_fields.get("metadata")
if metadata_field is not None and metadata_field.default is not None:
return metadata_field.default.get("identity_fields")
return None

@staticmethod
def _generate_identity_id(
identity_fields: list[str], data: dict, class_name: str
) -> Optional[UUID]:
"""Generate a deterministic UUID5 from identity field values.

Returns None if any identity field is missing from data,
causing fallback to the default UUID4.
"""
parts = []
for field_name in identity_fields:
if field_name not in data:
return None
value = data[field_name]
if isinstance(value, str):
value = value.lower().replace(" ", "_").replace("'", "")
else:
value = str(value)
parts.append(value)
joined = "|".join(parts)
identity_string = f"{class_name}:{joined}"
return uuid5(NAMESPACE_OID, identity_string)

@classmethod
def get_embeddable_data(self, data_point: "DataPoint"):
"""
Expand Down
165 changes: 165 additions & 0 deletions cognee/tests/unit/infrastructure/engine/test_identity_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from uuid import UUID, uuid4

from cognee.infrastructure.engine.models.DataPoint import DataPoint


class PersonWithIdentity(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]}


class DepartmentWithIdentity(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]}


class MultiFieldIdentity(DataPoint):
first_name: str
last_name: str
metadata: dict = {
"index_fields": ["first_name"],
"identity_fields": ["first_name", "last_name"],
}


class NoIdentityFields(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}


class PartialIdentity(DataPoint):
name: str
age: int = 0
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name", "age"]}


class TestSameValuesSameUUID:
def test_same_name_produces_same_id(self):
p1 = PersonWithIdentity(name="John")
p2 = PersonWithIdentity(name="John")
assert p1.id == p2.id

def test_deterministic_across_calls(self):
ids = [PersonWithIdentity(name="Alice").id for _ in range(10)]
assert len(set(ids)) == 1


class TestDifferentValuesDifferentUUID:
def test_different_names_produce_different_ids(self):
p1 = PersonWithIdentity(name="John")
p2 = PersonWithIdentity(name="Jane")
assert p1.id != p2.id


class TestCrossTypeSafety:
def test_same_name_different_class_different_id(self):
person = PersonWithIdentity(name="Engineering")
department = DepartmentWithIdentity(name="Engineering")
assert person.id != department.id


class TestExplicitIdOverride:
def test_explicit_id_takes_precedence(self):
explicit_id = uuid4()
p = PersonWithIdentity(id=explicit_id, name="John")
assert p.id == explicit_id

def test_explicit_id_differs_from_generated(self):
explicit_id = uuid4()
p_explicit = PersonWithIdentity(id=explicit_id, name="John")
p_generated = PersonWithIdentity(name="John")
assert p_explicit.id != p_generated.id


class TestMissingIdentityFieldFallback:
def test_missing_field_falls_back_to_uuid4(self):
"""When identity_fields references a field not in the data, fall back to UUID4."""
p1 = PartialIdentity(name="John")
p2 = PartialIdentity(name="John")
# age has a default so it IS in data; both should match
assert p1.id == p2.id

def test_truly_missing_field(self):
"""A class where identity_fields references a non-existent field."""

class BadIdentity(DataPoint):
name: str
metadata: dict = {
"index_fields": ["name"],
"identity_fields": ["name", "nonexistent"],
}

b1 = BadIdentity(name="John")
b2 = BadIdentity(name="John")
# nonexistent is not in data, so falls back to UUID4 - different each time
assert b1.id != b2.id


class TestNoIdentityFieldsBackwardCompat:
def test_no_identity_fields_produces_random_uuid(self):
n1 = NoIdentityFields(name="John")
n2 = NoIdentityFields(name="John")
assert n1.id != n2.id

def test_base_datapoint_no_identity_fields(self):
dp1 = DataPoint()
dp2 = DataPoint()
assert dp1.id != dp2.id


class TestMultiFieldIdentity:
def test_same_multi_fields_same_id(self):
m1 = MultiFieldIdentity(first_name="John", last_name="Doe")
m2 = MultiFieldIdentity(first_name="John", last_name="Doe")
assert m1.id == m2.id

def test_different_multi_fields_different_id(self):
m1 = MultiFieldIdentity(first_name="John", last_name="Doe")
m2 = MultiFieldIdentity(first_name="John", last_name="Smith")
assert m1.id != m2.id

def test_field_order_matters(self):
"""first_name='A', last_name='B' should differ from first_name='B', last_name='A'."""
m1 = MultiFieldIdentity(first_name="A", last_name="B")
m2 = MultiFieldIdentity(first_name="B", last_name="A")
assert m1.id != m2.id


class TestStringNormalization:
def test_case_insensitive(self):
p1 = PersonWithIdentity(name="John")
p2 = PersonWithIdentity(name="JOHN")
assert p1.id == p2.id

def test_spaces_normalized(self):
p1 = PersonWithIdentity(name="John Doe")
p2 = PersonWithIdentity(name="John_Doe")
assert p1.id == p2.id

def test_apostrophes_removed(self):
p1 = PersonWithIdentity(name="O'Brien")
p2 = PersonWithIdentity(name="OBrien")
assert p1.id == p2.id


class TestIdIsUUID5:
def test_generated_id_is_valid_uuid(self):
p = PersonWithIdentity(name="Test")
assert isinstance(p.id, UUID)
# UUID5 has version == 5
assert p.id.version == 5

def test_no_identity_id_is_uuid4(self):
n = NoIdentityFields(name="Test")
assert isinstance(n.id, UUID)
assert n.id.version == 4


class TestTypeFieldPreserved:
def test_type_is_class_name(self):
p = PersonWithIdentity(name="John")
assert p.type == "PersonWithIdentity"

def test_type_not_affected_by_identity_fields(self):
n = NoIdentityFields(name="John")
assert n.type == "NoIdentityFields"
63 changes: 30 additions & 33 deletions examples/low_level/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,27 @@

class Person(DataPoint):
name: str
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
# "index_fields": fields to embed for vector search
# "identity_fields": fields used to generate deterministic IDs (deduplication)
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]}


class Department(DataPoint):
name: str
employees: list[Person]
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]}


class CompanyType(DataPoint):
name: str = "Company"
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]}


class Company(DataPoint):
name: str
departments: list[Department]
is_type: CompanyType
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]}


class Data(BaseModel):
Expand All @@ -47,41 +45,40 @@ class Data(BaseModel):


def ingest_files(data: List[Data]):
people_data_points = {}
departments_data_points = {}
companies_data_points = {}
# With identity_fields, DataPoints with the same name automatically get the same UUID.
# No manual dict-based deduplication needed — just create instances freely.
all_companies = []

for data_item in data:
people = data_item.payload["people"]
companies = data_item.payload["companies"]

# Build departments with their employees
dept_employees: Dict[str, List[Person]] = {}
for person in people:
new_person = Person(name=person["name"])
people_data_points[person["name"]] = new_person
dept_name = person["department"]
if dept_name not in dept_employees:
dept_employees[dept_name] = []
dept_employees[dept_name].append(Person(name=person["name"]))

if person["department"] not in departments_data_points:
departments_data_points[person["department"]] = Department(
name=person["department"], employees=[new_person]
)
else:
departments_data_points[person["department"]].employees.append(new_person)
departments = {
name: Department(name=name, employees=employees)
for name, employees in dept_employees.items()
}

# Create a single CompanyType node, so we connect all companies to it.
companyType = CompanyType()
# Create a single CompanyType node (deterministic ID via identity_fields)
company_type = CompanyType()

for company in companies:
new_company = Company(name=company["name"], departments=[], is_type=companyType)
companies_data_points[company["name"]] = new_company

for department_name in company["departments"]:
if department_name not in departments_data_points:
departments_data_points[department_name] = Department(
name=department_name, employees=[]
)

new_company.departments.append(departments_data_points[department_name])

return list(companies_data_points.values())
company_departments = [
departments.get(dept_name, Department(name=dept_name, employees=[]))
for dept_name in company["departments"]
]
all_companies.append(
Company(name=company["name"], departments=company_departments, is_type=company_type)
)

return all_companies
Comment on lines 47 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "pipeline.py" | head -20

Repository: topoteretes/cognee

Length of output: 144


🏁 Script executed:

cat -n examples/low_level/pipeline.py | head -100

Repository: topoteretes/cognee

Length of output: 4037


🏁 Script executed:

rg "^def ingest_files|^class Company|^class Data|^from typing" examples/low_level/pipeline.py -A 2 -B 2

Repository: topoteretes/cognee

Length of output: 700


Add return type annotation, variable type annotation, and docstring to ingest_files.

The function lacks a return type annotation, a type hint for the all_companies variable, and a docstring. This violates the project's requirements for type hints (mypy checks enabled) and for documenting function definitions.

✏️ Suggested update
-def ingest_files(data: List[Data]):
+def ingest_files(data: List[Data]) -> List[Company]:
+    """Build Company instances from raw payloads using identity-based deduplication."""
     # With identity_fields, DataPoints with the same name automatically get the same UUID.
     # No manual dict-based deduplication needed — just create instances freely.
-    all_companies = []
+    all_companies: List[Company] = []
🤖 Prompt for AI Agents
In `@examples/low_level/pipeline.py` around lines 47 - 81, Add a docstring to
ingest_files explaining its purpose and inputs, annotate its signature with the
correct return type (e.g., -> List[Company]) and add an explicit type for the
local variable all_companies (e.g., all_companies: List[Company]) so mypy can
validate; update imports if necessary to reference Company/Department/Person
types used in the annotations and keep the implementation unchanged (function
name ingest_files, local variable all_companies).



async def main():
Expand Down
Loading