@@ -28,32 +28,38 @@ static std::string GetInferaError() {
2828 return err ? std::string (err) : std::string (" unknown error" );
2929}
3030
31- static void PragmaAutoloadDir (ClientContext &context, const FunctionParameters ¶meters) {
32- if (parameters.values .empty () || parameters.values [0 ].IsNull ()) {
33- return ;
31+ static void SetAutoloadDir (DataChunk &args, ExpressionState &state, Vector &result) {
32+ if (args.ColumnCount () != 1 ) {
33+ throw InvalidInputException (" infera_set_autoload_dir(path) expects exactly 1 argument" );
34+ }
35+ if (args.size () == 0 ) { return ; }
36+ auto path_val = args.data [0 ].GetValue (0 );
37+ if (path_val.IsNull ()) {
38+ throw InvalidInputException (" Path cannot be NULL" );
3439 }
35- std::string path = parameters.values [0 ].ToString ();
36- char *result_json_c = infera_autoload_dir (path.c_str ());
40+ std::string path_str = path_val.ToString ();
41+ char *result_json_c = infera_set_autoload_dir (path_str.c_str ());
42+ result.SetVectorType (VectorType::CONSTANT_VECTOR);
43+ ConstantVector::GetData<string_t >(result)[0 ] = StringVector::AddString (result, result_json_c);
44+ ConstantVector::SetNull (result, false );
3745 infera_free (result_json_c);
3846}
3947
40- static void InferaVersion (DataChunk &args, ExpressionState &state, Vector &result) {
41- char *info_json_c = infera_version ();
48+ static void GetVersion (DataChunk &args, ExpressionState &state, Vector &result) {
49+ char *info_json_c = infera_get_version ();
4250 result.SetVectorType (VectorType::CONSTANT_VECTOR);
4351 ConstantVector::GetData<string_t >(result)[0 ] = StringVector::AddString (result, info_json_c);
4452 ConstantVector::SetNull (result, false );
4553 infera_free (info_json_c);
4654}
4755
48- static void LoadOnnxModel (DataChunk &args, ExpressionState &state, Vector &result) {
56+ static void LoadModel (DataChunk &args, ExpressionState &state, Vector &result) {
4957 if (args.ColumnCount () != 2 ) {
50- throw InvalidInputException (" load_onnx_model (model_name, path) expects exactly 2 arguments" );
58+ throw InvalidInputException (" infera_load_model (model_name, path) expects exactly 2 arguments" );
5159 }
52- auto &model_name_vec = args.data [0 ];
53- auto &path_vec = args.data [1 ];
5460 if (args.size () == 0 ) { return ; }
55- auto model_name = model_name_vec .GetValue (0 );
56- auto path = path_vec .GetValue (0 );
61+ auto model_name = args. data [ 0 ] .GetValue (0 );
62+ auto path = args. data [ 1 ] .GetValue (0 );
5763 if (model_name.IsNull () || path.IsNull ()) {
5864 throw InvalidInputException (" Model name and path cannot be NULL" );
5965 }
@@ -62,31 +68,30 @@ static void LoadOnnxModel(DataChunk &args, ExpressionState &state, Vector &resul
6268 if (model_name_str.empty ()) {
6369 throw InvalidInputException (" Model name cannot be empty" );
6470 }
65- int rc = infera_load_onnx_model (model_name_str.c_str (), path_str.c_str ());
71+ int rc = infera_load_model (model_name_str.c_str (), path_str.c_str ());
6672 bool success = rc == 0 ;
6773 if (!success) {
68- throw InvalidInputException (" Failed to load ONNX model '" + model_name_str + " ': " + GetInferaError ());
74+ throw InvalidInputException (" Failed to load model '" + model_name_str + " ': " + GetInferaError ());
6975 }
7076 result.SetVectorType (VectorType::CONSTANT_VECTOR);
7177 ConstantVector::GetData<bool >(result)[0 ] = success;
7278 ConstantVector::SetNull (result, false );
7379}
7480
75- static void UnloadOnnxModel (DataChunk &args, ExpressionState &state, Vector &result) {
81+ static void UnloadModel (DataChunk &args, ExpressionState &state, Vector &result) {
7682 if (args.ColumnCount () != 1 ) {
77- throw InvalidInputException (" unload_onnx_model (model_name) expects exactly 1 argument" );
83+ throw InvalidInputException (" infera_unload_model (model_name) expects exactly 1 argument" );
7884 }
79- auto &model_name_vec = args.data [0 ];
8085 if (args.size () == 0 ) { return ; }
81- auto model_name = model_name_vec .GetValue (0 );
86+ auto model_name = args. data [ 0 ] .GetValue (0 );
8287 if (model_name.IsNull ()) {
8388 throw InvalidInputException (" Model name cannot be NULL" );
8489 }
8590 std::string model_name_str = model_name.ToString ();
86- int rc = infera_unload_onnx_model (model_name_str.c_str ());
91+ int rc = infera_unload_model (model_name_str.c_str ());
8792 bool success = (rc == 0 );
8893 if (!success) {
89- throw InvalidInputException (" Failed to unload ONNX model '" + model_name_str + " ': " + GetInferaError ());
94+ throw InvalidInputException (" Failed to unload model '" + model_name_str + " ': " + GetInferaError ());
9095 }
9196 result.SetVectorType (VectorType::CONSTANT_VECTOR);
9297 ConstantVector::GetData<bool >(result)[0 ] = success;
@@ -117,24 +122,28 @@ static void ExtractFeatures(DataChunk &args, std::vector<float> &features) {
117122 }
118123}
119124
120- static void OnnxPredict (DataChunk &args, ExpressionState &state, Vector &result ) {
125+ static std::string ValidateAndGetModelName (DataChunk &args, const std::string &func_name ) {
121126 if (args.ColumnCount () < 2 ) {
122- throw InvalidInputException (" onnx_predict (model_name, feature1, ...) requires at least 2 arguments" );
127+ throw InvalidInputException (func_name + " (model_name, feature1, ...) requires at least 2 arguments" );
123128 }
124- if (args.size () == 0 ) { return ; }
125- auto &model_name_vec = args.data [0 ];
126- auto model_name_val = model_name_vec.GetValue (0 );
129+ auto model_name_val = args.data [0 ].GetValue (0 );
127130 if (model_name_val.IsNull ()) {
128131 throw InvalidInputException (" Model name cannot be NULL" );
129132 }
130- std::string model_name_str = model_name_val.ToString ();
133+ return model_name_val.ToString ();
134+ }
135+
136+ static void Predict (DataChunk &args, ExpressionState &state, Vector &result) {
137+ if (args.size () == 0 ) { return ; }
138+ std::string model_name_str = ValidateAndGetModelName (args, " infera_predict" );
139+
131140 const idx_t batch_size = args.size ();
132141 const idx_t feature_count = args.ColumnCount () - 1 ;
133142
134143 std::vector<float > features;
135144 ExtractFeatures (args, features);
136145
137- InferaInferenceResult res = infera_run_inference (model_name_str.c_str (), features.data (), batch_size, feature_count);
146+ InferaInferenceResult res = infera_predict (model_name_str.c_str (), features.data (), batch_size, feature_count);
138147 if (res.status != 0 ) {
139148 throw InvalidInputException (" Inference failed for model '" + model_name_str + " ': " + GetInferaError ());
140149 }
@@ -151,9 +160,9 @@ static void OnnxPredict(DataChunk &args, ExpressionState &state, Vector &result)
151160 infera_free_result (res);
152161}
153162
154- static void InferaPredictBlob (DataChunk &args, ExpressionState &state, Vector &result) {
163+ static void PredictFromBlob (DataChunk &args, ExpressionState &state, Vector &result) {
155164 if (args.ColumnCount () != 2 ) {
156- throw InvalidInputException (" infera_predict_blob (model_name, input_blob) requires 2 arguments" );
165+ throw InvalidInputException (" infera_predict_from_blob (model_name, input_blob) requires 2 arguments" );
157166 }
158167 if (args.size () == 0 ) { return ; }
159168 result.SetVectorType (VectorType::FLAT_VECTOR);
@@ -168,7 +177,7 @@ static void InferaPredictBlob(DataChunk &args, ExpressionState &state, Vector &r
168177 string_t blob_str_t = blob_val.GetValueUnsafe <string_t >();
169178 auto blob_ptr = reinterpret_cast <const uint8_t *>(blob_str_t .GetDataUnsafe ());
170179 auto blob_len = blob_str_t .GetSize ();
171- InferaInferenceResult res = infera_predict_blob (model_name_str.c_str (), blob_ptr, blob_len);
180+ InferaInferenceResult res = infera_predict_from_blob (model_name_str.c_str (), blob_ptr, blob_len);
172181 if (res.status != 0 ) {
173182 infera_free_result (res);
174183 throw InvalidInputException (" Inference failed for model '" + model_name_str + " ': " + GetInferaError ());
@@ -184,49 +193,25 @@ static void InferaPredictBlob(DataChunk &args, ExpressionState &state, Vector &r
184193 result.Verify (args.size ());
185194}
186195
187- static void ListModels (DataChunk &args, ExpressionState &state, Vector &result) {
188- char *models_json = infera_list_models ();
196+ static void GetLoadedModels (DataChunk &args, ExpressionState &state, Vector &result) {
197+ char *models_json = infera_get_loaded_models ();
189198 result.SetVectorType (VectorType::CONSTANT_VECTOR);
190199 ConstantVector::GetData<string_t >(result)[0 ] = StringVector::AddString (result, models_json);
191200 ConstantVector::SetNull (result, false );
192201 infera_free (models_json);
193202}
194203
195- static void ModelInfo (DataChunk &args, ExpressionState &state, Vector &result) {
196- if (args.ColumnCount () != 1 ) {
197- throw InvalidInputException (" model_info(model_name) expects exactly 1 argument" );
198- }
204+ static void PredictMulti (DataChunk &args, ExpressionState &state, Vector &result) {
199205 if (args.size () == 0 ) { return ; }
200- auto model_name = args.data [0 ].GetValue (0 );
201- if (model_name.IsNull ()) {
202- throw InvalidInputException (" Model name cannot be NULL" );
203- }
204- std::string model_name_str = model_name.ToString ();
205- char *info_json = infera_model_info (model_name_str.c_str ());
206- result.SetVectorType (VectorType::CONSTANT_VECTOR);
207- ConstantVector::GetData<string_t >(result)[0 ] = StringVector::AddString (result, info_json);
208- ConstantVector::SetNull (result, false );
209- infera_free (info_json);
210- }
206+ std::string model_name_str = ValidateAndGetModelName (args, " infera_predict_multi" );
211207
212- static void OnnxPredictMulti (DataChunk &args, ExpressionState &state, Vector &result) {
213- if (args.ColumnCount () < 2 ) {
214- throw InvalidInputException (" onnx_predict_multi(model_name, feature1, ...) requires at least 2 arguments" );
215- }
216- if (args.size () == 0 ) { return ; }
217- auto &model_name_vec = args.data [0 ];
218- auto model_name_val = model_name_vec.GetValue (0 );
219- if (model_name_val.IsNull ()) {
220- throw InvalidInputException (" Model name cannot be NULL" );
221- }
222- std::string model_name_str = model_name_val.ToString ();
223208 const idx_t batch_size = args.size ();
224209 const idx_t feature_count = args.ColumnCount () - 1 ;
225210
226211 std::vector<float > features;
227212 ExtractFeatures (args, features);
228213
229- InferaInferenceResult res = infera_run_inference (model_name_str.c_str (), features.data (), batch_size, feature_count);
214+ InferaInferenceResult res = infera_predict (model_name_str.c_str (), features.data (), batch_size, feature_count);
230215 if (res.status != 0 ) {
231216 infera_free_result (res);
232217 throw InvalidInputException (" Inference failed for model '" + model_name_str + " ': " + GetInferaError ());
@@ -240,28 +225,31 @@ static void OnnxPredictMulti(DataChunk &args, ExpressionState &state, Vector &re
240225 auto result_data = FlatVector::GetData<string_t >(result);
241226 const size_t output_cols = res.cols ;
242227 for (idx_t row_idx = 0 ; row_idx < batch_size; row_idx++) {
243- std::string json_result = " [" ;
228+ std::ostringstream oss;
229+ oss << " [" ;
244230 for (size_t col_idx = 0 ; col_idx < output_cols; col_idx++) {
245- if (col_idx > 0 ) { json_result += " ," ; }
246- json_result += std::to_string (res.data [row_idx * output_cols + col_idx]);
231+ if (col_idx > 0 ) {
232+ oss << " ," ;
233+ }
234+ oss << res.data [row_idx * output_cols + col_idx];
247235 }
248- json_result += " ]" ;
249- result_data[row_idx] = StringVector::AddString (result, json_result );
236+ oss << " ]" ;
237+ result_data[row_idx] = StringVector::AddString (result, oss. str () );
250238 }
251239 infera_free_result (res);
252240}
253241
254- static void ModelMetadataFunc (DataChunk &args, ExpressionState &state, Vector &result) {
242+ static void GetModelInfo (DataChunk &args, ExpressionState &state, Vector &result) {
255243 if (args.ColumnCount () != 1 ) {
256- throw InvalidInputException (" model_metadata (model_name) expects exactly 1 argument" );
244+ throw InvalidInputException (" infera_get_model_info (model_name) expects exactly 1 argument" );
257245 }
258246 if (args.size () == 0 ) { return ; }
259247 auto model_name = args.data [0 ].GetValue (0 );
260248 if (model_name.IsNull ()) {
261249 throw InvalidInputException (" Model name cannot be NULL" );
262250 }
263251 std::string model_name_str = model_name.ToString ();
264- char *json_meta = infera_get_model_metadata (model_name_str.c_str ());
252+ char *json_meta = infera_get_model_info (model_name_str.c_str ());
265253
266254 result.SetVectorType (VectorType::CONSTANT_VECTOR);
267255 ConstantVector::GetData<string_t >(result)[0 ] = StringVector::AddString (result, json_meta);
@@ -270,43 +258,26 @@ static void ModelMetadataFunc(DataChunk &args, ExpressionState &state, Vector &r
270258}
271259
272260static void LoadInternal (ExtensionLoader &loader) {
273- ScalarFunction load_onnx_model_func ( " load_onnx_model " , {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadOnnxModel );
274- loader.RegisterFunction (load_onnx_model_func );
261+ loader. RegisterFunction ( ScalarFunction ( " infera_load_model " , {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadModel) );
262+ loader.RegisterFunction (ScalarFunction ( " infera_unload_model " , {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadModel) );
275263
276- ScalarFunction unload_onnx_model_func (" unload_onnx_model" , {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadOnnxModel);
277- loader.RegisterFunction (unload_onnx_model_func);
278-
279- // Removed deprecated LogicalType::VARARG usage. Register multiple arities instead.
280- const idx_t MAX_FEATURES = 63 ; // features (total args = 1 + features)
264+ const idx_t MAX_FEATURES = 63 ;
281265 for (idx_t feature_count = 1 ; feature_count <= MAX_FEATURES; feature_count++) {
282266 vector<LogicalType> arg_types;
283267 arg_types.reserve (feature_count + 1 );
284- arg_types.push_back (LogicalType::VARCHAR); // model name
268+ arg_types.push_back (LogicalType::VARCHAR);
285269 for (idx_t i = 0 ; i < feature_count; i++) {
286- arg_types.push_back (LogicalType::FLOAT); // DuckDB will auto-cast other numerics
270+ arg_types.push_back (LogicalType::FLOAT);
287271 }
288- loader.RegisterFunction (ScalarFunction (" onnx_predict " , arg_types, LogicalType::FLOAT, OnnxPredict ));
289- loader.RegisterFunction (ScalarFunction (" onnx_predict_multi " , arg_types, LogicalType::VARCHAR, OnnxPredictMulti ));
272+ loader.RegisterFunction (ScalarFunction (" infera_predict " , arg_types, LogicalType::FLOAT, Predict ));
273+ loader.RegisterFunction (ScalarFunction (" infera_predict_multi " , arg_types, LogicalType::VARCHAR, PredictMulti ));
290274 }
291275
292- ScalarFunction infera_predict_blob_func (" infera_predict_blob" , {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST (LogicalType::FLOAT), InferaPredictBlob);
293- loader.RegisterFunction (infera_predict_blob_func);
294-
295- ScalarFunction list_models_func (" list_models" , {}, LogicalType::VARCHAR, ListModels);
296- loader.RegisterFunction (list_models_func);
297-
298- ScalarFunction model_info_func (" model_info" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, ModelInfo);
299- loader.RegisterFunction (model_info_func);
300-
301- ScalarFunction model_metadata_func (" model_metadata" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, ModelMetadataFunc);
302- loader.RegisterFunction (model_metadata_func);
303-
304- ScalarFunction infera_version_func (" infera_version" , {}, LogicalType::VARCHAR, InferaVersion);
305- loader.RegisterFunction (infera_version_func);
306-
307- auto autoload_pragma = PragmaFunction::PragmaCall (" infera_autoload_dir" , PragmaAutoloadDir,
308- {LogicalType::VARCHAR}, LogicalType::INVALID);
309- loader.RegisterFunction (autoload_pragma);
276+ loader.RegisterFunction (ScalarFunction (" infera_predict_from_blob" , {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST (LogicalType::FLOAT), PredictFromBlob));
277+ loader.RegisterFunction (ScalarFunction (" infera_get_loaded_models" , {}, LogicalType::VARCHAR, GetLoadedModels));
278+ loader.RegisterFunction (ScalarFunction (" infera_get_model_info" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, GetModelInfo));
279+ loader.RegisterFunction (ScalarFunction (" infera_get_version" , {}, LogicalType::VARCHAR, GetVersion));
280+ loader.RegisterFunction (ScalarFunction (" infera_set_autoload_dir" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetAutoloadDir));
310281}
311282
312283void InferaExtension::Load (ExtensionLoader &loader) { LoadInternal (loader); }
0 commit comments