Skip to content

Commit a398911

Browse files
committed
query: add any/all objectbox#24
1 parent 557738b commit a398911

File tree

4 files changed

+212
-121
lines changed

4 files changed

+212
-121
lines changed

objectbox/c.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def shlib_name(library: str) -> str:
6868
obx_id = ctypes.c_uint64
6969
obx_qb_cond = ctypes.c_int
7070

71+
obx_qb_cond_p = ctypes.POINTER(obx_qb_cond)
72+
7173
# enums
7274
OBXPropertyType = ctypes.c_int
7375
OBXPropertyFlags = ctypes.c_int
@@ -327,6 +329,16 @@ def c_voidp_as_bytes(voidp, size):
327329
return memoryview(ctypes.cast(voidp, ctypes.POINTER(ctypes.c_ubyte * size))[0]).tobytes()
328330

329331

332+
def py_list_to_c_array(py_list: List[Any], c_type):
333+
""" Converts the given python list into a C array. """
334+
return (c_type * len(py_list))(*py_list)
335+
336+
337+
def py_list_to_c_pointer(py_list: List[Any], c_type):
338+
""" Converts the given python list into a C array and returns a pointer type. """
339+
return ctypes.cast(py_list_to_c_array(py_list, c_type), ctypes.POINTER(c_type))
340+
341+
330342
# OBX_model* (void);
331343
obx_model = c_fn('obx_model', OBX_model_p, [])
332344

@@ -656,10 +668,10 @@ def c_voidp_as_bytes(voidp, size):
656668
[OBX_query_builder_p, obx_schema_id, ctypes.c_void_p, ctypes.c_size_t])
657669

658670
# OBX_C_API obx_qb_cond obx_qb_all(OBX_query_builder* builder, const obx_qb_cond conditions[], size_t count);
659-
obx_qb_all = c_fn('obx_qb_all', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond, ctypes.c_size_t])
671+
obx_qb_all = c_fn('obx_qb_all', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond_p, ctypes.c_size_t])
660672

661673
# OBX_C_API obx_qb_cond obx_qb_any(OBX_query_builder* builder, const obx_qb_cond conditions[], size_t count);
662-
obx_qb_any = c_fn('obx_qb_any', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond, ctypes.c_size_t])
674+
obx_qb_any = c_fn('obx_qb_any', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond_p, ctypes.c_size_t])
663675

664676
# OBX_C_API obx_err obx_qb_param_alias(OBX_query_builder* builder, const char* alias);
665677
obx_qb_param_alias = c_fn_rc('obx_qb_param_alias', [OBX_query_builder_p, ctypes.c_char_p])

objectbox/query_builder.py

+66-51
Original file line numberDiff line numberDiff line change
@@ -30,85 +30,90 @@ def error_code(self) -> int:
3030
def error_message(self) -> str:
3131
return obx_qb_error_message(self._c_builder)
3232

33-
def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
33+
def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
3434
prop_id = self._get_property_id(prop)
35-
obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
36-
return self
35+
cond = obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
36+
return cond
3737

38-
def not_equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
38+
def not_equals_string(self, prop: Union[int, str, Property], value: str,
39+
case_sensitive: bool = True) -> obx_qb_cond:
3940
prop_id = self._get_property_id(prop)
40-
obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
41-
return self
41+
cond = obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
42+
return cond
4243

43-
def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
44+
def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
4445
prop_id = self._get_property_id(prop)
45-
obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive)
46-
return self
46+
cond = obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive)
47+
return cond
4748

48-
def starts_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
49+
def starts_with_string(self, prop: Union[int, str, Property], value: str,
50+
case_sensitive: bool = True) -> obx_qb_cond:
4951
prop_id = self._get_property_id(prop)
50-
obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
51-
return self
52+
cond = obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
53+
return cond
5254

53-
def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
55+
def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
5456
prop_id = self._get_property_id(prop)
55-
obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
56-
return self
57+
cond = obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
58+
return cond
5759

58-
def greater_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
60+
def greater_than_string(self, prop: Union[int, str, Property], value: str,
61+
case_sensitive: bool = True) -> obx_qb_cond:
5962
prop_id = self._get_property_id(prop)
60-
obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
61-
return self
63+
cond = obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
64+
return cond
6265

63-
def greater_or_equal_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
66+
def greater_or_equal_string(self, prop: Union[int, str, Property], value: str,
67+
case_sensitive: bool = True) -> obx_qb_cond:
6468
prop_id = self._get_property_id(prop)
65-
obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
66-
return self
69+
cond = obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
70+
return cond
6771

68-
def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
72+
def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
6973
prop_id = self._get_property_id(prop)
70-
obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
71-
return self
74+
cond = obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
75+
return cond
7276

73-
def less_or_equal_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
77+
def less_or_equal_string(self, prop: Union[int, str, Property], value: str,
78+
case_sensitive: bool = True) -> obx_qb_cond:
7479
prop_id = self._get_property_id(prop)
75-
obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
76-
return self
80+
cond = obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
81+
return cond
7782

78-
def equals_int(self, prop: Union[int, str, Property], value: int):
83+
def equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
7984
prop_id = self._get_property_id(prop)
80-
obx_qb_equals_int(self._c_builder, prop_id, value)
81-
return self
85+
cond = obx_qb_equals_int(self._c_builder, prop_id, value)
86+
return cond
8287

83-
def not_equals_int(self, prop: Union[int, str, Property], value: int):
88+
def not_equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
8489
prop_id = self._get_property_id(prop)
85-
obx_qb_not_equals_int(self._c_builder, prop_id, value)
86-
return self
90+
cond = obx_qb_not_equals_int(self._c_builder, prop_id, value)
91+
return cond
8792

