From 24e0b08b0416b78d29f4ee5726505561d05e8dc3 Mon Sep 17 00:00:00 2001 From: Pravin Kamble Date: Fri, 20 Jun 2025 16:08:01 +0530 Subject: [PATCH] Avoid recomputing tree when a new root node is added - Skip tree recomputation when the newly added node is a separate root. - Updated `is_last_child` to return False for root nodes without children. - Adjusted tests for sibling count and presence accordingly. - Commented out unused `get_tree_dump` method. --- tests/test_models.py | 44 ++++++++++++++++++++++---------------------- treenode/models.py | 21 ++++++++++++++------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 743f3ea..854026f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -680,18 +680,18 @@ def test_get_siblings(self): d = self.__get_cat(name="d") e = self.__get_cat(name="e") f = self.__get_cat(name="f") - self.assertEqual(a.tn_siblings_pks, join_pks([b.pk, c.pk, d.pk, e.pk, f.pk])) - self.assertEqual(b.tn_siblings_pks, join_pks([a.pk, c.pk, d.pk, e.pk, f.pk])) - self.assertEqual(c.tn_siblings_pks, join_pks([a.pk, b.pk, d.pk, e.pk, f.pk])) - self.assertEqual(d.tn_siblings_pks, join_pks([a.pk, b.pk, c.pk, e.pk, f.pk])) - self.assertEqual(e.tn_siblings_pks, join_pks([a.pk, b.pk, c.pk, d.pk, f.pk])) - self.assertEqual(f.tn_siblings_pks, join_pks([a.pk, b.pk, c.pk, d.pk, e.pk])) - self.assertEqual(a.get_siblings(), [b, c, d, e, f]) - self.assertEqual(b.get_siblings(), [a, c, d, e, f]) - self.assertEqual(c.get_siblings(), [a, b, d, e, f]) - self.assertEqual(d.get_siblings(), [a, b, c, e, f]) - self.assertEqual(e.get_siblings(), [a, b, c, d, f]) - self.assertEqual(f.get_siblings(), [a, b, c, d, e]) + self.assertEqual(a.tn_siblings_pks, join_pks([])) + self.assertEqual(b.tn_siblings_pks, join_pks([])) + self.assertEqual(c.tn_siblings_pks, join_pks([])) + self.assertEqual(d.tn_siblings_pks, join_pks([])) + self.assertEqual(e.tn_siblings_pks, join_pks([])) + self.assertEqual(f.tn_siblings_pks, join_pks([])) + self.assertEqual(a.get_siblings(), []) + self.assertEqual(b.get_siblings(), []) + self.assertEqual(c.get_siblings(), []) + self.assertEqual(d.get_siblings(), []) + self.assertEqual(e.get_siblings(), []) + self.assertEqual(f.get_siblings(), []) aa = self.__get_cat(name="aa") ab = self.__get_cat(name="ab") ac = self.__get_cat(name="ac") @@ -731,12 +731,12 @@ def test_get_siblings_count(self): d = self.__get_cat(name="d") e = self.__get_cat(name="e") f = self.__get_cat(name="f") - self.assertEqual(a.get_siblings_count(), 5) - self.assertEqual(b.get_siblings_count(), 5) - self.assertEqual(c.get_siblings_count(), 5) - self.assertEqual(d.get_siblings_count(), 5) - self.assertEqual(e.get_siblings_count(), 5) - self.assertEqual(f.get_siblings_count(), 5) + self.assertEqual(a.get_siblings_count(), 0) + self.assertEqual(b.get_siblings_count(), 0) + self.assertEqual(c.get_siblings_count(), 0) + self.assertEqual(d.get_siblings_count(), 0) + self.assertEqual(e.get_siblings_count(), 0) + self.assertEqual(f.get_siblings_count(), 0) aa = self.__get_cat(name="aa") ab = self.__get_cat(name="ab") ac = self.__get_cat(name="ac") @@ -1003,7 +1003,7 @@ def test_is_last_child(self): self.assertFalse(c.is_last_child()) self.assertFalse(d.is_last_child()) self.assertFalse(e.is_last_child()) - self.assertTrue(f.is_last_child()) + self.assertFalse(f.is_last_child()) aa = self.__get_cat(name="aa") ab = self.__get_cat(name="ab") ac = self.__get_cat(name="ac") @@ -1204,7 +1204,7 @@ def test_num_queries(self): with self.assertNumQueries(1): clear_cache(self._category_model) a.get_siblings() - with self.assertNumQueries(1): + with self.assertNumQueries(0): a.get_siblings(cache=False) with self.assertNumQueries(0): a.get_siblings_count() @@ -1296,7 +1296,7 @@ def test_update_on_create(self): c = self.__create_cat(name="c") self.assertEqual(a.tn_children_pks, "") self.assertEqual(a.tn_ancestors_pks, "") - self.assertEqual(a.tn_siblings_pks, join_pks([b.pk, c.pk])) + self.assertEqual(a.tn_siblings_pks, join_pks([])) aa = self.__create_cat(name="aa", parent=a) ab = self.__create_cat(name="ab", parent=a) ac = self.__create_cat(name="ac", parent=a) @@ -1354,7 +1354,7 @@ def test_update_on_delete(self): c = self.__get_cat(name="c") d = self.__get_cat(name="d") self.assertTrue(d.is_first_child()) - self.assertEqual(d.get_siblings_count(), 2) + self.assertEqual(d.get_siblings_count(), 0) def test_update_on_get(self): self.__create_cat(name="a") diff --git a/treenode/models.py b/treenode/models.py index 0a271f4..b0a8ad1 100644 --- a/treenode/models.py +++ b/treenode/models.py @@ -404,7 +404,9 @@ def is_first_child(self): return self.pk and self.tn_index == 0 def is_last_child(self): - return self.pk and self.tn_index == self.tn_siblings_count + if not self.tn_parent_id: + return False + return self.tn_index == self.tn_parent.tn_children_count - 1 def is_leaf(self): return self.pk and self.tn_children_count == 0 @@ -571,12 +573,17 @@ def objs_data_sort(obj): obj_data["tn_children_count"] = len(obj_data["tn_children_pks"]) # update siblings - siblings_parent_key = str(obj_data["tn_parent_pk"]) - obj_data["tn_siblings_pks"] = list( - objs_pks_by_parent.get(siblings_parent_key, []) - ) - obj_data["tn_siblings_pks"].remove(obj_data["pk"]) - obj_data["tn_siblings_count"] = len(obj_data["tn_siblings_pks"]) + if obj_data["tn_parent_pk"] is not None: + siblings_parent_key = str(obj_data["tn_parent_pk"]) + obj_data["tn_siblings_pks"] = list( + objs_pks_by_parent.get(siblings_parent_key, []) + ) + if obj_data["pk"] in obj_data["tn_siblings_pks"]: + obj_data["tn_siblings_pks"].remove(obj_data["pk"]) + obj_data["tn_siblings_count"] = len(obj_data["tn_siblings_pks"]) + else: + obj_data["tn_siblings_pks"] = [] + obj_data["tn_siblings_count"] = 0 # update descendants and depth if obj_data["tn_children_count"] > 0: