Skip to content

Commit 800c857

Browse files
committed
query: add some set_parameter_* for HNSW objectbox#24
set_parameter_string set_parameter_int set_parameter_vector_f32
1 parent a398911 commit 800c857

File tree

5 files changed

+136
-43
lines changed

5 files changed

+136
-43
lines changed

objectbox/c.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import platform
1919
from objectbox.version import Version
2020
from typing import *
21+
import numpy as np
2122

2223
# This file contains C-API bindings based on lib/objectbox.h, linking to the 'objectbox' shared library.
2324
# The bindings are implementing using ctypes, see https://docs.python.org/dev/library/ctypes.html for introduction.
@@ -329,13 +330,20 @@ def c_voidp_as_bytes(voidp, size):
329330
return memoryview(ctypes.cast(voidp, ctypes.POINTER(ctypes.c_ubyte * size))[0]).tobytes()
330331

331332

332-
def py_list_to_c_array(py_list: List[Any], c_type):
333-
""" Converts the given python list into a C array. """
334-
return (c_type * len(py_list))(*py_list)
333+
def py_list_to_c_array(py_list: Union[List[Any], np.ndarray], c_type):
334+
""" Converts the given python list or ndarray into a C array. """
335+
if isinstance(py_list, np.ndarray):
336+
if py_list.ndim != 1:
337+
raise Exception(f"ndarray is expected to be 1-dimensional. Input shape: {py_list.shape}")
338+
return py_list.ctypes.data_as(ctypes.POINTER(c_type))
339+
elif isinstance(py_list, list):
340+
return (c_type * len(py_list))(*py_list)
341+
else:
342+
raise Exception(f"Unsupported Python list type: {type(py_list)}")
335343

336344

337-
def py_list_to_c_pointer(py_list: List[Any], c_type):
338-
""" Converts the given python list into a C array and returns a pointer type. """
345+
def py_list_to_c_pointer(py_list: Union[List[Any], np.ndarray], c_type):
346+
""" Converts the given python list or ndarray into a C array, and returns a pointer type. """
339347
return ctypes.cast(py_list_to_c_array(py_list, c_type), ctypes.POINTER(c_type))
340348

341349

@@ -676,8 +684,16 @@ def py_list_to_c_pointer(py_list: List[Any], c_type):
676684
# OBX_C_API obx_err obx_qb_param_alias(OBX_query_builder* builder, const char* alias);
677685
obx_qb_param_alias = c_fn_rc('obx_qb_param_alias', [OBX_query_builder_p, ctypes.c_char_p])
678686

687+
# OBX_C_API obx_err obx_query_param_string(OBX_query* query, obx_schema_id entity_id, obx_schema_id property_id, const char* value);
688+
obx_query_param_string = c_fn_rc('obx_query_param_string', [OBX_query_p, obx_schema_id, obx_schema_id, ctypes.c_char_p])
689+
690+
# OBX_C_API obx_err obx_query_param_int(OBX_query* query, obx_schema_id entity_id, obx_schema_id property_id, int64_t value);
691+
obx_query_param_int = c_fn_rc('obx_query_param_int', [OBX_query_p, obx_schema_id, obx_schema_id, ctypes.c_int64])
692+
679693
# OBX_C_API obx_err obx_query_param_vector_float32(OBX_query* query, obx_schema_id entity_id, obx_schema_id property_id, const float* value, size_t element_count);
680-
# TODO
694+
obx_query_param_vector_float32 = c_fn_rc('obx_query_param_vector_float32',
695+
[OBX_query_p, obx_schema_id, obx_schema_id, ctypes.POINTER(ctypes.c_float),
696+
ctypes.c_size_t])
681697

682698
# OBX_C_API obx_err obx_query_param_alias_vector_float32(OBX_query* query, const char* alias, const float* value, size_t element_count);
683699
# TODO

objectbox/model/entity.py

