-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[modular] Tests for custom blocks in modular diffusers #12557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
760a914
77e5015
1be88f0
316b71f
ecdd843
5f1afc1
ddb5ba7
b5f13d9
9f113f8
728655c
b8809f7
ea4f29f
fd88f3d
c0ce538
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| from collections import deque | ||
| from typing import List | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from diffusers import FluxTransformer2DModel | ||
| from diffusers.modular_pipelines import ( | ||
| ComponentSpec, | ||
| InputParam, | ||
| ModularPipelineBlocks, | ||
| OutputParam, | ||
| PipelineState, | ||
| WanModularPipeline, | ||
| ) | ||
|
|
||
| from ..testing_utils import nightly, require_torch, slow | ||
|
|
||
|
|
||
| class DummyCustomBlockSimple(ModularPipelineBlocks): | ||
| def __init__(self, use_dummy_model_component=False): | ||
| self.use_dummy_model_component = use_dummy_model_component | ||
| super().__init__() | ||
|
|
||
| @property | ||
| def expected_components(self): | ||
| if self.use_dummy_model_component: | ||
| return [ComponentSpec("transformer", FluxTransformer2DModel)] | ||
| else: | ||
| return [] | ||
|
|
||
| @property | ||
| def inputs(self) -> List[InputParam]: | ||
| return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] | ||
|
|
||
| @property | ||
| def intermediate_inputs(self) -> List[InputParam]: | ||
| return [] | ||
|
|
||
| @property | ||
| def intermediate_outputs(self) -> List[OutputParam]: | ||
| return [ | ||
| OutputParam( | ||
| "output_prompt", | ||
| type_hint=str, | ||
| description="Modified prompt", | ||
| ) | ||
| ] | ||
|
|
||
| def __call__(self, components, state: PipelineState) -> PipelineState: | ||
| block_state = self.get_block_state(state) | ||
|
|
||
| old_prompt = block_state.prompt | ||
| block_state.output_prompt = "Modular diffusers + " + old_prompt | ||
| self.set_block_state(state, block_state) | ||
|
|
||
| return components, state | ||
|
|
||
|
|
||
| class TestModularCustomBlocks: | ||
| def _test_block_properties(self, block): | ||
| assert not block.expected_components | ||
| assert not block.intermediate_inputs | ||
|
|
||
| actual_inputs = [inp.name for inp in block.inputs] | ||
| actual_intermediate_outputs = [out.name for out in block.intermediate_outputs] | ||
| assert actual_inputs == ["prompt"] | ||
| assert actual_intermediate_outputs == ["output_prompt"] | ||
|
|
||
| def test_custom_block_properties(self): | ||
| custom_block = DummyCustomBlockSimple() | ||
| self._test_block_properties(custom_block) | ||
|
|
||
| def test_custom_block_output(self): | ||
| custom_block = DummyCustomBlockSimple() | ||
| pipe = custom_block.init_pipeline() | ||
| prompt = "Diffusers is nice" | ||
| output = pipe(prompt=prompt) | ||
|
|
||
| actual_inputs = [inp.name for inp in custom_block.inputs] | ||
| actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] | ||
| assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) | ||
|
|
||
| output_prompt = output.values["output_prompt"] | ||
| assert output_prompt.startswith("Modular diffusers + ") | ||
|
|
||
| def test_custom_block_supported_components(self): | ||
| custom_block = DummyCustomBlockSimple(use_dummy_model_component=True) | ||
| pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe") | ||
| pipe.load_components() | ||
|
|
||
| assert len(pipe.components) == 1 | ||
| assert pipe.component_names[0] == "transformer" | ||
|
|
||
| def test_custom_block_loads_from_hub(self): | ||
| repo_id = "hf-internal-testing/tiny-modular-diffusers-block" | ||
| block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) | ||
| self._test_block_properties(block) | ||
|
|
||
| pipe = block.init_pipeline() | ||
|
|
||
| prompt = "Diffusers is nice" | ||
| output = pipe(prompt=prompt) | ||
| output_prompt = output.values["output_prompt"] | ||
| assert output_prompt.startswith("Modular diffusers + ") | ||
|
|
||
|
|
||
| @slow | ||
| @nightly | ||
| @require_torch | ||
| class TestKreaCustomBlocksIntegration: | ||
| repo_id = "krea/krea-realtime-video" | ||
|
|
||
| def test_loading_from_hub(self): | ||
| blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) | ||
| block_names = sorted(blocks.sub_blocks) | ||
|
|
||
| assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"]) | ||
|
|
||
| pipe = WanModularPipeline(blocks, self.repo_id) | ||
| pipe.load_components( | ||
| trust_remote_code=True, | ||
| device_map="cuda", | ||
| torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, | ||
| ) | ||
| assert len(pipe.components) == 7 | ||
| assert sorted(pipe.components) == sorted( | ||
| ["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"] | ||
| ) | ||
|
|
||
| def test_forward(self): | ||
| blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) | ||
| pipe = WanModularPipeline(blocks, self.repo_id) | ||
| pipe.load_components( | ||
| trust_remote_code=True, | ||
| device_map="cuda", | ||
| torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, | ||
| ) | ||
|
|
||
| num_frames_per_block = 2 | ||
| num_blocks = 2 | ||
|
|
||
| state = PipelineState() | ||
| state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len)) | ||
|
|
||
| prompt = ["a cat sitting on a boat"] | ||
|
|
||
| for block in pipe.transformer.blocks: | ||
| block.self_attn.fuse_projections() | ||
|
|
||
| for block_idx in range(num_blocks): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so a little bit related to the idea you had here https://huggingface.slack.com/archives/C065E480NN9/p1761726088454629?thread_ts=1761713343.592699&cid=C065E480NN9 in addition to being a docstring example, I think If remote custom pipelines include an executable example string, we can also programmatically discover it and create slow tests for them in our test suite. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed over Slack, we will do it after a few days. |
||
| state = pipe( | ||
| state, | ||
| prompt=prompt, | ||
| num_inference_steps=2, | ||
| num_blocks=num_blocks, | ||
| num_frames_per_block=num_frames_per_block, | ||
| block_idx=block_idx, | ||
| generator=torch.manual_seed(42), | ||
| ) | ||
| current_frames = np.array(state.values["videos"][0]) | ||
| current_frames_flat = current_frames.flatten() | ||
| actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist() | ||
|
|
||
| if block_idx == 0: | ||
| assert current_frames.shape == (5, 480, 832, 3) | ||
| expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193]) | ||
| else: | ||
| assert current_frames.shape == (8, 480, 832, 3) | ||
| expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191]) | ||
|
|
||
| assert np.allclose(actual_slices, expected_slices) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small numbers to keep the runtime small. Rest of the codebase is from
https://huggingface.co/krea/krea-realtime-video#use-it-with-%F0%9F%A7%A8-diffusers