|
19 | 19 | UNet2DConditionModel,
|
20 | 20 | )
|
21 | 21 | from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
|
22 |
| -from diffusers.utils.testing_utils import torch_device |
| 22 | +from diffusers.utils.testing_utils import require_torch_gpu, torch_device |
23 | 23 |
|
24 | 24 |
|
25 | 25 | class IsSafetensorsCompatibleTests(unittest.TestCase):
|
@@ -826,3 +826,104 @@ def test_video_to_video(self):
|
826 | 826 | with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
827 | 827 | _ = pipe(**inputs)
|
828 | 828 | self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
| 829 | + |
| 830 | + |
| 831 | +@require_torch_gpu |
| 832 | +class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase): |
| 833 | + expected_pipe_device = torch.device("cuda:0") |
| 834 | + expected_pipe_dtype = torch.float64 |
| 835 | + |
| 836 | + def get_dummy_components_image_generation(self): |
| 837 | + cross_attention_dim = 8 |
| 838 | + |
| 839 | + torch.manual_seed(0) |
| 840 | + unet = UNet2DConditionModel( |
| 841 | + block_out_channels=(4, 8), |
| 842 | + layers_per_block=1, |
| 843 | + sample_size=32, |
| 844 | + in_channels=4, |
| 845 | + out_channels=4, |
| 846 | + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), |
| 847 | + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), |
| 848 | + cross_attention_dim=cross_attention_dim, |
| 849 | + norm_num_groups=2, |
| 850 | + ) |
| 851 | + scheduler = DDIMScheduler( |
| 852 | + beta_start=0.00085, |
| 853 | + beta_end=0.012, |
| 854 | + beta_schedule="scaled_linear", |
| 855 | + clip_sample=False, |
| 856 | + set_alpha_to_one=False, |
| 857 | + ) |
| 858 | + torch.manual_seed(0) |
| 859 | + vae = AutoencoderKL( |
| 860 | + block_out_channels=[4, 8], |
| 861 | + in_channels=3, |
| 862 | + out_channels=3, |
| 863 | + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], |
| 864 | + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], |
| 865 | + latent_channels=4, |
| 866 | + norm_num_groups=2, |
| 867 | + ) |
| 868 | + torch.manual_seed(0) |
| 869 | + text_encoder_config = CLIPTextConfig( |
| 870 | + bos_token_id=0, |
| 871 | + eos_token_id=2, |
| 872 | + hidden_size=cross_attention_dim, |
| 873 | + intermediate_size=16, |
| 874 | + layer_norm_eps=1e-05, |
| 875 | + num_attention_heads=2, |
| 876 | + num_hidden_layers=2, |
| 877 | + pad_token_id=1, |
| 878 | + vocab_size=1000, |
| 879 | + ) |
| 880 | + text_encoder = CLIPTextModel(text_encoder_config) |
| 881 | + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") |
| 882 | + |
| 883 | + components = { |
| 884 | + "unet": unet, |
| 885 | + "scheduler": scheduler, |
| 886 | + "vae": vae, |
| 887 | + "text_encoder": text_encoder, |
| 888 | + "tokenizer": tokenizer, |
| 889 | + "safety_checker": None, |
| 890 | + "feature_extractor": None, |
| 891 | + "image_encoder": None, |
| 892 | + } |
| 893 | + return components |
| 894 | + |
| 895 | + def test_deterministic_device(self): |
| 896 | + components = self.get_dummy_components_image_generation() |
| 897 | + |
| 898 | + pipe = StableDiffusionPipeline(**components) |
| 899 | + pipe.to(device=torch_device, dtype=torch.float32) |
| 900 | + |
| 901 | + pipe.unet.to(device="cpu") |
| 902 | + pipe.vae.to(device="cuda") |
| 903 | + pipe.text_encoder.to(device="cuda:0") |
| 904 | + |
| 905 | + pipe_device = pipe.device |
| 906 | + |
| 907 | + self.assertEqual( |
| 908 | + self.expected_pipe_device, |
| 909 | + pipe_device, |
| 910 | + f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.", |
| 911 | + ) |
| 912 | + |
| 913 | + def test_deterministic_dtype(self): |
| 914 | + components = self.get_dummy_components_image_generation() |
| 915 | + |
| 916 | + pipe = StableDiffusionPipeline(**components) |
| 917 | + pipe.to(device=torch_device, dtype=torch.float32) |
| 918 | + |
| 919 | + pipe.unet.to(dtype=torch.float16) |
| 920 | + pipe.vae.to(dtype=torch.float32) |
| 921 | + pipe.text_encoder.to(dtype=torch.float64) |
| 922 | + |
| 923 | + pipe_dtype = pipe.dtype |
| 924 | + |
| 925 | + self.assertEqual( |
| 926 | + self.expected_pipe_dtype, |
| 927 | + pipe_dtype, |
| 928 | + f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.", |
| 929 | + ) |
0 commit comments