Skip to content

Commit 2d06ac8

Browse files
authored
C++ backend for Binary Indexed Trees (#561)
1 parent c17fe9c commit 2d06ac8

File tree

4 files changed

+179
-4
lines changed

4 files changed

+179
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#ifndef TREES_BINARYINDEXEDTREE_HPP
2+
#define TREES_BINARYINDEXEDTREE_HPP
3+
4+
#define PY_SSIZE_T_CLEAN
5+
#include <Python.h>
6+
#include <structmember.h>
7+
#include <cstdlib>
8+
#include "../../../utils/_backend/cpp/utils.hpp"
9+
#include "../../../utils/_backend/cpp/TreeNode.hpp"
10+
#include "../../../linear_data_structures/_backend/cpp/arrays/ArrayForTrees.hpp"
11+
#include "../../../linear_data_structures/_backend/cpp/arrays/DynamicOneDimensionalArray.hpp"
12+
13+
// Copied binary trees and changed the name to BinaryIndexedTree
14+
// Start from the struct
15+
16+
typedef struct {
17+
PyObject_HEAD
18+
OneDimensionalArray* array;
19+
PyObject* tree;
20+
PyObject* flag;
21+
} BinaryIndexedTree;
22+
23+
static void BinaryIndexedTree_dealloc(BinaryIndexedTree *self) {
24+
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
25+
}
26+
27+
static PyObject* BinaryIndexedTree_update(BinaryIndexedTree* self, PyObject *args) {
28+
long index = PyLong_AsLong(PyObject_GetItem(args, PyZero));
29+
long value = PyLong_AsLong(PyObject_GetItem(args, PyOne));
30+
long _index = index;
31+
long _value = value;
32+
if (PyList_GetItem(self->flag, index) == PyZero) {
33+
PyList_SetItem(self->flag, index, PyOne);
34+
index += 1;
35+
while (index < self->array->_size + 1) {
36+
long curr = PyLong_AsLong(PyList_GetItem(self->tree, index));
37+
PyList_SetItem(self->tree, index, PyLong_FromLong(curr + value));
38+
index = index + (index & (-1*index));
39+
}
40+
}
41+
else {
42+
value = value - PyLong_AsLong(self->array->_data[index]);
43+
index += 1;
44+
while (index < self->array->_size + 1) {
45+
long curr = PyLong_AsLong(PyList_GetItem(self->tree, index));
46+
PyList_SetItem(self->tree, index, PyLong_FromLong(curr + value));
47+
index = index + (index & (-1*index));
48+
}
49+
}
50+
self->array->_data[_index] = PyLong_FromLong(_value);
51+
Py_RETURN_NONE;
52+
}
53+
54+
static PyObject* BinaryIndexedTree___new__(PyTypeObject* type, PyObject *args, PyObject *kwds) {
55+
BinaryIndexedTree *self;
56+
self = reinterpret_cast<BinaryIndexedTree*>(type->tp_alloc(type, 0));
57+
58+
// Python code is such that arguments are: type(array[0]) and array
59+
60+
if (PyType_Ready(&OneDimensionalArrayType) < 0) { // This has to be present to finalize a type object. This should be called on all type objects to finish their initialization.
61+
return NULL;
62+
}
63+
PyObject* _one_dimensional_array = OneDimensionalArray___new__(&OneDimensionalArrayType, args, kwds);
64+
if ( !_one_dimensional_array ) {
65+
return NULL;
66+
}
67+
self->array = reinterpret_cast<OneDimensionalArray*>(_one_dimensional_array);
68+
self->tree = PyList_New(self->array->_size+2);
69+
for(int i=0;i<self->array->_size+2;i++){
70+
PyList_SetItem(self->tree, i, PyZero);
71+
}
72+
self->flag = PyList_New(self->array->_size);
73+
for(int i=0;i<self->array->_size;i++){
74+
PyList_SetItem(self->flag, i, PyZero);
75+
BinaryIndexedTree_update(self, Py_BuildValue("(OO)", PyLong_FromLong(i), self->array->_data[i]));
76+
}
77+
78+
return reinterpret_cast<PyObject*>(self);
79+
}
80+
81+
static PyObject* BinaryIndexedTree_get_prefix_sum(BinaryIndexedTree* self, PyObject *args) {
82+
long index = PyLong_AsLong(PyObject_GetItem(args, PyZero));
83+
index += 1;
84+
long sum = 0;
85+
while (index > 0) {
86+
sum += PyLong_AsLong(PyList_GetItem(self->tree, index));
87+
index = index - (index & (-1*index));
88+
}
89+
90+
return PyLong_FromLong(sum);
91+
}
92+
93+
static PyObject* BinaryIndexedTree_get_sum(BinaryIndexedTree* self, PyObject *args) {
94+
long left_index = PyLong_AsLong(PyObject_GetItem(args, PyZero));
95+
long right_index = PyLong_AsLong(PyObject_GetItem(args, PyOne));
96+
if (left_index >= 1) {
97+
long l1 = PyLong_AsLong(BinaryIndexedTree_get_prefix_sum(self, Py_BuildValue("(O)", PyLong_FromLong(right_index))));
98+
long l2 = PyLong_AsLong(BinaryIndexedTree_get_prefix_sum(self, Py_BuildValue("(O)", PyLong_FromLong(left_index - 1))));
99+
return PyLong_FromLong(l1 - l2);
100+
}
101+
else {
102+
return BinaryIndexedTree_get_prefix_sum(self, Py_BuildValue("(O)", PyLong_FromLong(right_index)));
103+
}
104+
}
105+
106+
107+
static struct PyMethodDef BinaryIndexedTree_PyMethodDef[] = {
108+
{"update", (PyCFunction) BinaryIndexedTree_update, METH_VARARGS, NULL},
109+
{"get_prefix_sum", (PyCFunction) BinaryIndexedTree_get_prefix_sum, METH_VARARGS, NULL},
110+
{"get_sum", (PyCFunction) BinaryIndexedTree_get_sum, METH_VARARGS, NULL},
111+
{NULL}
112+
};
113+
114+
static PyMemberDef BinaryIndexedTree_PyMemberDef[] = {
115+
{"array", T_OBJECT_EX, offsetof(BinaryIndexedTree, array), 0, "array"},
116+
{"tree", T_OBJECT_EX, offsetof(BinaryIndexedTree, tree), 0, "tree"},
117+
{"flag", T_OBJECT_EX, offsetof(BinaryIndexedTree, flag), 0, "flag"},
118+
{NULL} /* Sentinel */
119+
};
120+
121+
122+
static PyTypeObject BinaryIndexedTreeType = {
123+
/* tp_name */ PyVarObject_HEAD_INIT(NULL, 0) "BinaryIndexedTree",
124+
/* tp_basicsize */ sizeof(BinaryIndexedTree),
125+
/* tp_itemsize */ 0,
126+
/* tp_dealloc */ (destructor) BinaryIndexedTree_dealloc,
127+
/* tp_print */ 0,
128+
/* tp_getattr */ 0,
129+
/* tp_setattr */ 0,
130+
/* tp_reserved */ 0,
131+
/* tp_repr */ 0,
132+
/* tp_as_number */ 0,
133+
/* tp_as_sequence */ 0,
134+
/* tp_as_mapping */ 0,
135+
/* tp_hash */ 0,
136+
/* tp_call */ 0,
137+
/* tp_str */ 0,
138+
/* tp_getattro */ 0,
139+
/* tp_setattro */ 0,
140+
/* tp_as_buffer */ 0,
141+
/* tp_flags */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
142+
/* tp_doc */ 0,
143+
/* tp_traverse */ 0,
144+
/* tp_clear */ 0,
145+
/* tp_richcompare */ 0,
146+
/* tp_weaklistoffset */ 0,
147+
/* tp_iter */ 0,
148+
/* tp_iternext */ 0,
149+
/* tp_methods */ BinaryIndexedTree_PyMethodDef,
150+
/* tp_members */ BinaryIndexedTree_PyMemberDef,
151+
/* tp_getset */ 0,
152+
/* tp_base */ &PyBaseObject_Type,
153+
/* tp_dict */ 0,
154+
/* tp_descr_get */ 0,
155+
/* tp_descr_set */ 0,
156+
/* tp_dictoffset */ 0,
157+
/* tp_init */ 0,
158+
/* tp_alloc */ 0,
159+
/* tp_new */ BinaryIndexedTree___new__,
160+
};
161+
162+
#endif

pydatastructs/trees/_backend/cpp/trees.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "BinaryTreeTraversal.hpp"
55
#include "SelfBalancingBinaryTree.hpp"
66
#include "RedBlackTree.hpp"
7+
#include "BinaryIndexedTree.hpp"
78

89
static struct PyModuleDef trees_struct = {
910
PyModuleDef_HEAD_INIT,
@@ -47,5 +48,11 @@ PyMODINIT_FUNC PyInit__trees(void) {
4748
Py_INCREF(&RedBlackTreeType);
4849
PyModule_AddObject(trees, "RedBlackTree", reinterpret_cast<PyObject*>(&RedBlackTreeType));
4950

51+
if (PyType_Ready(&BinaryIndexedTreeType) < 0) {
52+
return NULL;
53+
}
54+
Py_INCREF(&BinaryIndexedTreeType);
55+
PyModule_AddObject(trees, "BinaryIndexedTree", reinterpret_cast<PyObject*>(&BinaryIndexedTreeType));
56+
5057
return trees;
5158
}

pydatastructs/trees/binary_trees.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1760,8 +1760,9 @@ class BinaryIndexedTree(object):
17601760
__slots__ = ['tree', 'array', 'flag']
17611761

17621762
def __new__(cls, array, **kwargs):
1763-
raise_if_backend_is_not_python(
1764-
cls, kwargs.get('backend', Backend.PYTHON))
1763+
backend = kwargs.get('backend', Backend.PYTHON)
1764+
if backend == Backend.CPP:
1765+
return _trees.BinaryIndexedTree(type(array[0]), array, **kwargs)
17651766
obj = object.__new__(cls)
17661767
obj.array = OneDimensionalArray(type(array[0]), array)
17671768
obj.tree = [0] * (obj.array._size + 2)

pydatastructs/trees/tests/test_binary_trees.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,11 @@ def test_select_rank(expected_output):
361361
test_select_rank([])
362362

363363

364-
def test_BinaryIndexedTree():
364+
def _test_BinaryIndexedTree(backend):
365365

366366
FT = BinaryIndexedTree
367367

368-
t = FT([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
368+
t = FT([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], backend=backend)
369369

370370
assert t.get_sum(0, 2) == 6
371371
assert t.get_sum(0, 4) == 15
@@ -375,6 +375,11 @@ def test_BinaryIndexedTree():
375375
assert t.get_sum(0, 4) == 114
376376
assert t.get_sum(1, 9) == 54
377377

378+
def test_BinaryIndexedTree():
379+
_test_BinaryIndexedTree(Backend.PYTHON)
380+
381+
def test_cpp_BinaryIndexedTree():
382+
_test_BinaryIndexedTree(Backend.CPP)
378383

379384
def test_CartesianTree():
380385
tree = CartesianTree()

0 commit comments

Comments
 (0)