Skip to content

Commit be54a95

Browse files
dimitribarbotDN6
andauthored
Fix deterministic issue when getting pipeline dtype and device (#10696)
Co-authored-by: Dhruv Nair <[email protected]>
1 parent 6b9a333 commit be54a95

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

src/diffusers/pipelines/pipeline_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1610,7 +1610,7 @@ def _get_signature_keys(cls, obj):
16101610
expected_modules.add(name)
16111611
optional_parameters.remove(name)
16121612

1613-
return expected_modules, optional_parameters
1613+
return sorted(expected_modules), sorted(optional_parameters)
16141614

16151615
@classmethod
16161616
def _get_signature_types(cls):
@@ -1652,10 +1652,12 @@ def components(self) -> Dict[str, Any]:
16521652
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
16531653
}
16541654

1655-
if set(components.keys()) != expected_modules:
1655+
actual = sorted(set(components.keys()))
1656+
expected = sorted(expected_modules)
1657+
if actual != expected:
16561658
raise ValueError(
16571659
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
1658-
f" {expected_modules} to be defined, but {components.keys()} are defined."
1660+
f" {expected} to be defined, but {actual} are defined."
16591661
)
16601662

16611663
return components

tests/pipelines/test_pipeline_utils.py

+102-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
UNet2DConditionModel,
2020
)
2121
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
22-
from diffusers.utils.testing_utils import torch_device
22+
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
2323

2424

2525
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -826,3 +826,104 @@ def test_video_to_video(self):
826826
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
827827
_ = pipe(**inputs)
828828
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
829+
830+
831+
@require_torch_gpu
832+
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
833+
expected_pipe_device = torch.device("cuda:0")
834+
expected_pipe_dtype = torch.float64
835+
836+
def get_dummy_components_image_generation(self):
837+
cross_attention_dim = 8
838+
839+
torch.manual_seed(0)
840+
unet = UNet2DConditionModel(
841+
block_out_channels=(4, 8),
842+
layers_per_block=1,
843+
sample_size=32,
844+
in_channels=4,
845+
out_channels=4,
846+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
847+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
848+
cross_attention_dim=cross_attention_dim,
849+
norm_num_groups=2,
850+
)
851+
scheduler = DDIMScheduler(
852+
beta_start=0.00085,
853+
beta_end=0.012,
854+
beta_schedule="scaled_linear",
855+
clip_sample=False,
856+
set_alpha_to_one=False,
857+
)
858+
torch.manual_seed(0)
859+
vae = AutoencoderKL(
860+
block_out_channels=[4, 8],
861+
in_channels=3,
862+
out_channels=3,
863+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
864+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
865+
latent_channels=4,
866+
norm_num_groups=2,
867+
)
868+
torch.manual_seed(0)
869+
text_encoder_config = CLIPTextConfig(
870+
bos_token_id=0,
871+
eos_token_id=2,
872+
hidden_size=cross_attention_dim,
873+
intermediate_size=16,
874+
layer_norm_eps=1e-05,
875+
num_attention_heads=2,
876+
num_hidden_layers=2,
877+
pad_token_id=1,
878+
vocab_size=1000,
879+
)
880+
text_encoder = CLIPTextModel(text_encoder_config)
881+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
882+
883+
components = {
884+
"unet": unet,
885+
"scheduler": scheduler,
886+
"vae": vae,
887+
"text_encoder": text_encoder,
888+
"tokenizer": tokenizer,
889+
"safety_checker": None,
890+
"feature_extractor": None,
891+
"image_encoder": None,
892+
}
893+
return components
894+
895+
def test_deterministic_device(self):
896+
components = self.get_dummy_components_image_generation()
897+
898+
pipe = StableDiffusionPipeline(**components)
899+
pipe.to(device=torch_device, dtype=torch.float32)
900+
901+
pipe.unet.to(device="cpu")
902+
pipe.vae.to(device="cuda")
903+
pipe.text_encoder.to(device="cuda:0")
904+
905+
pipe_device = pipe.device
906+
907+
self.assertEqual(
908+
self.expected_pipe_device,
909+
pipe_device,
910+
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
911+
)
912+
913+
def test_deterministic_dtype(self):
914+
components = self.get_dummy_components_image_generation()
915+
916+
pipe = StableDiffusionPipeline(**components)
917+
pipe.to(device=torch_device, dtype=torch.float32)
918+
919+
pipe.unet.to(dtype=torch.float16)
920+
pipe.vae.to(dtype=torch.float32)
921+
pipe.text_encoder.to(dtype=torch.float64)
922+
923+
pipe_dtype = pipe.dtype
924+
925+
self.assertEqual(
926+
self.expected_pipe_dtype,
927+
pipe_dtype,
928+
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
929+
)

0 commit comments

Comments
 (0)