+11
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def get_property(self, name: str):
103103
return prop
104104
raise Exception(f"Property \"{name}\" not found in Entity: \"{self.name}\"")
105105

106+
def get_property_id(self, prop: Union[int, str, Property]) -> int:
107+
""" A convenient way to get the property ID regardless having its ID, name or Property. """
108+
if isinstance(prop, int):
109+
return prop # We already have it!
110+
elif isinstance(prop, str):
111+
return self.get_property(prop)._id
112+
elif isinstance(prop, Property):
113+
return prop._id
114+
else:
115+
raise Exception(f"Unsupported Property type: {type(prop)}")
116+
106117
def get_value(self, object, prop: Property):
107118
# in case value is not overwritten on the object, it's the Property object itself (= as defined in the Class)
108119
val = getattr(object, prop._name)

objectbox/query.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class Query:
1919
def __init__(self, c_query, box: 'Box'):
2020
self._c_query = c_query
2121
self._box = box
22+
self._entity = self._box._entity
2223
self._ob = box._ob
2324

2425
def find(self) -> list:
@@ -96,8 +97,31 @@ def remove(self) -> int:
9697
obx_query_remove(self._c_query, ctypes.byref(count))
9798
return int(count.value)
9899

99-
def offset(self, offset: int):
100-
return obx_query_offset(self._c_query, offset)
100+
def offset(self, offset: int) -> 'Query':
101+
obx_query_offset(self._c_query, offset)
102+
return self
101103

102-
def limit(self, limit: int):
103-
return obx_query_limit(self._c_query, limit)
104+
def limit(self, limit: int) -> 'Query':
105+
obx_query_limit(self._c_query, limit)
106+
return self
107+
108+
def set_parameter_string(self, prop: Union[int, str, 'Property'], value: str) -> 'Query':
109+
prop_id = self._entity.get_property_id(prop)
110+
obx_query_param_string(self._c_query, self._entity.id, prop_id, c_str(value))
111+
return self
112+
113+
def set_parameter_int(self, prop: Union[int, str, 'Property'], value: int) -> 'Query':
114+
prop_id = self._entity.get_property_id(prop)
115+
obx_query_param_int(self._c_query, self._entity.id, prop_id, value)
116+
return self
117+
118+
def set_parameter_vector_f32(self,
119+
prop: Union[int, str, 'Property'],
120+
value: Union[List[float], np.ndarray]) -> 'Query':
121+
if isinstance(value, np.ndarray) and value.dtype != np.float32:
122+
raise Exception(f"value dtype is expected to be np.float32, got: {value.dtype}")
123+
prop_id = self._entity.get_property_id(prop)
124+
c_value = py_list_to_c_array(value, ctypes.c_float)
125+
num_el = len(value)
126+
obx_query_param_vector_float32(self._c_query, self._entity.id, prop_id, c_value, num_el)
127+
return self

objectbox/query_builder.py

+21-32
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@ def __init__(self, ob: ObjectBox, box: 'Box'):
1414
self._entity = box._entity
1515
self._c_builder = obx_query_builder(ob._c_store, box._entity.id)
1616

17-
def _get_property_id(self, prop: Union[int, str, Property]) -> int:
18-
if type(prop) is int:
19-
return prop
20-
elif type(prop) is str:
21-
prop = self._entity.get_property(prop)
22-
return prop._id
23-
2417
def close(self) -> int:
2518
return obx_qb_close(self._c_builder)
2619

@@ -31,101 +24,97 @@ def error_message(self) -> str:
3124
return obx_qb_error_message(self._c_builder)
3225

3326
def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
34-
prop_id = self._get_property_id(prop)
27+
prop_id = self._entity.get_property_id(prop)
3528
cond = obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
3629
return cond
3730

3831
def not_equals_string(self, prop: Union[int, str, Property], value: str,
3932
case_sensitive: bool = True) -> obx_qb_cond:
40-
prop_id = self._get_property_id(prop)
33+
prop_id = self._entity.get_property_id(prop)
4134
cond = obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
4235
return cond
4336

