@@ -326,12 +326,13 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
326
326
return NULL ;
327
327
}
328
328
329
- RAI_Model * RAI_ModelCreateORT (RAI_Backend backend , const char * devicestr , RAI_ModelOpts opts ,
330
- const char * modeldef , size_t modellen , RAI_Error * error ) {
329
+ int RAI_ModelCreateORT (RAI_Model * model , RAI_Error * error ) {
331
330
332
331
const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
333
332
char * * inputs_ = NULL ;
334
333
char * * outputs_ = NULL ;
334
+ size_t ninputs ;
335
+ size_t noutputs ;
335
336
OrtSessionOptions * session_options = NULL ;
336
337
OrtSession * session = NULL ;
337
338
OrtStatus * status = NULL ;
@@ -348,7 +349,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
348
349
}
349
350
350
351
ONNX_VALIDATE_STATUS (ort -> CreateSessionOptions (& session_options ))
351
- if (strcasecmp (devicestr , "CPU" ) == 0 ) {
352
+ if (strcasecmp (model -> devicestr , "CPU" ) == 0 ) {
352
353
// These are required to ensure that onnx will use the registered REDIS allocator (for
353
354
// a model that defined to run on CPU).
354
355
ONNX_VALIDATE_STATUS (
@@ -359,24 +360,31 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
359
360
// TODO: these options could be configured at the AI.CONFIG level
360
361
ONNX_VALIDATE_STATUS (ort -> SetSessionGraphOptimizationLevel (session_options , ORT_ENABLE_BASIC ))
361
362
ONNX_VALIDATE_STATUS (
362
- ort -> SetIntraOpNumThreads (session_options , (int )opts .backends_intra_op_parallelism ))
363
+ ort -> SetIntraOpNumThreads (session_options , (int )model -> opts .backends_intra_op_parallelism ))
363
364
ONNX_VALIDATE_STATUS (
364
- ort -> SetInterOpNumThreads (session_options , (int )opts .backends_inter_op_parallelism ))
365
+ ort -> SetInterOpNumThreads (session_options , (int )model -> opts .backends_inter_op_parallelism ))
365
366
366
367
// If the model is set for GPU, this will set CUDA provider for the session,
367
368
// so that onnx will use its own allocator for CUDA (not Redis allocator)
368
- if (!setDeviceId (devicestr , session_options , error )) {
369
+ if (!setDeviceId (model -> devicestr , session_options , error )) {
369
370
ort -> ReleaseSessionOptions (session_options );
370
- return NULL ;
371
+ return REDISMODULE_ERR ;
371
372
}
372
373
373
374
ONNX_VALIDATE_STATUS (
374
- ort -> CreateSessionFromArray (env , modeldef , modellen , session_options , & session ))
375
+ ort -> CreateSessionFromArray (env , model -> data , model -> datalen , session_options , & session ))
375
376
ort -> ReleaseSessionOptions (session_options );
376
377
378
+ model -> session = session ;
379
+
377
380
size_t n_input_nodes ;
378
- ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
379
381
size_t n_output_nodes ;
382
+
383
+ // We save the model's inputs and outputs only in the first time that we create the model.
384
+ // We might create the model again when loading from RDB, in this case the inputs and outputs
385
+ // are already loaded from RDB.
386
+ // if (!model->inputs) {
387
+ ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
380
388
ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes ))
381
389
382
390
inputs_ = array_new (char * , n_input_nodes );
@@ -393,27 +401,13 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
393
401
outputs_ = array_append (outputs_ , output_name );
394
402
}
395
403
396
- // Since ONNXRuntime doesn't have a re-serialization function,
397
- // we cache the blob in order to re-serialize it.
398
- // Not optimal for storage purposes, but again, it may be temporary
399
- char * buffer = RedisModule_Calloc (modellen , sizeof (* buffer ));
400
- memcpy (buffer , modeldef , modellen );
401
-
402
- RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
403
- ret -> model = NULL ;
404
- ret -> session = session ;
405
- ret -> backend = backend ;
406
- ret -> devicestr = RedisModule_Strdup (devicestr );
407
- ret -> refCount = 1 ;
408
- ret -> opts = opts ;
409
- ret -> data = buffer ;
410
- ret -> datalen = modellen ;
411
- ret -> ninputs = n_input_nodes ;
412
- ret -> noutputs = n_output_nodes ;
413
- ret -> inputs = inputs_ ;
414
- ret -> outputs = outputs_ ;
415
-
416
- return ret ;
404
+ model -> ninputs = n_input_nodes ;
405
+ model -> noutputs = n_output_nodes ;
406
+ model -> inputs = inputs_ ;
407
+ model -> outputs = outputs_ ;
408
+ //}
409
+
410
+ return REDISMODULE_OK ;
417
411
418
412
error :
419
413
RAI_SetError (error , RAI_EMODELCREATE , ort -> GetErrorMessage (status ));
@@ -438,28 +432,32 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
438
432
ort -> ReleaseSession (session );
439
433
}
440
434
ort -> ReleaseStatus (status );
441
- return NULL ;
435
+ return REDISMODULE_ERR ;
442
436
}
443
437
444
438
void RAI_ModelFreeORT (RAI_Model * model , RAI_Error * error ) {
445
439
const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
446
440
OrtStatus * status = NULL ;
447
441
448
- for (uint32_t i = 0 ; i < model -> ninputs ; i ++ ) {
449
- ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
442
+ if (model -> inputs ) {
443
+ for (uint32_t i = 0 ; i < model -> ninputs ; i ++ ) {
444
+ ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
445
+ }
446
+ array_free (model -> inputs );
447
+ model -> inputs = NULL ;
450
448
}
451
- array_free (model -> inputs );
452
449
453
- for (uint32_t i = 0 ; i < model -> noutputs ; i ++ ) {
454
- ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> outputs [i ]))
450
+ if (model -> outputs ) {
451
+ for (uint32_t i = 0 ; i < model -> noutputs ; i ++ ) {
452
+ ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> outputs [i ]))
453
+ }
454
+ array_free (model -> outputs );
455
+ model -> outputs = NULL ;
455
456
}
456
- array_free (model -> outputs );
457
457
458
- RedisModule_Free (model -> devicestr );
459
- RedisModule_Free (model -> data );
460
- ort -> ReleaseSession (model -> session );
461
- model -> model = NULL ;
462
- model -> session = NULL ;
458
+ if (model -> session ) {
459
+ ort -> ReleaseSession (model -> session );
460
+ }
463
461
return ;
464
462
465
463
error :
0 commit comments