Skip to content

Commit 0616eeb

Browse files
committed
Enable async modelset + refactor of model creation (draft)
1 parent 43b91ff commit 0616eeb

29 files changed

+778
-634
lines changed

src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ ADD_LIBRARY(redisai_obj OBJECT
2727
execution/command_parser.c
2828
execution/run_info.c
2929
execution/background_workers.c
30+
execution/background_modelset.c
3031
config/config.c
3132
execution/DAG/dag.c
3233
execution/DAG/dag_parser.c

src/backends/backends.c

+4-9
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) {
8888
init_backend(RedisModule_GetApi);
8989

9090
backend.model_create_with_nodes =
91-
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, size_t, const char **, size_t,
92-
const char **, const char *, size_t,
93-
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF");
91+
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF");
9492
if (backend.model_create_with_nodes == NULL) {
9593
dlclose(handle);
9694
RedisModule_Log(ctx, "warning",
@@ -180,8 +178,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) {
180178
init_backend(RedisModule_GetApi);
181179

182180
backend.model_create =
183-
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
184-
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite");
181+
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite");
185182
if (backend.model_create == NULL) {
186183
dlclose(handle);
187184
RedisModule_Log(ctx, "warning",
@@ -272,8 +269,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
272269
init_backend(RedisModule_GetApi);
273270

274271
backend.model_create =
275-
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
276-
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch");
272+
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch");
277273
if (backend.model_create == NULL) {
278274
dlclose(handle);
279275
RedisModule_Log(ctx, "warning",
@@ -396,8 +392,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) {
396392
init_backend(RedisModule_GetApi);
397393

398394
backend.model_create =
399-
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
400-
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT");
395+
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT");
401396
if (backend.model_create == NULL) {
402397
dlclose(handle);
403398
RedisModule_Log(ctx, "warning",

src/backends/backends.h

+4-7
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,12 @@
4040
*/
4141
typedef struct RAI_LoadedBackend {
4242
// ** model_create_with_nodes **: A callback function pointer that creates a
43-
// model given the RAI_ModelOpts and input and output nodes
44-
RAI_Model *(*model_create_with_nodes)(RAI_Backend, const char *, RAI_ModelOpts, size_t,
45-
const char **, size_t, const char **, const char *,
46-
size_t, RAI_Error *);
43+
// model given the RAI_ModelOpts and input and output nodes (which are stored in the model).
44+
int (*model_create_with_nodes)(RAI_Model *, RAI_Error *);
4745

4846
// ** model_create **: A callback function pointer that creates a model given
49-
// the RAI_ModelOpts
50-
RAI_Model *(*model_create)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
51-
RAI_Error *);
47+
// the RAI_ModelOpts (which are stored in the model).
48+
int (*model_create)(RAI_Model *, RAI_Error *);
5249

5350
// ** model_free **: A callback function pointer that frees a model given the
5451
// RAI_Model pointer

src/backends/onnxruntime.c

+40-42
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,13 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
326326
return NULL;
327327
}
328328

329-
RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
330-
const char *modeldef, size_t modellen, RAI_Error *error) {
329+
int RAI_ModelCreateORT(RAI_Model *model, RAI_Error *error) {
331330

332331
const OrtApi *ort = OrtGetApiBase()->GetApi(1);
333332
char **inputs_ = NULL;
334333
char **outputs_ = NULL;
334+
size_t ninputs;
335+
size_t noutputs;
335336
OrtSessionOptions *session_options = NULL;
336337
OrtSession *session = NULL;
337338
OrtStatus *status = NULL;
@@ -348,7 +349,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
348349
}
349350

350351
ONNX_VALIDATE_STATUS(ort->CreateSessionOptions(&session_options))
351-
if (strcasecmp(devicestr, "CPU") == 0) {
352+
if (strcasecmp(model->devicestr, "CPU") == 0) {
352353
// These are required to ensure that onnx will use the registered REDIS allocator (for
353354
// a model that defined to run on CPU).
354355
ONNX_VALIDATE_STATUS(
@@ -359,24 +360,31 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
359360
// TODO: these options could be configured at the AI.CONFIG level
360361
ONNX_VALIDATE_STATUS(ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC))
361362
ONNX_VALIDATE_STATUS(
362-
ort->SetIntraOpNumThreads(session_options, (int)opts.backends_intra_op_parallelism))
363+
ort->SetIntraOpNumThreads(session_options, (int)model->opts.backends_intra_op_parallelism))
363364
ONNX_VALIDATE_STATUS(
364-
ort->SetInterOpNumThreads(session_options, (int)opts.backends_inter_op_parallelism))
365+
ort->SetInterOpNumThreads(session_options, (int)model->opts.backends_inter_op_parallelism))
365366

366367
// If the model is set for GPU, this will set CUDA provider for the session,
367368
// so that onnx will use its own allocator for CUDA (not Redis allocator)
368-
if (!setDeviceId(devicestr, session_options, error)) {
369+
if (!setDeviceId(model->devicestr, session_options, error)) {
369370
ort->ReleaseSessionOptions(session_options);
370-
return NULL;
371+
return REDISMODULE_ERR;
371372
}
372373

373374
ONNX_VALIDATE_STATUS(
374-
ort->CreateSessionFromArray(env, modeldef, modellen, session_options, &session))
375+
ort->CreateSessionFromArray(env, model->data, model->datalen, session_options, &session))
375376
ort->ReleaseSessionOptions(session_options);
376377

378+
model->session = session;
379+
377380
size_t n_input_nodes;
378-
ONNX_VALIDATE_STATUS(ort->SessionGetInputCount(session, &n_input_nodes))
379381
size_t n_output_nodes;
382+
383+
// We save the model's inputs and outputs only in the first time that we create the model.
384+
// We might create the model again when loading from RDB, in this case the inputs and outputs
385+
// are already loaded from RDB.
386+
// if (!model->inputs) {
387+
ONNX_VALIDATE_STATUS(ort->SessionGetInputCount(session, &n_input_nodes))
380388
ONNX_VALIDATE_STATUS(ort->SessionGetOutputCount(session, &n_output_nodes))
381389

382390
inputs_ = array_new(char *, n_input_nodes);
@@ -393,27 +401,13 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
393401
outputs_ = array_append(outputs_, output_name);
394402
}
395403

396-
// Since ONNXRuntime doesn't have a re-serialization function,
397-
// we cache the blob in order to re-serialize it.
398-
// Not optimal for storage purposes, but again, it may be temporary
399-
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
400-
memcpy(buffer, modeldef, modellen);
401-
402-
RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret));
403-
ret->model = NULL;
404-
ret->session = session;
405-
ret->backend = backend;
406-
ret->devicestr = RedisModule_Strdup(devicestr);
407-
ret->refCount = 1;
408-
ret->opts = opts;
409-
ret->data = buffer;
410-
ret->datalen = modellen;
411-
ret->ninputs = n_input_nodes;
412-
ret->noutputs = n_output_nodes;
413-
ret->inputs = inputs_;
414-
ret->outputs = outputs_;
415-
416-
return ret;
404+
model->ninputs = n_input_nodes;
405+
model->noutputs = n_output_nodes;
406+
model->inputs = inputs_;
407+
model->outputs = outputs_;
408+
//}
409+
410+
return REDISMODULE_OK;
417411

418412
error:
419413
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
@@ -438,28 +432,32 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
438432
ort->ReleaseSession(session);
439433
}
440434
ort->ReleaseStatus(status);
441-
return NULL;
435+
return REDISMODULE_ERR;
442436
}
443437

444438
void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
445439
const OrtApi *ort = OrtGetApiBase()->GetApi(1);
446440
OrtStatus *status = NULL;
447441

448-
for (uint32_t i = 0; i < model->ninputs; i++) {
449-
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->inputs[i]))
442+
if (model->inputs) {
443+
for (uint32_t i = 0; i < model->ninputs; i++) {
444+
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->inputs[i]))
445+
}
446+
array_free(model->inputs);
447+
model->inputs = NULL;
450448
}
451-
array_free(model->inputs);
452449

453-
for (uint32_t i = 0; i < model->noutputs; i++) {
454-
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->outputs[i]))
450+
if (model->outputs) {
451+
for (uint32_t i = 0; i < model->noutputs; i++) {
452+
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->outputs[i]))
453+
}
454+
array_free(model->outputs);
455+
model->outputs = NULL;
455456
}
456-
array_free(model->outputs);
457457

458-
RedisModule_Free(model->devicestr);
459-
RedisModule_Free(model->data);
460-
ort->ReleaseSession(model->session);
461-
model->model = NULL;
462-
model->session = NULL;
458+
if (model->session) {
459+
ort->ReleaseSession(model->session);
460+
}
463461
return;
464462

465463
error:

src/backends/onnxruntime.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ unsigned long long RAI_GetMemoryAccessORT(void);
1111

1212
int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *));
1313

14-
RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
15-
const char *modeldef, size_t modellen, RAI_Error *err);
14+
int RAI_ModelCreateORT(RAI_Model *model, RAI_Error *err);
1615

1716
void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error);
1817

0 commit comments

Comments
 (0)