88-
def greater_than_int(self, prop: Union[int, str, Property], value: int):
93+
def greater_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
8994
prop_id = self._get_property_id(prop)
90-
obx_qb_greater_than_int(self._c_builder, prop_id, value)
91-
return self
95+
cond = obx_qb_greater_than_int(self._c_builder, prop_id, value)
96+
return cond
9297

93-
def greater_or_equal_int(self, prop: Union[int, str, Property], value: int):
98+
def greater_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
9499
prop_id = self._get_property_id(prop)
95-
obx_qb_greater_or_equal_int(self._c_builder, prop_id, value)
96-
return self
100+
cond = obx_qb_greater_or_equal_int(self._c_builder, prop_id, value)
101+
return cond
97102

98-
def less_than_int(self, prop: Union[int, str, Property], value: int):
103+
def less_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
99104
prop_id = self._get_property_id(prop)
100-
obx_qb_less_than_int(self._c_builder, prop_id, value)
101-
return self
105+
cond = obx_qb_less_than_int(self._c_builder, prop_id, value)
106+
return cond
102107

103-
def less_or_equal_int(self, prop: Union[int, str, Property], value: int):
108+
def less_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
104109
prop_id = self._get_property_id(prop)
105-
obx_qb_less_or_equal_int(self._c_builder, prop_id, value)
106-
return self
110+
cond = obx_qb_less_or_equal_int(self._c_builder, prop_id, value)
111+
return cond
107112

108-
def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int):
113+
def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int) -> obx_qb_cond:
109114
prop_id = self._get_property_id(prop)
110-
obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b)
111-
return self
115+
cond = obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b)
116+
return cond
112117

113118
def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: Union[np.ndarray, List[float]],
114119
element_count: int):
@@ -117,11 +122,21 @@ def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: U
117122
raise Exception(f"query_vector dtype must be float32")
118123
query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
119124
else: # List[float]
120-
query_vector_data = (ctypes.c_float * len(query_vector))(*query_vector)
125+
query_vector_data = py_list_to_c_array(query_vector, ctypes.c_float)
121126

122127
prop_id = self._get_property_id(prop)
123-
obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count)
124-
return self
128+
cond = obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count)
129+
return cond
130+
131+
def any(self, conditions: List[obx_qb_cond]) -> obx_qb_cond:
132+
c_conditions = py_list_to_c_pointer(conditions, obx_qb_cond)
133+
cond = obx_qb_any(self._c_builder, c_conditions, len(conditions))
134+
return cond
135+
136+
def all(self, conditions: List[obx_qb_cond]) -> obx_qb_cond:
137+
c_conditions = py_list_to_c_pointer(conditions, obx_qb_cond)
138+
cond = obx_qb_all(self._c_builder, c_conditions, len(conditions))
139+
return cond
125140

126141
def build(self) -> Query:
127142
c_query = obx_query(self._c_builder)

tests/test_hnsw.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
5757
assert len(expected_result) == k
5858

5959
# Run ANN with OBX
60-
query_builder = QueryBuilder(db, box)
61-
query_builder.nearest_neighbors_f32("vector", query_point, k)
62-
query = query_builder.build()
60+
qb = box.query()
61+
qb.nearest_neighbors_f32("vector", query_point, k)
62+
query = qb.build()
6363
obx_result = [id_ for id_, score in query.find_ids_with_scores()] # Ignore score
6464
assert len(obx_result) == k
6565

@@ -100,10 +100,10 @@ def test_combined_nn_search():
100100
assert box.count() == 9
101101

102102
# Test condition + NN search
103-
query = box.query() \
104-
.nearest_neighbors_f32("vector", [4.1, 4.2], 6) \
105-
.contains_string("name", "red", case_sensitive=False) \
106-
.build()
103+
qb = box.query()
104+
qb.nearest_neighbors_f32("vector", [4.1, 4.2], 6)
105+
qb.contains_string("name", "red", case_sensitive=False)
106+
query = qb.build()
107107
# 4, 5, 3, 6, 2, 7
108108
# Filtered: 3, 6, 7
109109
search_results = query.find_with_scores()
@@ -120,20 +120,20 @@ def test_combined_nn_search():
120120
assert search_results[0][0].name == "Red apple"
121121

122122
# Regular condition + NN search
123-
query = box.query() \
124-
.nearest_neighbors_f32("vector", [9.2, 8.9], 7) \
125-
.starts_with_string("name", "Blue", case_sensitive=True) \
126-
.build()
123+
qb = box.query()
124+
qb.nearest_neighbors_f32("vector", [9.2, 8.9], 7)
125+
qb.starts_with_string("name", "Blue", case_sensitive=True)
126+
query = qb.build()
127127

128128
search_results = query.find_with_scores()
129129
assert len(search_results) == 1
130130
assert search_results[0][0].name == "Blue sea"
131131

132132
# Regular condition + NN search
133-
query = box.query() \
134-
.nearest_neighbors_f32("vector", [7.7, 7.7], 8) \
135-
.contains_string("name", "blue", case_sensitive=False) \
136-
.build()
133+
qb = box.query()
134+
qb.nearest_neighbors_f32("vector", [7.7, 7.7], 8)
135+
qb.contains_string("name", "blue", case_sensitive=False)
136+
query = qb.build()
137137
# 8, 7, 9, 6, 5, 4, 3, 2
138138
# Filtered: 9, 5, 4, 2
139139
search_results = query.find_ids_with_scores()

0 commit comments

Comments
 (0)