diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml new file mode 100644 index 000000000000..e01345e32524 --- /dev/null +++ b/.github/workflows/pr_modular_tests.yml @@ -0,0 +1,141 @@ +name: Fast PR tests for Modular + +on: + pull_request: + branches: [main] + paths: + - "src/diffusers/modular_pipelines/**.py" + - "src/diffusers/models/modeling_utils.py" + - "src/diffusers/models/model_loading_utils.py" + - "src/diffusers/pipelines/pipeline_utils.py" + - "src/diffusers/pipeline_loading_utils.py" + - "src/diffusers/loaders/lora_base.py" + - "src/diffusers/loaders/lora_pipeline.py" + - "src/diffusers/loaders/peft.py" + - "tests/modular_pipelines/**.py" + - ".github/**.yml" + - "utils/**.py" + - "setup.py" + push: + branches: + - ci-* + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + DIFFUSERS_IS_CI: yes + HF_HUB_ENABLE_HF_TRANSFER: 1 + OMP_NUM_THREADS: 4 + MKL_NUM_THREADS: 4 + PYTEST_TIMEOUT: 60 + +jobs: + check_code_quality: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: make quality + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY + + check_repository_consistency: + needs: check_code_quality + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check repo consistency + run: | + python utils/check_copies.py + python utils/check_dummies.py + python utils/check_support_list.py + make deps_table_check_updated + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY + + run_fast_tests: + needs: [check_code_quality, check_repository_consistency] + strategy: + fail-fast: false + matrix: + config: + - name: Fast PyTorch Modular Pipeline CPU tests + framework: pytorch_pipelines + runner: aws-highmemory-32-plus + image: diffusers/diffusers-pytorch-cpu + report: torch_cpu_modular_pipelines + + name: ${{ matrix.config.name }} + + runs-on: + group: ${{ matrix.config.runner }} + + container: + image: ${{ matrix.config.image }} + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + + defaults: + run: + shell: bash + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Install dependencies + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install -e [quality,test] + pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps + pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps + + - name: Environment + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python utils/print_env.py + + - name: Run fast PyTorch Pipeline CPU tests + if: ${{ matrix.config.framework == 'pytorch_pipelines' }} + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "not Flax and not Onnx" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests/modular_pipelines + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports + path: reports + + diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 0ef1d59f4d65..294ebe8ae933 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -493,6 +493,22 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) + @property + def input_names(self) -> List[str]: + return [input_param.name for input_param in self.inputs] + + @property + def intermediate_input_names(self) -> List[str]: + return [input_param.name for input_param in self.intermediate_inputs] + + @property + def intermediate_output_names(self) -> List[str]: + return [output_param.name for output_param in self.intermediate_outputs] + + @property + def output_names(self) -> List[str]: + return [output_param.name for output_param in self.outputs] + class PipelineBlock(ModularPipelineBlocks): """ @@ -2839,3 +2855,8 @@ def _dict_to_component_spec( type_hint=type_hint, **spec_dict, ) + + def set_progress_bar_config(self, **kwargs): + for sub_block_name, sub_block in self.blocks.sub_blocks.items(): + if hasattr(sub_block, "set_progress_bar_config"): + sub_block.set_progress_bar_config(**kwargs) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index c56f4af1b8a5..1800a613ec0f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -744,8 +744,6 @@ def prepare_latents_inpaint( timestep=None, is_strength_max=True, add_noise=True, - return_noise=False, - return_image_latents=False, ): shape = ( batch_size, @@ -768,7 +766,7 @@ def prepare_latents_inpaint( if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): + elif latents is None and not is_strength_max: image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) @@ -786,13 +784,7 @@ def prepare_latents_inpaint( noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) + outputs = (latents, noise, image_latents) return outputs @@ -864,7 +856,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - block_state.latents, block_state.noise = self.prepare_latents_inpaint( + block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( components, block_state.batch_size * block_state.num_images_per_prompt, components.num_channels_latents, @@ -878,8 +870,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline timestep=block_state.latent_timestep, is_strength_max=block_state.is_strength_max, add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, ) # 7. Prepare mask latent variables diff --git a/tests/modular_pipelines/__init__.py b/tests/modular_pipelines/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/stable_diffusion_xl/__init__.py b/tests/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py new file mode 100644 index 000000000000..4127d00c8e1a --- /dev/null +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -0,0 +1,466 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest +from typing import Any, Dict + +import numpy as np +import torch +from PIL import Image + +from diffusers import ( + ClassifierFreeGuidance, + StableDiffusionXLAutoBlocks, + StableDiffusionXLModularPipeline, +) +from diffusers.loaders import ModularIPAdapterMixin +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ...models.unets.test_models_unet_2d_condition import ( + create_ip_adapter_state_dict, +) +from ..test_modular_pipelines_common import ( + ModularPipelineTesterMixin, +) + + +enable_full_determinism() + + +class SDXLModularTests: + """ + This mixin defines method to create pipeline, base input and base test across all SDXL modular tests. + """ + + pipeline_class = StableDiffusionXLModularPipeline + pipeline_blocks_class = StableDiffusionXLAutoBlocks + repo = "hf-internal-testing/tiny-sdxl-modular" + params = frozenset( + [ + "prompt", + "height", + "width", + "negative_prompt", + "cross_attention_kwargs", + "image", + "mask_image", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"]) + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) + pipeline.load_default_components(torch_dtype=torch_dtype) + return pipeline + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + sd_pipe = self.get_pipeline() + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs, output="images") + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == expected_image_shape + + assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, ( + "Image Slice does not match expected slice" + ) + + +class SDXLModularIPAdapterTests: + """ + This mixin is designed to test IP Adapter. + """ + + def test_pipeline_inputs_and_blocks(self): + blocks = self.pipeline_blocks_class() + parameters = blocks.input_names + + assert issubclass(self.pipeline_class, ModularIPAdapterMixin) + assert "ip_adapter_image" in parameters, ( + "`ip_adapter_image` argument must be supported by the `__call__` method" + ) + assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block" + + _ = blocks.sub_blocks.pop("ip_adapter") + parameters = blocks.input_names + intermediate_parameters = blocks.intermediate_input_names + assert "ip_adapter_image" not in parameters, ( + "`ip_adapter_image` argument must be removed from the `__call__` method" + ) + assert "ip_adapter_image_embeds" not in intermediate_parameters, ( + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" + ) + + def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((1, 1, cross_attention_dim), device=torch_device) + + def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device) + + def _get_dummy_masks(self, input_size: int = 64): + _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) + _masks[0, :, :, : int(input_size / 2)] = 1 + return _masks + + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + parameters = blocks.input_names + if "image" in parameters and "strength" in parameters: + inputs["num_inference_steps"] = 4 + + inputs["output_type"] = "np" + return inputs + + def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for IP-Adapter. + + The following scenarios are tested: + - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + blocks = self.pipeline_blocks_class() + _ = blocks.sub_blocks.pop("ip_adapter") + pipe = blocks.init_pipeline(self.repo) + pipe.load_default_components(torch_dtype=torch.float32) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim") + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + if expected_pipe_slice is None: + output_without_adapter = pipe(**inputs, output="images") + else: + output_without_adapter = expected_pipe_slice + + # 1. Single IP-Adapter test cases + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + assert max_diff_without_adapter_scale < expected_max_diff, ( + "Output without ip-adapter must be same as normal inference" + ) + assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference" + + # 2. Multi IP-Adapter test cases + adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) + adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) + + # forward pass with multi ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([0.0, 0.0]) + output_without_multi_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with multi ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([42.0, 42.0]) + output_with_multi_adapter_scale = pipe(**inputs, output="images") + if expected_pipe_slice is not None: + output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_multi_adapter_scale = np.abs( + output_without_multi_adapter_scale - output_without_adapter + ).max() + max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() + assert max_diff_without_multi_adapter_scale < expected_max_diff, ( + "Output without multi-ip-adapter must be same as normal inference" + ) + assert max_diff_with_multi_adapter_scale > 1e-2, ( + "Output with multi-ip-adapter scale must be different from normal inference" + ) + + +class SDXLModularControlNetTests: + """ + This mixin is designed to test ControlNet. + """ + + def test_pipeline_inputs(self): + blocks = self.pipeline_blocks_class() + parameters = blocks.input_names + + assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method" + assert "controlnet_conditioning_scale" in parameters, ( + "`controlnet_conditioning_scale` argument must be supported by the `__call__` method" + ) + + def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]): + controlnet_embedder_scale_factor = 2 + image = torch.randn( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + device=torch_device, + ) + inputs["control_image"] = image + return inputs + + def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None): + r"""Tests for ControlNet. + + The following scenarios are tested: + - Single ControlNet with scale=0 should produce same output as no ControlNet. + - Single ControlNet with scale!=0 should produce different output compared to no ControlNet. + """ + # Raising the tolerance for this test when it's run on a CPU because we + # compare against static slices and that can be shaky (with a VVVV low probability). + expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff + + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # forward pass without controlnet + inputs = self.get_dummy_inputs(torch_device) + output_without_controlnet = pipe(**inputs, output="images") + output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten() + + # forward pass with single controlnet, but scale=0 which should have no effect + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + inputs["controlnet_conditioning_scale"] = 0.0 + output_without_controlnet_scale = pipe(**inputs, output="images") + output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single controlnet, but with scale of adapter weights + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + inputs["controlnet_conditioning_scale"] = 42.0 + output_with_controlnet_scale = pipe(**inputs, output="images") + output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max() + max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max() + + assert max_diff_without_controlnet_scale < expected_max_diff, ( + "Output without controlnet must be same as normal inference" + ) + assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference" + + def test_controlnet_cfg(self): + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # forward pass with CFG not applied + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) + + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + out_no_cfg = pipe(**inputs, output="images") + + # forward pass with CFG applied + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) + inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device)) + out_cfg = pipe(**inputs, output="images") + + assert out_cfg.shape == out_no_cfg.shape + max_diff = np.abs(out_cfg - out_no_cfg).max() + assert max_diff > 1e-2, "Output with CFG must be different from normal inference" + + +class SDXLModularGuiderTests: + def test_guider_cfg(self): + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # forward pass with CFG not applied + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) + + inputs = self.get_dummy_inputs(torch_device) + out_no_cfg = pipe(**inputs, output="images") + + # forward pass with CFG applied + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) + inputs = self.get_dummy_inputs(torch_device) + out_cfg = pipe(**inputs, output="images") + + assert out_cfg.shape == out_no_cfg.shape + max_diff = np.abs(out_cfg - out_no_cfg).max() + assert max_diff > 1e-2, "Output with CFG must be different from normal inference" + + +class SDXLModularPipelineFastTests( + SDXLModularTests, + SDXLModularIPAdapterTests, + SDXLModularControlNetTests, + SDXLModularGuiderTests, + ModularPipelineTesterMixin, + unittest.TestCase, +): + """Test cases for Stable Diffusion XL modular pipeline fast tests.""" + + def test_stable_diffusion_xl_euler(self): + self._test_stable_diffusion_xl_euler( + expected_image_shape=(1, 64, 64, 3), + expected_slice=[ + 0.5966781, + 0.62939394, + 0.48465094, + 0.51573336, + 0.57593524, + 0.47035995, + 0.53410417, + 0.51436996, + 0.47313565, + ], + expected_max_diff=1e-2, + ) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + +class SDXLImg2ImgModularPipelineFastTests( + SDXLModularTests, + SDXLModularIPAdapterTests, + SDXLModularControlNetTests, + SDXLModularGuiderTests, + ModularPipelineTesterMixin, + unittest.TestCase, +): + """Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests.""" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + inputs["image"] = image + inputs["strength"] = 0.8 + + return inputs + + def test_stable_diffusion_xl_euler(self): + self._test_stable_diffusion_xl_euler( + expected_image_shape=(1, 64, 64, 3), + expected_slice=[ + 0.56943184, + 0.4702148, + 0.48048905, + 0.6235963, + 0.551138, + 0.49629188, + 0.60031277, + 0.5688907, + 0.43996853, + ], + expected_max_diff=1e-2, + ) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + +class SDXLInpaintingModularPipelineFastTests( + SDXLModularTests, + SDXLModularIPAdapterTests, + SDXLModularControlNetTests, + SDXLModularGuiderTests, + ModularPipelineTesterMixin, + unittest.TestCase, +): + """Test cases for Stable Diffusion XL inpainting modular pipeline fast tests.""" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + # create mask + image[8:, 8:, :] = 255 + mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64)) + + inputs["image"] = init_image + inputs["mask_image"] = mask_image + inputs["strength"] = 1.0 + + return inputs + + def test_stable_diffusion_xl_euler(self): + self._test_stable_diffusion_xl_euler( + expected_image_shape=(1, 64, 64, 3), + expected_slice=[ + 0.40872607, + 0.38842705, + 0.34893104, + 0.47837183, + 0.43792963, + 0.5332134, + 0.3716843, + 0.47274873, + 0.45000193, + ], + expected_max_diff=1e-2, + ) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py new file mode 100644 index 000000000000..6240797742d4 --- /dev/null +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -0,0 +1,360 @@ +import gc +import tempfile +import unittest +from typing import Callable, Union + +import numpy as np +import torch + +import diffusers +from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + backend_empty_cache, + numpy_cosine_similarity_distance, + require_accelerator, + require_torch, + torch_device, +) + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +@require_torch +class ModularPipelineTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each modular pipeline, + including: + - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters + - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs + - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input + - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs + - test_to_device: check if the pipeline's __call__ method can handle different devices + """ + + # Canonical parameters that are passed to `__call__` regardless + # of the type of pipeline. They are always optional and have common + # sense default values. + optional_params = frozenset( + [ + "num_inference_steps", + "num_images_per_prompt", + "latents", + "output_type", + ] + ) + # this is modular specific: generator needs to be a intermediate input because it's mutable + intermediate_params = frozenset( + [ + "generator", + ] + ) + + def get_generator(self, seed): + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device).manual_seed(seed) + return generator + + @property + def pipeline_class(self) -> Union[Callable, ModularPipeline]: + raise NotImplementedError( + "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. " + "See existing pipeline tests for reference." + ) + + @property + def repo(self) -> str: + raise NotImplementedError( + "You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference." + ) + + @property + def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]: + raise NotImplementedError( + "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_pipeline(self): + raise NotImplementedError( + "You need to implement `get_pipeline(self)` in the child test class. " + "See existing pipeline tests for reference." + ) + + def get_dummy_inputs(self, device, seed=0): + raise NotImplementedError( + "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. " + "See existing pipeline tests for reference." + ) + + @property + def params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `params` in the child test class. " + "`params` are checked for if all values are present in `__call__`'s signature." + " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`" + " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to " + "image pipelines, including prompts and prompt embedding overrides." + "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, " + "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline " + "with non-configurable height and width arguments should set the attribute as " + "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. " + "See existing pipeline tests for reference." + ) + + @property + def batch_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `batch_params` in the child test class. " + "`batch_params` are the parameters required to be batched when passed to the pipeline's " + "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as " + "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's " + "set of batch arguments has minor changes from one of the common sets of batch arguments, " + "do not make modifications to the existing common sets of batch arguments. I.e. a text to " + "image pipeline `negative_prompt` is not batched should set the attribute as " + "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. " + "See existing pipeline tests for reference." + ) + + def setUp(self): + # clean up the VRAM before each test + super().setUp() + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + # clean up the VRAM after each test in case of CUDA runtime errors + super().tearDown() + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def test_pipeline_call_signature(self): + pipe = self.get_pipeline() + input_parameters = pipe.blocks.input_names + intermediate_parameters = pipe.blocks.intermediate_input_names + optional_parameters = pipe.default_call_parameters + + def _check_for_parameters(parameters, expected_parameters, param_type): + remaining_parameters = {param for param in parameters if param not in expected_parameters} + assert len(remaining_parameters) == 0, ( + f"Required {param_type} parameters not present: {remaining_parameters}" + ) + + _check_for_parameters(self.params, input_parameters, "input") + _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") + _check_for_parameters(self.optional_params, optional_parameters, "optional") + + def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True): + pipe = self.get_pipeline() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # prepare batched inputs + batched_inputs = [] + for batch_size in batch_sizes: + batched_input = {} + batched_input.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + batched_input[name] = batch_size * [value] + + if batch_generator and "generator" in inputs: + batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_input["batch_size"] = batch_size + + batched_inputs.append(batched_input) + + logger.setLevel(level=diffusers.logging.WARNING) + for batch_size, batched_input in zip(batch_sizes, batched_inputs): + output = pipe(**batched_input, output="images") + assert len(output) == batch_size, "Output is different from expected batch size" + + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + ): + pipe = self.get_pipeline() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + output = pipe(**inputs, output="images") + output_batch = pipe(**batched_inputs, output="images") + + assert output_batch.shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max() + assert max_diff < expected_max_diff, "Batch inference results different from single inference results" + + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator + def test_float16_inference(self, expected_max_diff=5e-2): + pipe = self.get_pipeline() + pipe.to(torch_device, torch.float32) + pipe.set_progress_bar_config(disable=None) + + pipe_fp16 = self.get_pipeline() + pipe_fp16.to(torch_device, torch.float16) + pipe_fp16.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is used inside dummy inputs + if "generator" in inputs: + inputs["generator"] = self.get_generator(0) + output = pipe(**inputs, output="images") + + fp16_inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is used inside dummy inputs + if "generator" in fp16_inputs: + fp16_inputs["generator"] = self.get_generator(0) + output_fp16 = pipe_fp16(**fp16_inputs, output="images") + + if isinstance(output, torch.Tensor): + output = output.cpu() + output_fp16 = output_fp16.cpu() + + max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) + assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference" + + @require_accelerator + def test_to_device(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU" + + pipe.to(torch_device) + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + assert all(device == torch_device for device in model_devices), ( + "All pipeline components are not on accelerator device" + ) + + def test_inference_is_not_nan_cpu(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + pipe.to("cpu") + + output = pipe(**self.get_dummy_inputs("cpu"), output="images") + assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN" + + @require_accelerator + def test_inference_is_not_nan(self): + pipe = self.get_pipeline() + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + output = pipe(**self.get_dummy_inputs(torch_device), output="images") + assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN" + + def test_num_images_per_prompt(self): + pipe = self.get_pipeline() + + if "num_images_per_prompt" not in pipe.blocks.input_names: + return + + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") + + assert images.shape[0] == batch_size * num_images_per_prompt + + @require_accelerator + def test_components_auto_cpu_offload_inference_consistent(self): + base_pipe = self.get_pipeline().to(torch_device) + + cm = ComponentsManager() + cm.enable_auto_cpu_offload(device=torch_device) + offload_pipe = self.get_pipeline(components_manager=cm) + + image_slices = [] + for pipe in [base_pipe, offload_pipe]: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_save_from_pretrained(self): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + with tempfile.TemporaryDirectory() as tmpdirname: + base_pipe.save_pretrained(tmpdirname) + pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + pipe.load_default_components(torch_dtype=torch.float32) + pipe.to(torch_device) + + pipes.append(pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index 4e2c4dcdd9cb..3db7c9fa1b0c 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -20,12 +20,6 @@ ] ) -TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) - -TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) - -IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) - IMAGE_VARIATION_PARAMS = frozenset( [ "image", @@ -35,8 +29,6 @@ ] ) -IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"]) - TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset( [ "prompt", @@ -50,8 +42,6 @@ ] ) -TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"]) - TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( [ # Text guided image variation with an image mask @@ -67,8 +57,6 @@ ] ) -TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"]) - IMAGE_INPAINTING_PARAMS = frozenset( [ # image variation with an image mask @@ -80,8 +68,6 @@ ] ) -IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"]) - IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( [ "example_image", @@ -93,20 +79,12 @@ ] ) -IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"]) +UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"]) CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"]) -UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) - -UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([]) - -UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) - -UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) - TEXT_TO_AUDIO_PARAMS = frozenset( [ "prompt", @@ -119,11 +97,38 @@ ] ) -TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"]) -TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) +UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) -TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"]) +# image params +TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) + +IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) + + +# batch params +TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) + +IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"]) + +TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"]) + +TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"]) + +IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"]) + +IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"]) + +UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([]) + +UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) + +TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) + +TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"]) + +# callback params +TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])