-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add identity_fields to DataPoint for declarative node deduplication #2125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| from uuid import UUID, uuid4 | ||
|
|
||
| from cognee.infrastructure.engine.models.DataPoint import DataPoint | ||
|
|
||
|
|
||
| class PersonWithIdentity(DataPoint): | ||
| name: str | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]} | ||
|
|
||
|
|
||
| class DepartmentWithIdentity(DataPoint): | ||
| name: str | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]} | ||
|
|
||
|
|
||
| class MultiFieldIdentity(DataPoint): | ||
| first_name: str | ||
| last_name: str | ||
| metadata: dict = { | ||
| "index_fields": ["first_name"], | ||
| "identity_fields": ["first_name", "last_name"], | ||
| } | ||
|
|
||
|
|
||
| class NoIdentityFields(DataPoint): | ||
| name: str | ||
| metadata: dict = {"index_fields": ["name"]} | ||
|
|
||
|
|
||
| class PartialIdentity(DataPoint): | ||
| name: str | ||
| age: int = 0 | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name", "age"]} | ||
|
|
||
|
|
||
| class TestSameValuesSameUUID: | ||
| def test_same_name_produces_same_id(self): | ||
| p1 = PersonWithIdentity(name="John") | ||
| p2 = PersonWithIdentity(name="John") | ||
| assert p1.id == p2.id | ||
|
|
||
| def test_deterministic_across_calls(self): | ||
| ids = [PersonWithIdentity(name="Alice").id for _ in range(10)] | ||
| assert len(set(ids)) == 1 | ||
|
|
||
|
|
||
| class TestDifferentValuesDifferentUUID: | ||
| def test_different_names_produce_different_ids(self): | ||
| p1 = PersonWithIdentity(name="John") | ||
| p2 = PersonWithIdentity(name="Jane") | ||
| assert p1.id != p2.id | ||
|
|
||
|
|
||
| class TestCrossTypeSafety: | ||
| def test_same_name_different_class_different_id(self): | ||
| person = PersonWithIdentity(name="Engineering") | ||
| department = DepartmentWithIdentity(name="Engineering") | ||
| assert person.id != department.id | ||
|
|
||
|
|
||
| class TestExplicitIdOverride: | ||
| def test_explicit_id_takes_precedence(self): | ||
| explicit_id = uuid4() | ||
| p = PersonWithIdentity(id=explicit_id, name="John") | ||
| assert p.id == explicit_id | ||
|
|
||
| def test_explicit_id_differs_from_generated(self): | ||
| explicit_id = uuid4() | ||
| p_explicit = PersonWithIdentity(id=explicit_id, name="John") | ||
| p_generated = PersonWithIdentity(name="John") | ||
| assert p_explicit.id != p_generated.id | ||
|
|
||
|
|
||
| class TestMissingIdentityFieldFallback: | ||
| def test_missing_field_falls_back_to_uuid4(self): | ||
| """When identity_fields references a field not in the data, fall back to UUID4.""" | ||
| p1 = PartialIdentity(name="John") | ||
| p2 = PartialIdentity(name="John") | ||
| # age has a default so it IS in data; both should match | ||
| assert p1.id == p2.id | ||
|
|
||
| def test_truly_missing_field(self): | ||
| """A class where identity_fields references a non-existent field.""" | ||
|
|
||
| class BadIdentity(DataPoint): | ||
| name: str | ||
| metadata: dict = { | ||
| "index_fields": ["name"], | ||
| "identity_fields": ["name", "nonexistent"], | ||
| } | ||
|
|
||
| b1 = BadIdentity(name="John") | ||
| b2 = BadIdentity(name="John") | ||
| # nonexistent is not in data, so falls back to UUID4 - different each time | ||
| assert b1.id != b2.id | ||
|
|
||
|
|
||
| class TestNoIdentityFieldsBackwardCompat: | ||
| def test_no_identity_fields_produces_random_uuid(self): | ||
| n1 = NoIdentityFields(name="John") | ||
| n2 = NoIdentityFields(name="John") | ||
| assert n1.id != n2.id | ||
|
|
||
| def test_base_datapoint_no_identity_fields(self): | ||
| dp1 = DataPoint() | ||
| dp2 = DataPoint() | ||
| assert dp1.id != dp2.id | ||
|
|
||
|
|
||
| class TestMultiFieldIdentity: | ||
| def test_same_multi_fields_same_id(self): | ||
| m1 = MultiFieldIdentity(first_name="John", last_name="Doe") | ||
| m2 = MultiFieldIdentity(first_name="John", last_name="Doe") | ||
| assert m1.id == m2.id | ||
|
|
||
| def test_different_multi_fields_different_id(self): | ||
| m1 = MultiFieldIdentity(first_name="John", last_name="Doe") | ||
| m2 = MultiFieldIdentity(first_name="John", last_name="Smith") | ||
| assert m1.id != m2.id | ||
|
|
||
| def test_field_order_matters(self): | ||
| """first_name='A', last_name='B' should differ from first_name='B', last_name='A'.""" | ||
| m1 = MultiFieldIdentity(first_name="A", last_name="B") | ||
| m2 = MultiFieldIdentity(first_name="B", last_name="A") | ||
| assert m1.id != m2.id | ||
|
|
||
|
|
||
| class TestStringNormalization: | ||
| def test_case_insensitive(self): | ||
| p1 = PersonWithIdentity(name="John") | ||
| p2 = PersonWithIdentity(name="JOHN") | ||
| assert p1.id == p2.id | ||
|
|
||
| def test_spaces_normalized(self): | ||
| p1 = PersonWithIdentity(name="John Doe") | ||
| p2 = PersonWithIdentity(name="John_Doe") | ||
| assert p1.id == p2.id | ||
|
|
||
| def test_apostrophes_removed(self): | ||
| p1 = PersonWithIdentity(name="O'Brien") | ||
| p2 = PersonWithIdentity(name="OBrien") | ||
| assert p1.id == p2.id | ||
|
|
||
|
|
||
| class TestIdIsUUID5: | ||
| def test_generated_id_is_valid_uuid(self): | ||
| p = PersonWithIdentity(name="Test") | ||
| assert isinstance(p.id, UUID) | ||
| # UUID5 has version == 5 | ||
| assert p.id.version == 5 | ||
|
|
||
| def test_no_identity_id_is_uuid4(self): | ||
| n = NoIdentityFields(name="Test") | ||
| assert isinstance(n.id, UUID) | ||
| assert n.id.version == 4 | ||
|
|
||
|
|
||
| class TestTypeFieldPreserved: | ||
| def test_type_is_class_name(self): | ||
| p = PersonWithIdentity(name="John") | ||
| assert p.type == "PersonWithIdentity" | ||
|
|
||
| def test_type_not_affected_by_identity_fields(self): | ||
| n = NoIdentityFields(name="John") | ||
| assert n.type == "NoIdentityFields" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,29 +16,27 @@ | |
|
|
||
| class Person(DataPoint): | ||
| name: str | ||
| # Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search | ||
| metadata: dict = {"index_fields": ["name"]} | ||
| # "index_fields": fields to embed for vector search | ||
| # "identity_fields": fields used to generate deterministic IDs (deduplication) | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]} | ||
|
|
||
|
|
||
| class Department(DataPoint): | ||
| name: str | ||
| employees: list[Person] | ||
| # Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search | ||
| metadata: dict = {"index_fields": ["name"]} | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]} | ||
|
|
||
|
|
||
| class CompanyType(DataPoint): | ||
| name: str = "Company" | ||
| # Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search | ||
| metadata: dict = {"index_fields": ["name"]} | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]} | ||
|
|
||
|
|
||
| class Company(DataPoint): | ||
| name: str | ||
| departments: list[Department] | ||
| is_type: CompanyType | ||
| # Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search | ||
| metadata: dict = {"index_fields": ["name"]} | ||
| metadata: dict = {"index_fields": ["name"], "identity_fields": ["name"]} | ||
|
|
||
|
|
||
| class Data(BaseModel): | ||
|
|
@@ -47,41 +45,40 @@ class Data(BaseModel): | |
|
|
||
|
|
||
| def ingest_files(data: List[Data]): | ||
| people_data_points = {} | ||
| departments_data_points = {} | ||
| companies_data_points = {} | ||
| # With identity_fields, DataPoints with the same name automatically get the same UUID. | ||
| # No manual dict-based deduplication needed — just create instances freely. | ||
| all_companies = [] | ||
|
|
||
| for data_item in data: | ||
| people = data_item.payload["people"] | ||
| companies = data_item.payload["companies"] | ||
|
|
||
| # Build departments with their employees | ||
| dept_employees: Dict[str, List[Person]] = {} | ||
| for person in people: | ||
| new_person = Person(name=person["name"]) | ||
| people_data_points[person["name"]] = new_person | ||
| dept_name = person["department"] | ||
| if dept_name not in dept_employees: | ||
| dept_employees[dept_name] = [] | ||
| dept_employees[dept_name].append(Person(name=person["name"])) | ||
|
|
||
| if person["department"] not in departments_data_points: | ||
| departments_data_points[person["department"]] = Department( | ||
| name=person["department"], employees=[new_person] | ||
| ) | ||
| else: | ||
| departments_data_points[person["department"]].employees.append(new_person) | ||
| departments = { | ||
| name: Department(name=name, employees=employees) | ||
| for name, employees in dept_employees.items() | ||
| } | ||
|
|
||
| # Create a single CompanyType node, so we connect all companies to it. | ||
| companyType = CompanyType() | ||
| # Create a single CompanyType node (deterministic ID via identity_fields) | ||
| company_type = CompanyType() | ||
|
|
||
| for company in companies: | ||
| new_company = Company(name=company["name"], departments=[], is_type=companyType) | ||
| companies_data_points[company["name"]] = new_company | ||
|
|
||
| for department_name in company["departments"]: | ||
| if department_name not in departments_data_points: | ||
| departments_data_points[department_name] = Department( | ||
| name=department_name, employees=[] | ||
| ) | ||
|
|
||
| new_company.departments.append(departments_data_points[department_name]) | ||
|
|
||
| return list(companies_data_points.values()) | ||
| company_departments = [ | ||
| departments.get(dept_name, Department(name=dept_name, employees=[])) | ||
| for dept_name in company["departments"] | ||
| ] | ||
| all_companies.append( | ||
| Company(name=company["name"], departments=company_departments, is_type=company_type) | ||
| ) | ||
|
|
||
| return all_companies | ||
|
Comment on lines
47
to
+81
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: find . -type f -name "pipeline.py" | head -20Repository: topoteretes/cognee Length of output: 144 🏁 Script executed: cat -n examples/low_level/pipeline.py | head -100Repository: topoteretes/cognee Length of output: 4037 🏁 Script executed: rg "^def ingest_files|^class Company|^class Data|^from typing" examples/low_level/pipeline.py -A 2 -B 2Repository: topoteretes/cognee Length of output: 700 Add return type annotation, variable type annotation, and docstring to ingest_files. The function lacks a return type annotation, a type hint for the ✏️ Suggested update-def ingest_files(data: List[Data]):
+def ingest_files(data: List[Data]) -> List[Company]:
+ """Build Company instances from raw payloads using identity-based deduplication."""
# With identity_fields, DataPoints with the same name automatically get the same UUID.
# No manual dict-based deduplication needed — just create instances freely.
- all_companies = []
+ all_companies: List[Company] = []🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| async def main(): | ||
|
|
||
This comment was marked as resolved.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.