Skip to content

Commit cbdd07d

Browse files
committed
c.py: bubble up HNSW -related functions objectbox#24
1 parent 1395fa9 commit cbdd07d

File tree

1 file changed

+109
-3
lines changed

1 file changed

+109
-3
lines changed

objectbox/c.py

+109-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import platform
1919
from objectbox.version import Version
20+
from typing import *
2021

2122
# This file contains C-API bindings based on lib/objectbox.h, linking to the 'objectbox' shared library.
2223
# The bindings are implementing using ctypes, see https://docs.python.org/dev/library/ctypes.html for introduction.
@@ -72,6 +73,8 @@ def shlib_name(library: str) -> str:
7273
OBXDebugFlags = ctypes.c_int
7374
OBXPutMode = ctypes.c_int
7475
OBXOrderFlags = ctypes.c_int
76+
OBXHnswFlags = ctypes.c_int
77+
OBXHnswDistanceType = ctypes.c_int
7578

7679

7780
class OBX_model(ctypes.Structure):
@@ -115,6 +118,27 @@ class OBX_bytes_array(ctypes.Structure):
115118
OBX_bytes_array_p = ctypes.POINTER(OBX_bytes_array)
116119

117120

121+
class OBX_bytes_score(ctypes.Structure):
122+
_fields_ = [
123+
('data', ctypes.c_void_p),
124+
('size', ctypes.c_size_t),
125+
('score', ctypes.c_double),
126+
]
127+
128+
129+
OBX_bytes_score_p = ctypes.POINTER(OBX_bytes_score)
130+
131+
132+
class OBX_bytes_score_array(ctypes.Structure):
133+
_fields_ = [
134+
('bytes_scores', OBX_bytes_score_p),
135+
('count', ctypes.c_size_t),
136+
]
137+
138+
139+
OBX_bytes_score_array_p = ctypes.POINTER(OBX_bytes_score_array)
140+
141+
118142
class OBX_id_array(ctypes.Structure):
119143
_fields_ = [
120144
('ids', ctypes.POINTER(obx_id)),
@@ -125,6 +149,26 @@ class OBX_id_array(ctypes.Structure):
125149
OBX_id_array_p = ctypes.POINTER(OBX_id_array)
126150

127151

152+
class OBX_id_score(ctypes.Structure):
153+
_fields_ = [
154+
('id', obx_id),
155+
('score', ctypes.c_double)
156+
]
157+
158+
159+
OBX_id_score_p = ctypes.POINTER(OBX_id_score)
160+
161+
162+
class OBX_id_score_array(ctypes.Structure):
163+
_fields_ = [
164+
('ids_scores', ctypes.POINTER(OBX_id_score)),
165+
('count', ctypes.c_size_t)
166+
]
167+
168+
169+
OBX_id_score_array_p = ctypes.POINTER(OBX_id_score_array)
170+
171+
128172
class OBX_txn(ctypes.Structure):
129173
pass
130174

@@ -223,7 +267,7 @@ def check_result(result, func, args):
223267

224268
# creates a global function "name" with the given restype & argtypes, calling C function with the same name.
225269
# restype is used for error checking: if not None, check_result will throw an exception if the result is empty.
226-
def c_fn(name: str, restype: type, argtypes):
270+
def c_fn(name: str, restype: Optional[type], argtypes):
227271
func = C.__getattr__(name)
228272
func.argtypes = argtypes
229273
func.restype = restype
@@ -272,8 +316,38 @@ def c_voidp_as_bytes(voidp, size):
272316
[OBX_model_p, ctypes.c_char_p, OBXPropertyType, obx_schema_id, obx_uid])
273317

274318
# obx_err (OBX_model* model, OBXPropertyFlags flags);
275-
obx_model_property_flags = c_fn_rc('obx_model_property_flags', [
276-
OBX_model_p, OBXPropertyFlags])
319+
obx_model_property_flags = c_fn_rc('obx_model_property_flags', [OBX_model_p, OBXPropertyFlags])
320+
321+
# obx_err obx_model_property_index_id(OBX_model* model, obx_schema_id index_id, obx_uid index_uid)
322+
obx_model_property_index_id = c_fn_rc('obx_model_property_index_id', [OBX_model_p, obx_schema_id, obx_uid])
323+
324+
# obx_err obx_model_property_index_hnsw_dimensions(OBX_model* model, size_t value)
325+
obx_model_property_index_hnsw_dimensions = \
326+
c_fn_rc('obx_model_property_index_hnsw_dimensions', [OBX_model_p, ctypes.c_size_t])
327+
328+
# obx_err obx_model_property_index_hnsw_neighbors_per_node(OBX_model* model, uint32_t value)
329+
obx_model_property_index_hnsw_neighbors_per_node = \
330+
c_fn_rc('obx_model_property_index_hnsw_neighbors_per_node', [OBX_model_p, ctypes.c_uint32])
331+
332+
# obx_err obx_model_property_index_hnsw_indexing_search_count(OBX_model* model, uint32_t value)
333+
obx_model_property_index_hnsw_indexing_search_count = \
334+
c_fn_rc('obx_model_property_index_hnsw_indexing_search_count', [OBX_model_p, ctypes.c_uint32])
335+
336+
# obx_err obx_model_property_index_hnsw_flags(OBX_model* model, OBXHnswFlags value)
337+
obx_model_property_index_hnsw_flags = \
338+
c_fn_rc('obx_model_property_index_hnsw_flags', [OBX_model_p, OBXHnswFlags])
339+
340+
# obx_err obx_model_property_index_hnsw_distance_type(OBX_model* model, OBXHnswDistanceType value)
341+
obx_model_property_index_hnsw_distance_type = \
342+
c_fn_rc('obx_model_property_index_hnsw_distance_type', [OBX_model_p, OBXHnswDistanceType])
343+
344+
# obx_err obx_model_property_index_hnsw_reparation_backlink_probability(OBX_model* model, float value)
345+
obx_model_property_index_hnsw_reparation_backlink_probability = \
346+
c_fn_rc('obx_model_property_index_hnsw_reparation_backlink_probability', [OBX_model_p, ctypes.c_float])
347+
348+
# obx_err obx_model_property_index_hnsw_vector_cache_hint_size_kb(OBX_model* model, size_t value)
349+
obx_model_property_index_hnsw_vector_cache_hint_size_kb = \
350+
c_fn_rc('obx_model_property_index_hnsw_vector_cache_hint_size_kb', [OBX_model_p, ctypes.c_size_t])
277351