4437
def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
45-
prop_id = self._get_property_id(prop)
38+
prop_id = self._entity.get_property_id(prop)
4639
cond = obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive)
4740
return cond
4841

4942
def starts_with_string(self, prop: Union[int, str, Property], value: str,
5043
case_sensitive: bool = True) -> obx_qb_cond:
51-
prop_id = self._get_property_id(prop)
44+
prop_id = self._entity.get_property_id(prop)
5245
cond = obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
5346
return cond
5447

5548
def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
56-
prop_id = self._get_property_id(prop)
49+
prop_id = self._entity.get_property_id(prop)
5750
cond = obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
5851
return cond
5952

6053
def greater_than_string(self, prop: Union[int, str, Property], value: str,
6154
case_sensitive: bool = True) -> obx_qb_cond:
62-
prop_id = self._get_property_id(prop)
55+
prop_id = self._entity.get_property_id(prop)
6356
cond = obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
6457
return cond
6558

6659
def greater_or_equal_string(self, prop: Union[int, str, Property], value: str,
6760
case_sensitive: bool = True) -> obx_qb_cond:
68-
prop_id = self._get_property_id(prop)
61+
prop_id = self._entity.get_property_id(prop)
6962
cond = obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
7063
return cond
7164

7265
def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
73-
prop_id = self._get_property_id(prop)
66+
prop_id = self._entity.get_property_id(prop)
7467
cond = obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
7568
return cond
7669

7770
def less_or_equal_string(self, prop: Union[int, str, Property], value: str,
7871
case_sensitive: bool = True) -> obx_qb_cond:
79-
prop_id = self._get_property_id(prop)
72+
prop_id = self._entity.get_property_id(prop)
8073
cond = obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
8174
return cond
8275

8376
def equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
84-
prop_id = self._get_property_id(prop)
77+
prop_id = self._entity.get_property_id(prop)
8578
cond = obx_qb_equals_int(self._c_builder, prop_id, value)
8679
return cond
8780

8881
def not_equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
89-
prop_id = self._get_property_id(prop)
82+
prop_id = self._entity.get_property_id(prop)
9083
cond = obx_qb_not_equals_int(self._c_builder, prop_id, value)
9184
return cond
9285

9386
def greater_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
94-
prop_id = self._get_property_id(prop)
87+
prop_id = self._entity.get_property_id(prop)
9588
cond = obx_qb_greater_than_int(self._c_builder, prop_id, value)
9689
return cond
9790

9891
def greater_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
99-
prop_id = self._get_property_id(prop)
92+
prop_id = self._entity.get_property_id(prop)
10093
cond = obx_qb_greater_or_equal_int(self._c_builder, prop_id, value)
10194
return cond
10295

10396
def less_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
104-
prop_id = self._get_property_id(prop)
97+
prop_id = self._entity.get_property_id(prop)
10598
cond = obx_qb_less_than_int(self._c_builder, prop_id, value)
10699
return cond
107100

108101
def less_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
109-
prop_id = self._get_property_id(prop)
102+
prop_id = self._entity.get_property_id(prop)
110103
cond = obx_qb_less_or_equal_int(self._c_builder, prop_id, value)
111104
return cond
112105

113106
def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int) -> obx_qb_cond:
114-
prop_id = self._get_property_id(prop)
107+
prop_id = self._entity.get_property_id(prop)
115108
cond = obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b)
116109
return cond
117110

