File tree Expand file tree Collapse file tree 2 files changed +5
-5
lines changed
Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -526,7 +526,7 @@ def test_moving_to_cpu_throws_warning(self):
526526 reason = "Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release." ,
527527 strict = True ,
528528 )
529- def test_pipeline_device_placement_works_with_nf4 (self ):
529+ def test_pipeline_cuda_placement_works_with_nf4 (self ):
530530 transformer_nf4_config = BitsAndBytesConfig (
531531 load_in_4bit = True ,
532532 bnb_4bit_quant_type = "nf4" ,
@@ -560,7 +560,7 @@ def test_pipeline_device_placement_works_with_nf4(self):
560560 ).to (torch_device )
561561
562562 # Check if inference works.
563- _ = pipeline_4bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
563+ _ = pipeline_4bit (self . prompt , max_sequence_length = 20 , num_inference_steps = 2 )
564564
565565 del pipeline_4bit
566566
Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def test_generate_quality_dequantize(self):
492492 self .assertTrue (max_diff < 1e-2 )
493493
494494 # 8bit models cannot be offloaded to CPU.
495- self .assertTrue (self .pipeline_8bit .transformer .device .type == "cuda" )
495+ self .assertTrue (self .pipeline_8bit .transformer .device .type == torch_device )
496496 # calling it again shouldn't be a problem
497497 _ = self .pipeline_8bit (
498498 prompt = self .prompt ,
@@ -529,10 +529,10 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self):
529529 transformer = transformer_8bit ,
530530 text_encoder_3 = text_encoder_3_8bit ,
531531 torch_dtype = torch .float16 ,
532- ).to ("cuda" )
532+ ).to (torch_device )
533533
534534 # Check if inference works.
535- _ = pipeline_8bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
535+ _ = pipeline_8bit (self . prompt , max_sequence_length = 20 , num_inference_steps = 2 )
536536
537537 del pipeline_8bit
538538
You can’t perform that action at this time.
0 commit comments