@@ -383,6 +383,48 @@ TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
383
383
EXPECT_EQ (res, UMF_RESULT_ERROR_INVALID_ARGUMENT);
384
384
}
385
385
386
+ TEST_P (umfCUDAProviderTest, cudaProviderInvalidCreate) {
387
+ CUdevice device;
388
+ int ret = get_cuda_device (&device);
389
+ ASSERT_EQ (ret, 0 );
390
+
391
+ CUcontext ctx;
392
+ ret = create_context (device, &ctx);
393
+ ASSERT_EQ (ret, 0 );
394
+
395
+ // wrong memory type
396
+ umf_cuda_memory_provider_params_handle_t params_wrong_memtype =
397
+ create_cuda_prov_params (ctx, device,
398
+ static_cast <umf_usm_memory_type_t >(0xFFFF ), 0 );
399
+ ASSERT_NE (params_wrong_memtype, nullptr );
400
+ umf_memory_provider_handle_t provider = nullptr ;
401
+ umf_result_t umf_result = umfMemoryProviderCreate (
402
+ umfCUDAMemoryProviderOps (), params_wrong_memtype, &provider);
403
+ ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
404
+ umf_result = umfCUDAMemoryProviderParamsDestroy (params_wrong_memtype);
405
+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
406
+
407
+ // wrong context
408
+ umf_cuda_memory_provider_params_handle_t params_wrong_ctx =
409
+ create_cuda_prov_params (nullptr , device, UMF_MEMORY_TYPE_HOST, 0 );
410
+ ASSERT_NE (params_wrong_ctx, nullptr );
411
+ umf_result = umfMemoryProviderCreate (umfCUDAMemoryProviderOps (),
412
+ params_wrong_ctx, &provider);
413
+ ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
414
+ umf_result = umfCUDAMemoryProviderParamsDestroy (params_wrong_ctx);
415
+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
416
+
417
+ // wrong device
418
+ umf_cuda_memory_provider_params_handle_t params_wrong_device =
419
+ create_cuda_prov_params (ctx, (CUdevice)-1 , UMF_MEMORY_TYPE_HOST, 0 );
420
+ ASSERT_NE (params_wrong_device, nullptr );
421
+ umf_result = umfMemoryProviderCreate (umfCUDAMemoryProviderOps (),
422
+ params_wrong_device, &provider);
423
+ ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
424
+ umf_result = umfCUDAMemoryProviderParamsDestroy (params_wrong_device);
425
+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
426
+ }
427
+
386
428
TEST_P (umfCUDAProviderTest, multiContext) {
387
429
CUdevice device;
388
430
int ret = get_cuda_device (&device);
0 commit comments