Skip to content
76 changes: 30 additions & 46 deletions .github/workflows/pr_modular_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,62 +77,46 @@ jobs:

run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines

name: ${{ matrix.config.name }}

name: Fast PyTorch Modular Pipeline CPU tests
runs-on:
group: ${{ matrix.config.runner }}

group: aws-highmemory-32-plus
container:
image: ${{ matrix.config.image }}
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/

defaults:
run:
shell: bash

steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2

- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps

- name: Environment
run: |
python utils/print_env.py
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2

- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps

- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Environment
run: |
python utils/print_env.py

- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
- name: Run fast PyTorch Pipeline CPU tests
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_cpu_modular_pipelines \
tests/modular_pipelines

- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt

- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
path: reports
172 changes: 172 additions & 0 deletions tests/modular_pipelines/test_modular_pipelines_custom_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from collections import deque
from typing import List

import numpy as np
import torch

from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
InputParam,
ModularPipelineBlocks,
OutputParam,
PipelineState,
WanModularPipeline,
)

from ..testing_utils import nightly, require_torch, slow


class DummyCustomBlockSimple(ModularPipelineBlocks):
def __init__(self, use_dummy_model_component=False):
self.use_dummy_model_component = use_dummy_model_component
super().__init__()

@property
def expected_components(self):
if self.use_dummy_model_component:
return [ComponentSpec("transformer", FluxTransformer2DModel)]
else:
return []

@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]

@property
def intermediate_inputs(self) -> List[InputParam]:
return []

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_prompt",
type_hint=str,
description="Modified prompt",
)
]

def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

old_prompt = block_state.prompt
block_state.output_prompt = "Modular diffusers + " + old_prompt
self.set_block_state(state, block_state)

return components, state


class TestModularCustomBlocks:
def _test_block_properties(self, block):
assert not block.expected_components
assert not block.intermediate_inputs

actual_inputs = [inp.name for inp in block.inputs]
actual_intermediate_outputs = [out.name for out in block.intermediate_outputs]
assert actual_inputs == ["prompt"]
assert actual_intermediate_outputs == ["output_prompt"]

def test_custom_block_properties(self):
custom_block = DummyCustomBlockSimple()
self._test_block_properties(custom_block)

def test_custom_block_output(self):
custom_block = DummyCustomBlockSimple()
pipe = custom_block.init_pipeline()
prompt = "Diffusers is nice"
output = pipe(prompt=prompt)

actual_inputs = [inp.name for inp in custom_block.inputs]
actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs]
assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)

output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")

def test_custom_block_supported_components(self):
custom_block = DummyCustomBlockSimple(use_dummy_model_component=True)
pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe")
pipe.load_components()

assert len(pipe.components) == 1
assert pipe.component_names[0] == "transformer"

def test_custom_block_loads_from_hub(self):
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
self._test_block_properties(block)

pipe = block.init_pipeline()

prompt = "Diffusers is nice"
output = pipe(prompt=prompt)
output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")


@slow
@nightly
@require_torch
class TestModularCustomBlocksIntegration:
def test_krea_realtime_video_loading(self):
repo_id = "krea/krea-realtime-video"
blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
block_names = sorted(blocks.sub_blocks)

assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"])

pipe = WanModularPipeline(blocks, repo_id)
pipe.load_components(
trust_remote_code=True,
device_map="cuda",
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
)
assert len(pipe.components) == 7
assert sorted(pipe.components) == sorted(
["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"]
)

def test_forward(self):
repo_id = "krea/krea-realtime-video"
blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
pipe = WanModularPipeline(blocks, repo_id)
pipe.load_components(
trust_remote_code=True,
device_map="cuda",
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
)

num_frames_per_block = 2
num_blocks = 2
Comment on lines +240 to +241
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small numbers to keep the runtime small. Rest of the codebase is from
https://huggingface.co/krea/krea-realtime-video#use-it-with-%F0%9F%A7%A8-diffusers


state = PipelineState()
state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))

prompt = ["a cat sitting on a boat"]

for block in pipe.transformer.blocks:
block.self_attn.fuse_projections()

for block_idx in range(num_blocks):
Copy link
Collaborator

@yiyixuxu yiyixuxu Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so a little bit related to the idea you had here https://huggingface.slack.com/archives/C065E480NN9/p1761726088454629?thread_ts=1761713343.592699&cid=C065E480NN9

in addition to being a docstring example, I think If remote custom pipelines include an executable example string, we can also programmatically discover it and create slow tests for them in our test suite.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed over Slack, we will do it after a few days.

state = pipe(
state,
prompt=prompt,
num_inference_steps=2,
num_blocks=num_blocks,
num_frames_per_block=num_frames_per_block,
block_idx=block_idx,
generator=torch.manual_seed(42),
)
current_frames = np.array(state.values["videos"][0])
current_frames_flat = current_frames.flatten()
actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist()

if block_idx == 0:
assert current_frames.shape == (5, 480, 832, 3)
expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193])
else:
assert current_frames.shape == (8, 480, 832, 3)
expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191])

assert np.allclose(actual_slices, expected_slices)
Loading