Skip to content

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Code to test
import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.flux.modular_blocks import AUTO_BLOCKS_KONTEXT
from diffusers.utils import load_image

model_id = "black-forest-labs/FLUX.1-Kontext-dev"

blocks = SequentialPipelineBlocks.from_blocks_dict(AUTO_BLOCKS_KONTEXT)

pipeline = blocks.init_pipeline()
pipeline.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
pipeline.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer")
pipeline.load_components(["text_encoder_2"], repo=model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
pipeline.load_components(["tokenizer_2"], repo=model_id, subfolder="tokenizer_2")
pipeline.load_components(["scheduler"], repo=model_id, subfolder="scheduler")
pipeline.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.bfloat16)
pipeline.to("cuda")

# T2I
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
output = pipeline(
    prompt=prompt, num_inference_steps=28, guidance_scale=3.5, generator=torch.manual_seed(0)
)
output.values["images"][0].save("modular_flux.png")

# Edit
image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" 
).convert("RGB")
prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
output = pipeline(
    prompt=prompt, image=image, num_inference_steps=28, guidance_scale=2.5, generator=torch.manual_seed(0)
)
output.values["images"][0].save("modular_flux_image.png")
T2I I2I
alt text alt text

I have compared this with the regular Kontext pipeline outputs and my eyes didn't catch any differences.

Code
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
import torch 

model_id = "black-forest-labs/FLUX.1-Kontext-dev"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to("cuda")

# T2I
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
output = pipe(
    prompt=prompt, num_inference_steps=28, guidance_scale=3.5, generator=torch.manual_seed(0)
).images[0]
output.save("regular_flux.png")

# Edit
image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" 
).convert("RGB")
prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
output = pipe(
    prompt=prompt, image=image, num_inference_steps=28, guidance_scale=2.5, generator=torch.manual_seed(0)
).images[0]
output.save("regular_flux_image.png")

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu September 2, 2025 14:08
@sayakpaul
Copy link
Member Author

Cc: @asomoza if you wanna test.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! @sayakpaul

I left some thoughts/comments, I think the best approach is to wait for a couple days for our mellon nodes to be merged (they are almost ready) - that way we can play with it and test out how it works with flux

This way you can get first hand experience with node-based workflow and how we best structure our blocks to support that use case. The plan is to be able to have all our modular pipelines work out-of-box with the same set of mellon nodes :)

block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
# TODO: `_auto_resize` is currently forced to True. Since it's private anyway, I thought of not adding it.
block_state.image = self.preprocess_image(
Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think image processing should be part of vae encoder because in node-based modular workflow, the "encoding"/"decoding" all would need to be done as seperate process
(similar to how you could generate text embeddings, and then delete the text encoders and run the inference with the embedding directly -> same process can be applied to images)

so we arrange our blocks to support node-based workflow - I usually put all the encode blocks first (include text-encoder/vae encoder/image encoder (IP-adapter case) ....); so that these blocks can be poped into seperate process if needed,

and then I put an input step after that to standardize all the encoder outputs to prepare for the actual denoise step, which usually include set_timesteps, prepare_latents, denoising loop

)

@staticmethod
def prepare_latents(
Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should move the encoding out of prepare_latents step, this way we can pop the vae encoder as its own node and just generate the image latent once; since this part is usually same for your workflow

and then they can do different things with the image latents as they like: change steps/scheduler/seeds/batch_size/prompts

this will work nicely with hybrid inference too when we move the vae/text encoders remote

@sayakpaul
Copy link
Member Author

@yiyixuxu okay! Could you give this PR a ping once you think I can work out the changes you mentioned or would you rather if I addressed the feedback before?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants