Skip to content

Commit 0edd0e9

Browse files
committed
Improve the API and update the tests
1 parent a0dae47 commit 0edd0e9

File tree

10 files changed

+147
-217
lines changed

10 files changed

+147
-217
lines changed

infera/bindings/include/rust.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,23 @@ void infera_free(char *ptr);
2424
const char *infera_last_error(void);
2525

2626
// Version and autoload functions
27-
char *infera_version(void);
28-
char *infera_autoload_dir(const char *path);
27+
char *infera_get_version(void);
28+
char *infera_set_autoload_dir(const char *path);
2929

3030
// Model management functions
31-
int32_t infera_load_onnx_model(const char *name, const char *path);
32-
int32_t infera_unload_onnx_model(const char *name);
31+
int32_t infera_load_model(const char *name, const char *path);
32+
int32_t infera_unload_model(const char *name);
3333

3434
// Inference functions
35-
InferaInferenceResult infera_run_inference(const char *model_name,
36-
const float *data, size_t rows,
37-
size_t cols);
38-
InferaInferenceResult infera_predict_blob(const char *model_name,
39-
const uint8_t *blob_data,
40-
size_t blob_len);
35+
InferaInferenceResult infera_predict(const char *model_name, const float *data,
36+
size_t rows, size_t cols);
37+
InferaInferenceResult infera_predict_from_blob(const char *model_name,
38+
const uint8_t *blob_data,
39+
size_t blob_len);
4140

4241
// Utility functions
43-
char *infera_list_models(void);
44-
char *infera_model_info(const char *model_name);
45-
char *infera_get_model_metadata(const char *model_name); // changed return type
42+
char *infera_get_loaded_models(void);
43+
char *infera_get_model_info(const char *model_name);
4644

4745
// Memory cleanup functions
4846
void infera_free_result(InferaInferenceResult result);

infera/bindings/infera_extension.cpp

Lines changed: 68 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,38 @@ static std::string GetInferaError() {
2828
return err ? std::string(err) : std::string("unknown error");
2929
}
3030