118111
def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: Union[np.ndarray, List[float]],
119112
element_count: int):
120-
if isinstance(query_vector, np.ndarray):
121-
if query_vector.dtype != np.float32:
122-
raise Exception(f"query_vector dtype must be float32")
123-
query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
124-
else: # List[float]
125-
query_vector_data = py_list_to_c_array(query_vector, ctypes.c_float)
126-
127-
prop_id = self._get_property_id(prop)
128-
cond = obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count)
113+
if isinstance(query_vector, np.ndarray) and query_vector.dtype != np.float32:
114+
raise Exception(f"query_vector dtype is expected to be np.float32, got: {query_vector.dtype}")
115+
prop_id = self._entity.get_property_id(prop)
116+
c_query_vector = py_list_to_c_array(query_vector, ctypes.c_float)
117+
cond = obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, c_query_vector, element_count)
129118
return cond
130119

131120
def any(self, conditions: List[obx_qb_cond]) -> obx_qb_cond:

tests/test_query.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import objectbox
2+
from objectbox import *
23
from objectbox.model import *
34
from objectbox.c import *
5+
from objectbox.query import *
46
import pytest
57
from tests.common import (load_empty_test_objectbox, create_test_objectbox, autocleanup)
6-
from tests.model import TestEntity
8+
from tests.model import *
79

810

911
def test_basics():
@@ -224,3 +226,54 @@ def test_any_all():
224226
assert ids == [2, 3]
225227

226228

229+
def test_set_parameter():
230+
db = create_test_objectbox()
231+
232+
box_test_entity = objectbox.Box(db, TestEntity)
233+
box_test_entity.put(TestEntity(str="Foo", int64=2, int32=703, int8=101))
234+
box_test_entity.put(TestEntity(str="FooBar", int64=10, int32=49, int8=45))
235+
box_test_entity.put(TestEntity(str="Bar", int64=10, int32=226, int8=126))
236+
box_test_entity.put(TestEntity(str="Foster", int64=2, int32=301, int8=42))
237+
box_test_entity.put(TestEntity(str="Fox", int64=10, int32=157, int8=11))
238+
box_test_entity.put(TestEntity(str="Barrakuda", int64=4, int32=386, int8=60))
239+
240+
box_vector_entity = objectbox.Box(db, VectorEntity)
241+
box_vector_entity.put(VectorEntity(name="Object 1", vector=[1, 1]))
242+
box_vector_entity.put(VectorEntity(name="Object 2", vector=[2, 2]))
243+
box_vector_entity.put(VectorEntity(name="Object 3", vector=[3, 3]))
244+
box_vector_entity.put(VectorEntity(name="Object 4", vector=[4, 4]))
245+
box_vector_entity.put(VectorEntity(name="Object 5", vector=[5, 5]))
246+
247+
qb = box_test_entity.query()
248+
qb.starts_with_string("str", "fo", case_sensitive=False)
249+
qb.greater_than_int("int32", 150)
250+
qb.greater_than_int("int64", 0)
251+
query = qb.build()
252+
assert query.find_ids() == [1, 4, 5]
253+
254+
# Test set_parameter_string
255+
query.set_parameter_string("str", "bar")
256+
assert query.find_ids() == [3, 6]
257+
258+
# Test set_parameter_int
259+
query.set_parameter_int("int64", 8)
260+
assert query.find_ids() == [3]
261+
262+
qb = box_vector_entity.query()
263+
qb.nearest_neighbors_f32("vector", [3.4, 3.4], 3)
264+
query = qb.build()
265+
assert query.find_ids() == sorted([3, 4, 2])
266+
267+
# set_parameter_vector_f32
268+
# set_parameter_int (NN count)
269+
query.set_parameter_vector_f32("vector", [4.9, 4.9])
270+
assert query.find_ids() == sorted([5, 4, 3])
271+
272+
query.set_parameter_vector_f32("vector", [0, 0])
273+
assert query.find_ids() == sorted([1, 2, 3])
274+
275+
query.set_parameter_vector_f32("vector", [2.5, 2.1])
276+
assert query.find_ids() == sorted([2, 3, 1])
277+
278+
query.set_parameter_int("vector", 2)
279+
assert query.find_ids() == sorted([2, 3])

0 commit comments

Comments
 (0)