278352
# obx_err (OBX_model*, obx_schema_id entity_id, obx_uid entity_uid);
279353
obx_model_last_entity_id = c_fn('obx_model_last_entity_id', None, [
@@ -536,9 +610,20 @@ def c_voidp_as_bytes(voidp, size):
536610
# OBX_C_API obx_err obx_qb_param_alias(OBX_query_builder* builder, const char* alias);
537611
obx_qb_param_alias = c_fn_rc('obx_qb_param_alias', [OBX_query_builder_p, ctypes.c_char_p])
538612

613+
# OBX_C_API obx_err obx_query_param_vector_float32(OBX_query* query, obx_schema_id entity_id, obx_schema_id property_id, const float* value, size_t element_count);
614+
# TODO
615+
616+
# OBX_C_API obx_err obx_query_param_alias_vector_float32(OBX_query* query, const char* alias, const float* value, size_t element_count);
617+
# TODO
618+
539619
# OBX_C_API obx_err obx_qb_order(OBX_query_builder* builder, obx_schema_id property_id, OBXOrderFlags flags);
540620
obx_qb_order = c_fn_rc('obx_qb_order', [OBX_query_builder_p, obx_schema_id, OBXOrderFlags])
541621

622+
# OBX_C_API obx_qb_cond obx_qb_nearest_neighbors_f32(OBX_query_builder* builder, obx_schema_id vector_property_id, const float* query_vector, size_t max_result_count)
623+
obx_qb_nearest_neighbors_f32 = \
624+
c_fn('obx_qb_nearest_neighbors_f32', obx_qb_cond, [OBX_query_builder_p, obx_schema_id,
625+
ctypes.pointer(ctypes.c_float), ctypes.c_size_t])
626+
542627
# OBX_C_API OBX_query* obx_query(OBX_query_builder* builder);
543628
obx_query = c_fn('obx_query', OBX_query_p, [OBX_query_builder_p])
544629

@@ -566,6 +651,9 @@ def c_voidp_as_bytes(voidp, size):
566651
# OBX_C_API obx_err obx_query_find_unique(OBX_query* query, const void** data, size_t* size);
567652
obx_query_find_unique = c_fn_rc('obx_query_find_unique', [OBX_query_p, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_size_t)])
568653

654+
# OBX_C_API OBX_bytes_score_array* obx_query_find_with_scores(OBX_query* query);
655+
obx_query_find_with_scores = c_fn('obx_query_find_with_scores', OBX_bytes_score_array_p, [OBX_query_p]) # TODO
656+
569657
# typedef bool obx_data_visitor(void* user_data, const void* data, size_t size);
570658

571659
# OBX_C_API obx_err obx_query_visit(OBX_query* query, obx_data_visitor* visitor, void* user_data);
@@ -574,6 +662,9 @@ def c_voidp_as_bytes(voidp, size):
574662
# OBX_C_API OBX_id_array* obx_query_find_ids(OBX_query* query);
575663
obx_query_find_ids = c_fn('obx_query_find_ids', OBX_id_array_p, [OBX_query_p])
576664

665+
# OBX_C_API OBX_id_score_array* obx_query_find_ids_with_scores(OBX_query* query);
666+
obx_query_find_ids_with_scores = c_fn('obx_query_find_ids_with_scores', OBX_id_score_array_p, [OBX_query_p]) # TODO
667+
577668
# OBX_C_API obx_err obx_query_count(OBX_query* query, uint64_t* out_count);
578669
obx_query_count = c_fn_rc('obx_query_count', [OBX_query_p, ctypes.POINTER(ctypes.c_uint64)])
579670

@@ -596,6 +687,12 @@ def c_voidp_as_bytes(voidp, size):
596687
# void (OBX_bytes_array * array);
597688
obx_bytes_array_free = c_fn('obx_bytes_array_free', None, [OBX_bytes_array_p])
598689

690+
# OBX_C_API void obx_bytes_score_array_free(OBX_bytes_score_array* array)
691+
obx_bytes_score_array_free = c_fn('obx_bytes_score_array_free', None, [OBX_bytes_score_array_p])
692+
693+
# OBX_C_API void obx_id_score_array_free(OBX_id_score_array* array)
694+
obx_id_score_array_free = c_fn('obx_id_score_array_free', None, [OBX_id_score_array_p])
695+
599696
OBXPropertyType_Bool = 1
600697
OBXPropertyType_Byte = 2
601698
OBXPropertyType_Short = 3
@@ -669,3 +766,12 @@ def c_voidp_as_bytes(voidp, size):
669766

670767
# null values should be treated equal to zero (scalars only).
671768
OBXOrderFlags_NULLS_ZERO = 16
769+
770+
OBXHnswFlags_NONE = 0
771+
OBXHnswFlags_DEBUG_LOGS = 1
772+
OBXHnswFlags_DEBUG_LOGS_DETAILED = 2
773+
OBXHnswFlags_VECTOR_CACHE_SIMD_PADDING_OFF = 4
774+
OBXHnswFlags_REPARATION_LIMIT_CANDIDATES = 8
775+
776+
OBXHnswDistanceType_UNKNOWN = 0
777+
OBXHnswDistanceType_EUCLIDEAN = 1

0 commit comments

Comments
 (0)