31-
static void PragmaAutoloadDir(ClientContext &context, const FunctionParameters &parameters) {
32-
if (parameters.values.empty() || parameters.values[0].IsNull()) {
33-
return;
31+
static void SetAutoloadDir(DataChunk &args, ExpressionState &state, Vector &result) {
32+
if (args.ColumnCount() != 1) {
33+
throw InvalidInputException("infera_set_autoload_dir(path) expects exactly 1 argument");
34+
}
35+
if (args.size() == 0) { return; }
36+
auto path_val = args.data[0].GetValue(0);
37+
if (path_val.IsNull()) {
38+
throw InvalidInputException("Path cannot be NULL");
3439
}
35-
std::string path = parameters.values[0].ToString();
36-
char *result_json_c = infera_autoload_dir(path.c_str());
40+
std::string path_str = path_val.ToString();
41+
char *result_json_c = infera_set_autoload_dir(path_str.c_str());
42+
result.SetVectorType(VectorType::CONSTANT_VECTOR);
43+
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, result_json_c);
44+
ConstantVector::SetNull(result, false);
3745
infera_free(result_json_c);
3846
}
3947

40-
static void InferaVersion(DataChunk &args, ExpressionState &state, Vector &result) {
41-
char *info_json_c = infera_version();
48+
static void GetVersion(DataChunk &args, ExpressionState &state, Vector &result) {
49+
char *info_json_c = infera_get_version();
4250
result.SetVectorType(VectorType::CONSTANT_VECTOR);
4351
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, info_json_c);
4452
ConstantVector::SetNull(result, false);
4553
infera_free(info_json_c);
4654
}
4755

48-
static void LoadOnnxModel(DataChunk &args, ExpressionState &state, Vector &result) {
56+
static void LoadModel(DataChunk &args, ExpressionState &state, Vector &result) {
4957
if (args.ColumnCount() != 2) {
50-
throw InvalidInputException("load_onnx_model(model_name, path) expects exactly 2 arguments");
58+
throw InvalidInputException("infera_load_model(model_name, path) expects exactly 2 arguments");
5159
}
52-
auto &model_name_vec = args.data[0];
53-
auto &path_vec = args.data[1];
5460
if (args.size() == 0) { return; }
55-
auto model_name = model_name_vec.GetValue(0);
56-
auto path = path_vec.GetValue(0);
61+
auto model_name = args.data[0].GetValue(0);
62+
auto path = args.data[1].GetValue(0);
5763
if (model_name.IsNull() || path.IsNull()) {
5864
throw InvalidInputException("Model name and path cannot be NULL");
5965
}
@@ -62,31 +68,30 @@ static void LoadOnnxModel(DataChunk &args, ExpressionState &state, Vector &resul
6268
if (model_name_str.empty()) {
6369
throw InvalidInputException("Model name cannot be empty");
6470
}
65-
int rc = infera_load_onnx_model(model_name_str.c_str(), path_str.c_str());
71+
int rc = infera_load_model(model_name_str.c_str(), path_str.c_str());
6672
bool success = rc == 0;
6773
if (!success) {
68-
throw InvalidInputException("Failed to load ONNX model '" + model_name_str + "': " + GetInferaError());
74+
throw InvalidInputException("Failed to load model '" + model_name_str + "': " + GetInferaError());
6975
}
7076
result.SetVectorType(VectorType::CONSTANT_VECTOR);
7177
ConstantVector::GetData<bool>(result)[0] = success;
7278
ConstantVector::SetNull(result, false);
7379
}
7480

75-
static void UnloadOnnxModel(DataChunk &args, ExpressionState &state, Vector &result) {
81+
static void UnloadModel(DataChunk &args, ExpressionState &state, Vector &result) {
7682
if (args.ColumnCount() != 1) {
77-
throw InvalidInputException("unload_onnx_model(model_name) expects exactly 1 argument");
83+
throw InvalidInputException("infera_unload_model(model_name) expects exactly 1 argument");
7884
}
79-
auto &model_name_vec = args.data[0];
8085
if (args.size() == 0) { return; }
81-
auto model_name = model_name_vec.GetValue(0);
86+
auto model_name = args.data[0].GetValue(0);
8287
if (model_name.IsNull()) {
8388
throw InvalidInputException("Model name cannot be NULL");
8489
}
8590
std::string model_name_str = model_name.ToString();
86-
int rc = infera_unload_onnx_model(model_name_str.c_str());
91+
int rc = infera_unload_model(model_name_str.c_str());
8792
bool success = (rc == 0);
8893
if (!success) {
89-
throw InvalidInputException("Failed to unload ONNX model '" + model_name_str + "': " + GetInferaError());
94+
throw InvalidInputException("Failed to unload model '" + model_name_str + "': " + GetInferaError());
9095
}
9196
result.SetVectorType(VectorType::CONSTANT_VECTOR);
9297
ConstantVector::GetData<bool>(result)[0] = success;
@@ -117,24 +122,28 @@ static void ExtractFeatures(DataChunk &args, std::vector<float> &features) {
117122
}
118123
}
119124

120-
static void OnnxPredict(DataChunk &args, ExpressionState &state, Vector &result) {
125+
static std::string ValidateAndGetModelName(DataChunk &args, const std::string &func_name) {
121126
if (args.ColumnCount() < 2) {
122-
throw InvalidInputException("onnx_predict(model_name, feature1, ...) requires at least 2 arguments");
127+
throw InvalidInputException(func_name + "(model_name, feature1, ...) requires at least 2 arguments");
123128
}
124-
if (args.size() == 0) { return; }
125-
auto &model_name_vec = args.data[0];
126-
auto model_name_val = model_name_vec.GetValue(0);
129+
auto model_name_val = args.data[0].GetValue(0);
127130
if (model_name_val.IsNull()) {
128131
throw InvalidInputException("Model name cannot be NULL");
129132
}
130-
std::string model_name_str = model_name_val.ToString();
133+
return model_name_val.ToString();
134+
}
135+
136+
static void Predict(DataChunk &args, ExpressionState &state, Vector &result) {
137+
if (args.size() == 0) { return; }
138+
std::string model_name_str = ValidateAndGetModelName(args, "infera_predict");
139+
131140
const idx_t batch_size = args.size();
132141
const idx_t feature_count = args.ColumnCount() - 1;
133142

134143
std::vector<float> features;
135144
ExtractFeatures(args, features);
136145

137-
InferaInferenceResult res = infera_run_inference(model_name_str.c_str(), features.data(), batch_size, feature_count);
146+
InferaInferenceResult res = infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
138147
if (res.status != 0) {
139148
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
140149
}
@@ -151,9 +160,9 @@ static void OnnxPredict(DataChunk &args, ExpressionState &state, Vector &result)
151160
infera_free_result(res);
152161
}
153162

154-
static void InferaPredictBlob(DataChunk &args, ExpressionState &state, Vector &result) {
163+
static void PredictFromBlob(DataChunk &args, ExpressionState &state, Vector &result) {
155164
if (args.ColumnCount() != 2) {
156-
throw InvalidInputException("infera_predict_blob(model_name, input_blob) requires 2 arguments");
165+
throw InvalidInputException("infera_predict_from_blob(model_name, input_blob) requires 2 arguments");
157166
}
158167
if (args.size() == 0) { return; }
159168
result.SetVectorType(VectorType::FLAT_VECTOR);
@@ -168,7 +177,7 @@ static void InferaPredictBlob(DataChunk &args, ExpressionState &state, Vector &r
168177
string_t blob_str_t = blob_val.GetValueUnsafe<string_t>();
169178
auto blob_ptr = reinterpret_cast<const uint8_t *>(blob_str_t.GetDataUnsafe());
170179
auto blob_len = blob_str_t.GetSize();
171-
InferaInferenceResult res = infera_predict_blob(model_name_str.c_str(), blob_ptr, blob_len);
180+
InferaInferenceResult res = infera_predict_from_blob(model_name_str.c_str(), blob_ptr, blob_len);
172181
if (res.status != 0) {
173182
infera_free_result(res);
174183
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
@@ -184,49 +193,25 @@ static void InferaPredictBlob(DataChunk &args, ExpressionState &state, Vector &r
184193
result.Verify(args.size());
185194
}
186195

187-
static void ListModels(DataChunk &args, ExpressionState &state, Vector &result) {
188-
char *models_json = infera_list_models();
196+
static void GetLoadedModels(DataChunk &args, ExpressionState &state, Vector &result) {
197+
char *models_json = infera_get_loaded_models();
189198
result.SetVectorType(VectorType::CONSTANT_VECTOR);
190199
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, models_json);
191200
ConstantVector::SetNull(result, false);
192201
infera_free(models_json);
193202
}
194203

195-
static void ModelInfo(DataChunk &args, ExpressionState &state, Vector &result) {
196-
if (args.ColumnCount() != 1) {
197-
throw InvalidInputException("model_info(model_name) expects exactly 1 argument");
198-
}
204+
static void PredictMulti(DataChunk &args, ExpressionState &state, Vector &result) {
199205
if (args.size() == 0) { return; }
200-
auto model_name = args.data[0].GetValue(0);
201-
if (model_name.IsNull()) {
202-
throw InvalidInputException("Model name cannot be NULL");
203-
}
204-
std::string model_name_str = model_name.ToString();
205-
char *info_json = infera_model_info(model_name_str.c_str());
206-
result.SetVectorType(VectorType::CONSTANT_VECTOR);
207-
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, info_json);
208-
ConstantVector::SetNull(result, false);
209-
infera_free(info_json);
210-
}
206+
std::string model_name_str = ValidateAndGetModelName(args, "infera_predict_multi");
211207

212-
static void OnnxPredictMulti(DataChunk &args, ExpressionState &state, Vector &result) {
213-
if (args.ColumnCount() < 2) {
214-
throw InvalidInputException("onnx_predict_multi(model_name, feature1, ...) requires at least 2 arguments");
215-
}
216-
if (args.size() == 0) { return; }
217-
auto &model_name_vec = args.data[0];
218-
auto model_name_val = model_name_vec.GetValue(0);
219-
if (model_name_val.IsNull()) {
220-
throw InvalidInputException("Model name cannot be NULL");
221-
}
222-
std::string model_name_str = model_name_val.ToString();
223208
const idx_t batch_size = args.size();
224209
const idx_t feature_count = args.ColumnCount() - 1;
225210

226211
std::vector<float> features;
227212
ExtractFeatures(args, features);
228213

229-
InferaInferenceResult res = infera_run_inference(model_name_str.c_str(), features.data(), batch_size, feature_count);
214+
InferaInferenceResult res = infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
230215
if (res.status != 0) {
231216
infera_free_result(res);
232217
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
@@ -240,28 +225,31 @@ static void OnnxPredictMulti(DataChunk &args, ExpressionState &state, Vector &re
240225
auto result_data = FlatVector::GetData<string_t>(result);
241226
const size_t output_cols = res.cols;
242227
for (idx_t row_idx = 0; row_idx < batch_size; row_idx++) {
243-
std::string json_result = "[";
228+
std::ostringstream oss;
229+
oss << "[";
244230
for (size_t col_idx = 0; col_idx < output_cols; col_idx++) {
245-
if (col_idx > 0) { json_result += ","; }
246-
json_result += std::to_string(res.data[row_idx * output_cols + col_idx]);
231+
if (col_idx > 0) {
232+
oss << ",";
233+
}
234+
oss << res.data[row_idx * output_cols + col_idx];
247235
}
248-
json_result += "]";
249-
result_data[row_idx] = StringVector::AddString(result, json_result);
236+
oss << "]";
237+
result_data[row_idx] = StringVector::AddString(result, oss.str());
250238
}
251239
infera_free_result(res);
252240
}
253241

254-
static void ModelMetadataFunc(DataChunk &args, ExpressionState &state, Vector &result) {
242+
static void GetModelInfo(DataChunk &args, ExpressionState &state, Vector &result) {
255243
if (args.ColumnCount() != 1) {
256-
throw InvalidInputException("model_metadata(model_name) expects exactly 1 argument");
244+
throw InvalidInputException("infera_get_model_info(model_name) expects exactly 1 argument");
257245
}
258246
if (args.size() == 0) { return; }
259247
auto model_name = args.data[0].GetValue(0);
260248
if (model_name.IsNull()) {
261249
throw InvalidInputException("Model name cannot be NULL");
262250
}
263251
std::string model_name_str = model_name.ToString();
264-
char *json_meta = infera_get_model_metadata(model_name_str.c_str());
252+
char *json_meta = infera_get_model_info(model_name_str.c_str());
265253

266254
result.SetVectorType(VectorType::CONSTANT_VECTOR);
267255
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, json_meta);
@@ -270,43 +258,26 @@ static void ModelMetadataFunc(DataChunk &args, ExpressionState &state, Vector &r
270258
}
271259

272260
static void LoadInternal(ExtensionLoader &loader) {
273-
ScalarFunction load_onnx_model_func("load_onnx_model", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadOnnxModel);
274-
loader.RegisterFunction(load_onnx_model_func);
261+
loader.RegisterFunction(ScalarFunction("infera_load_model", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadModel));
262+
loader.RegisterFunction(ScalarFunction("infera_unload_model", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadModel));
275263

276-
ScalarFunction unload_onnx_model_func("unload_onnx_model", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadOnnxModel);
277-
loader.RegisterFunction(unload_onnx_model_func);
278-
279-
// Removed deprecated LogicalType::VARARG usage. Register multiple arities instead.
280-
const idx_t MAX_FEATURES = 63; // features (total args = 1 + features)
264+
const idx_t MAX_FEATURES = 63;
281265
for (idx_t feature_count = 1; feature_count <= MAX_FEATURES; feature_count++) {
282266
vector<LogicalType> arg_types;
283267
arg_types.reserve(feature_count + 1);
284-
arg_types.push_back(LogicalType::VARCHAR); // model name
268+
arg_types.push_back(LogicalType::VARCHAR);
285269
for (idx_t i = 0; i < feature_count; i++) {
286-
arg_types.push_back(LogicalType::FLOAT); // DuckDB will auto-cast other numerics
270+
arg_types.push_back(LogicalType::FLOAT);
287271
}
288-
loader.RegisterFunction(ScalarFunction("onnx_predict", arg_types, LogicalType::FLOAT, OnnxPredict));
289-
loader.RegisterFunction(ScalarFunction("onnx_predict_multi", arg_types, LogicalType::VARCHAR, OnnxPredictMulti));
272+
loader.RegisterFunction(ScalarFunction("infera_predict", arg_types, LogicalType::FLOAT, Predict));
273+
loader.RegisterFunction(ScalarFunction("infera_predict_multi", arg_types, LogicalType::VARCHAR, PredictMulti));
290274
}
291275

292-
ScalarFunction infera_predict_blob_func("infera_predict_blob", {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST(LogicalType::FLOAT), InferaPredictBlob);
293-
loader.RegisterFunction(infera_predict_blob_func);
294-
295-
ScalarFunction list_models_func("list_models", {}, LogicalType::VARCHAR, ListModels);
296-
loader.RegisterFunction(list_models_func);
297-
298-
ScalarFunction model_info_func("model_info", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ModelInfo);
299-
loader.RegisterFunction(model_info_func);
300-
301-
ScalarFunction model_metadata_func("model_metadata", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ModelMetadataFunc);
302-
loader.RegisterFunction(model_metadata_func);
303-
304-
ScalarFunction infera_version_func("infera_version", {}, LogicalType::VARCHAR, InferaVersion);
305-
loader.RegisterFunction(infera_version_func);
306-
307-
auto autoload_pragma = PragmaFunction::PragmaCall("infera_autoload_dir", PragmaAutoloadDir,
308-
{LogicalType::VARCHAR}, LogicalType::INVALID);
309-
loader.RegisterFunction(autoload_pragma);
276+
loader.RegisterFunction(ScalarFunction("infera_predict_from_blob", {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST(LogicalType::FLOAT), PredictFromBlob));
277+
loader.RegisterFunction(ScalarFunction("infera_get_loaded_models", {}, LogicalType::VARCHAR, GetLoadedModels));
278+
loader.RegisterFunction(ScalarFunction("infera_get_model_info", {LogicalType::VARCHAR}, LogicalType::VARCHAR, GetModelInfo));
279+
loader.RegisterFunction(ScalarFunction("infera_get_version", {}, LogicalType::VARCHAR, GetVersion));
280+
loader.RegisterFunction(ScalarFunction("infera_set_autoload_dir", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetAutoloadDir));
310281
}
311282

312283
void InferaExtension::Load(ExtensionLoader &loader) { LoadInternal(loader); }

infera/cbindgen.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ prefix = ""
3333
# Export settings
3434
[export]
3535
include = [
36-
"infera_autoload_dir",
37-
"infera_version",
38-
"infera_load_onnx_model",
39-
"infera_unload_onnx_model",
40-
"infera_run_inference",
41-
"infera_predict_blob",
42-
"infera_get_model_metadata",
43-
"infera_list_models",
36+
"infera_set_autoload_dir",
37+
"infera_get_version",
38+
"infera_load_model",
39+
"infera_unload_model",
40+
"infera_predict",
41+
"infera_predict_from_blob",
42+
"infera_get_model_info",
43+
"infera_get_loaded_models",
4444
"infera_last_error",
4545
"infera_free",
4646
"infera_free_result",

0 commit comments

Comments
 (0)