@@ -325,62 +325,137 @@ def test_decoupled_execute_cancel(self):
325325 self .assertIn ("[execute_cancel] Request cancelled at " , log_text )
326326
327327 def test_decoupled_bls_cancel (self ):
328- model_name = "decoupled_bls_cancel"
328+ model_names = [ "decoupled_bls_cancel" , "decoupled_bls_async_cancel" ]
329329 input_value = 1
330330 max_sum_value = 10
331+ ignore_cancel = False
331332 user_data = UserData ()
333+ for model_name in model_names :
334+ with self ._shm_leak_detector .Probe () as shm_probe :
335+ with grpcclient .InferenceServerClient (
336+ f"{ _tritonserver_ipaddr } :8001"
337+ ) as client :
338+ client .start_stream (callback = partial (callback , user_data ))
339+ input_data = np .array ([input_value ], dtype = np .int32 )
340+ max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
341+ ignore_cancel_data = np .array ([ignore_cancel ], dtype = np .bool_ )
342+ inputs = [
343+ grpcclient .InferInput (
344+ "INPUT" ,
345+ input_data .shape ,
346+ np_to_triton_dtype (input_data .dtype ),
347+ ),
348+ grpcclient .InferInput (
349+ "MAX_SUM" ,
350+ max_sum_data .shape ,
351+ np_to_triton_dtype (max_sum_data .dtype ),
352+ ),
353+ grpcclient .InferInput (
354+ "IGNORE_CANCEL" ,
355+ ignore_cancel_data .shape ,
356+ np_to_triton_dtype (ignore_cancel_data .dtype ),
357+ ),
358+ ]
359+ inputs [0 ].set_data_from_numpy (input_data )
360+ inputs [1 ].set_data_from_numpy (max_sum_data )
361+ inputs [2 ].set_data_from_numpy (ignore_cancel_data )
362+ client .async_stream_infer (model_name , inputs )
363+
364+ # Check the results of the decoupled model using BLS
365+ def check_result (result ):
366+ # Make sure the result is not an exception
367+ self .assertIsNot (type (result ), InferenceServerException )
368+ is_cancelled = result .as_numpy ("IS_CANCELLED" )
369+ self .assertTrue (
370+ is_cancelled [0 ],
371+ "error: expected the request to be cancelled" ,
372+ )
332373
333- with self ._shm_leak_detector .Probe () as shm_probe :
334- with grpcclient .InferenceServerClient (
335- f"{ _tritonserver_ipaddr } :8001"
336- ) as client :
337- client .start_stream (callback = partial (callback , user_data ))
338- input_data = np .array ([input_value ], dtype = np .int32 )
339- max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
340- inputs = [
341- grpcclient .InferInput (
342- "INPUT" , input_data .shape , np_to_triton_dtype (input_data .dtype )
343- ),
344- grpcclient .InferInput (
345- "MAX_SUM" ,
346- max_sum_data .shape ,
347- np_to_triton_dtype (max_sum_data .dtype ),
348- ),
349- ]
350- inputs [0 ].set_data_from_numpy (input_data )
351- inputs [1 ].set_data_from_numpy (max_sum_data )
352- client .async_stream_infer (model_name , inputs )
374+ sum_data = result .as_numpy ("SUM" )
375+ self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
376+ self .assertTrue (
377+ np .array_equal (sum_data , max_sum_data ),
378+ "error: expected output {} to match input {}" .format (
379+ sum_data , max_sum_data
380+ ),
381+ )
353382
354- # Check the results of the decoupled model using BLS
355- def check_result (result ):
356- # Make sure the result is not an exception
357- self .assertIsNot (type (result ), InferenceServerException )
383+ result = user_data ._completed_requests .get ()
384+ check_result (result )
358385
359- sum_data = result .as_numpy ("SUM" )
360- self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
361- self .assertTrue (
362- np .array_equal (sum_data , max_sum_data ),
363- "error: expected output {} to match input {}" .format (
364- sum_data , max_sum_data
386+ def test_decoupled_bls_ignore_cancel (self ):
387+ model_names = ["decoupled_bls_cancel" , "decoupled_bls_async_cancel" ]
388+ input_value = 1
389+ max_sum_value = 10
390+ ignore_cancel = True
391+ user_data = UserData ()
392+ for model_name in model_names :
393+ with self ._shm_leak_detector .Probe () as shm_probe :
394+ with grpcclient .InferenceServerClient (
395+ f"{ _tritonserver_ipaddr } :8001"
396+ ) as client :
397+ client .start_stream (callback = partial (callback , user_data ))
398+ input_data = np .array ([input_value ], dtype = np .int32 )
399+ max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
400+ ignore_cancel_data = np .array ([ignore_cancel ], dtype = np .bool_ )
401+ inputs = [
402+ grpcclient .InferInput (
403+ "INPUT" ,
404+ input_data .shape ,
405+ np_to_triton_dtype (input_data .dtype ),
365406 ),
366- )
407+ grpcclient .InferInput (
408+ "MAX_SUM" ,
409+ max_sum_data .shape ,
410+ np_to_triton_dtype (max_sum_data .dtype ),
411+ ),
412+ grpcclient .InferInput (
413+ "IGNORE_CANCEL" ,
414+ ignore_cancel_data .shape ,
415+ np_to_triton_dtype (ignore_cancel_data .dtype ),
416+ ),
417+ ]
418+ inputs [0 ].set_data_from_numpy (input_data )
419+ inputs [1 ].set_data_from_numpy (max_sum_data )
420+ inputs [2 ].set_data_from_numpy (ignore_cancel_data )
421+ client .async_stream_infer (model_name , inputs )
422+
423+ # Check the results of the decoupled model using BLS
424+ def check_result (result ):
425+ # Make sure the result is not an exception
426+ self .assertIsNot (type (result ), InferenceServerException )
427+ is_cancelled = result .as_numpy ("IS_CANCELLED" )
428+ self .assertFalse (
429+ is_cancelled [0 ],
430+ "error: expected the request not being cancelled" ,
431+ )
367432
368- result = user_data ._completed_requests .get ()
369- check_result (result )
433+ sum_data = result .as_numpy ("SUM" )
434+ self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
435+ self .assertTrue (
436+ sum_data > max_sum_data ,
437+ "error: expected sum_data {} to be greater than max_sum_data {}" .format (
438+ sum_data , max_sum_data
439+ ),
440+ )
441+
442+ result = user_data ._completed_requests .get ()
443+ check_result (result )
370444
371- def test_decoupled_bls_async_cancel (self ):
372- model_name = "decoupled_bls_async_cancel "
445+ def test_decoupled_bls_cancel_after_completion (self ):
446+ model_name = "decoupled_bls_cancel_after_complete "
373447 input_value = 1
374448 max_sum_value = 10
449+ ignore_cancel = False
375450 user_data = UserData ()
376-
377451 with self ._shm_leak_detector .Probe () as shm_probe :
378452 with grpcclient .InferenceServerClient (
379453 f"{ _tritonserver_ipaddr } :8001"
380454 ) as client :
381455 client .start_stream (callback = partial (callback , user_data ))
382456 input_data = np .array ([input_value ], dtype = np .int32 )
383457 max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
458+ ignore_cancel_data = np .array ([ignore_cancel ], dtype = np .bool_ )
384459 inputs = [
385460 grpcclient .InferInput (
386461 "INPUT" , input_data .shape , np_to_triton_dtype (input_data .dtype )
@@ -390,15 +465,25 @@ def test_decoupled_bls_async_cancel(self):
390465 max_sum_data .shape ,
391466 np_to_triton_dtype (max_sum_data .dtype ),
392467 ),
468+ grpcclient .InferInput (
469+ "IGNORE_CANCEL" ,
470+ ignore_cancel_data .shape ,
471+ np_to_triton_dtype (ignore_cancel_data .dtype ),
472+ ),
393473 ]
394474 inputs [0 ].set_data_from_numpy (input_data )
395475 inputs [1 ].set_data_from_numpy (max_sum_data )
476+ inputs [2 ].set_data_from_numpy (ignore_cancel_data )
396477 client .async_stream_infer (model_name , inputs )
397478
398479 # Check the results of the decoupled model using BLS
399480 def check_result (result ):
400481 # Make sure the result is not an exception
401482 self .assertIsNot (type (result ), InferenceServerException )
483+ is_cancelled = result .as_numpy ("IS_CANCELLED" )
484+ self .assertTrue (
485+ is_cancelled [0 ], "error: expected the request to be cancelled"
486+ )
402487
403488 sum_data = result .as_numpy ("SUM" )
404489 self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
0 commit comments