diff --git a/docs/source/en/api/models/omnigen_transformer.md b/docs/source/en/api/models/omnigen_transformer.md index ee700a04bdae..78d29fdab5e4 100644 --- a/docs/source/en/api/models/omnigen_transformer.md +++ b/docs/source/en/api/models/omnigen_transformer.md @@ -14,6 +14,17 @@ specific language governing permissions and limitations under the License. A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/). +The abstract from the paper is: + +*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.* + +```python +import torch +from diffusers import OmniGenTransformer2DModel + +transformer = OmniGenTransformer2DModel.from_pretrained("Shitao/OmniGen-v1-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + ## OmniGenTransformer2DModel [[autodoc]] OmniGenTransformer2DModel diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md index 0b826f182edd..114e3753e710 100644 --- a/docs/source/en/api/pipelines/omnigen.md +++ b/docs/source/en/api/pipelines/omnigen.md @@ -19,27 +19,7 @@ The abstract from the paper is: -*The emergence of Large Language Models (LLMs) has unified language -generation tasks and revolutionized human-machine interaction. -However, in the realm of image generation, a unified model capable of handling various tasks -within a single framework remains largely unexplored. In -this work, we introduce OmniGen, a new diffusion model -for unified image generation. OmniGen is characterized -by the following features: 1) Unification: OmniGen not -only demonstrates text-to-image generation capabilities but -also inherently supports various downstream tasks, such -as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of -OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion -models, it is more user-friendly and can complete complex -tasks end-to-end through instructions without the need for -extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from -learning in a unified format, OmniGen effectively transfers -knowledge across different tasks, manages unseen tasks and -domains, and exhibits novel capabilities. We also explore -the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. -This work represents the first attempt at a general-purpose image generation model, -and we will release our resources at https: -//github.com/VectorSpaceLab/OmniGen to foster future advancements.* +*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.* @@ -49,7 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1). - ## Inference First, load the pipeline: @@ -57,17 +36,15 @@ First, load the pipeline: ```python import torch from diffusers import OmniGenPipeline -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1-diffusers", - torch_dtype=torch.bfloat16 -) + +pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) pipe.to("cuda") ``` For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. You can try setting the `height` and `width` parameters to generate images with different size. -```py +```python prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." image = pipe( prompt=prompt, @@ -76,14 +53,14 @@ image = pipe( guidance_scale=3, generator=torch.Generator(device="cpu").manual_seed(111), ).images[0] -image +image.save("output.png") ``` OmniGen supports multimodal inputs. When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image. It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. -```py +```python prompt="<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")] image = pipe( @@ -93,14 +70,11 @@ image = pipe( img_guidance_scale=1.6, use_input_image_size_as_output=True, generator=torch.Generator(device="cpu").manual_seed(222)).images[0] -image +image.save("output.png") ``` - ## OmniGenPipeline [[autodoc]] OmniGenPipeline - all - __call__ - - diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index a3d98e4e60cc..40a9e81bcd52 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -19,25 +19,22 @@ For more information, please refer to the [paper](https://arxiv.org/pdf/2409.113 This guide will walk you through using OmniGen for various tasks and use cases. ## Load model checkpoints + Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. -```py +```python import torch from diffusers import OmniGenPipeline -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1-diffusers", - torch_dtype=torch.bfloat16 -) -``` - +pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) +``` ## Text-to-image For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. You can try setting the `height` and `width` parameters to generate images with different size. -```py +```python import torch from diffusers import OmniGenPipeline @@ -55,8 +52,9 @@ image = pipe( guidance_scale=3, generator=torch.Generator(device="cpu").manual_seed(111), ).images[0] -image +image.save("output.png") ``` +
generated image
@@ -67,7 +65,7 @@ OmniGen supports multimodal inputs. When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image. It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. -```py +```python import torch from diffusers import OmniGenPipeline from diffusers.utils import load_image @@ -86,9 +84,11 @@ image = pipe( guidance_scale=2, img_guidance_scale=1.6, use_input_image_size_as_output=True, - generator=torch.Generator(device="cpu").manual_seed(222)).images[0] -image + generator=torch.Generator(device="cpu").manual_seed(222) +).images[0] +image.save("output.png") ``` +
@@ -101,7 +101,8 @@ image
OmniGen has some interesting features, such as visual reasoning, as shown in the example below. -```py + +```python prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <|image_1|>" input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] image = pipe( @@ -110,20 +111,20 @@ image = pipe( guidance_scale=2, img_guidance_scale=1.6, use_input_image_size_as_output=True, - generator=torch.Generator(device="cpu").manual_seed(0)).images[0] -image + generator=torch.Generator(device="cpu").manual_seed(0) +).images[0] +image.save("output.png") ``` +
generated image
- ## Controllable generation - OmniGen can handle several classic computer vision tasks. - As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images. +OmniGen can handle several classic computer vision tasks. As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images. -```py +```python import torch from diffusers import OmniGenPipeline from diffusers.utils import load_image @@ -142,8 +143,9 @@ image1 = pipe( guidance_scale=2, img_guidance_scale=1.6, use_input_image_size_as_output=True, - generator=torch.Generator(device="cpu").manual_seed(333)).images[0] -image1 + generator=torch.Generator(device="cpu").manual_seed(333) +).images[0] +image1.save("image1.png") prompt="Generate a new photo using the following picture and text as conditions: <|image_1|>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")] @@ -153,8 +155,9 @@ image2 = pipe( guidance_scale=2, img_guidance_scale=1.6, use_input_image_size_as_output=True, - generator=torch.Generator(device="cpu").manual_seed(333)).images[0] -image2 + generator=torch.Generator(device="cpu").manual_seed(333) +).images[0] +image2.save("image2.png") ```
@@ -174,7 +177,8 @@ image2 OmniGen can also directly use relevant information from input images to generate new images. -```py + +```python import torch from diffusers import OmniGenPipeline from diffusers.utils import load_image @@ -193,9 +197,11 @@ image = pipe( guidance_scale=2, img_guidance_scale=1.6, use_input_image_size_as_output=True, - generator=torch.Generator(device="cpu").manual_seed(0)).images[0] -image + generator=torch.Generator(device="cpu").manual_seed(0) +).images[0] +image.save("output.png") ``` +
@@ -203,13 +209,12 @@ image
- ## ID and object preserving OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously. Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions. -```py +```python import torch from diffusers import OmniGenPipeline from diffusers.utils import load_image @@ -231,9 +236,11 @@ image = pipe( width=1024, guidance_scale=2.5, img_guidance_scale=1.6, - generator=torch.Generator(device="cpu").manual_seed(666)).images[0] -image + generator=torch.Generator(device="cpu").manual_seed(666) +).images[0] +image.save("output.png") ``` +
@@ -249,7 +256,6 @@ image
- ```py import torch from diffusers import OmniGenPipeline @@ -261,7 +267,6 @@ pipe = OmniGenPipeline.from_pretrained( ) pipe.to("cuda") - prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg") input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg") @@ -273,8 +278,9 @@ image = pipe( width=1024, guidance_scale=2.5, img_guidance_scale=1.6, - generator=torch.Generator(device="cpu").manual_seed(666)).images[0] -image + generator=torch.Generator(device="cpu").manual_seed(666) +).images[0] +image.save("output.png") ```
@@ -292,13 +298,12 @@ image
- -## Optimization when inputting multiple images +## Optimization when using multiple images For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU). However, when using input images, the computational cost increases. -Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images. +Here are some guidelines to help you reduce computational costs when using multiple images. The experiments are conducted on an A800 GPU with two input images. Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `. In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`. @@ -310,5 +315,3 @@ The memory consumption for different image sizes is shown in the table below: | max_input_image_size=512 | 17GB | | max_input_image_size=256 | 14GB | - - diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bd3237c24c1c..c42fbbc9f0a3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1199,7 +1199,7 @@ def apply_rotary_emb( x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: - # Used for Stable Audio + # Used for Stable Audio and OmniGen x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 0774a3f2a6ee..8d5d1b3f8fea 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -13,17 +13,15 @@ # limitations under the License. import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers -from ..attention_processor import Attention, AttentionProcessor +from ...utils import logging +from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -34,39 +32,21 @@ class OmniGenFeedForward(nn.Module): - r""" - A feed-forward layer for OmniGen. - - Parameters: - hidden_size (`int`): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - intermediate_size (`int`): The intermediate dimension of the feedforward layer. - """ - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): + def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() + self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.activation_fn = nn.SiLU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: up_states = self.gate_up_proj(hidden_states) - gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * self.activation_fn(gate) - return self.down_proj(up_states) class OmniGenPatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for OmniGen.""" - def __init__( self, patch_size: int = 2, @@ -99,7 +79,7 @@ def __init__( ) self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) - def cropped_pos_embed(self, height, width): + def _cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: raise ValueError("`pos_embed_max_size` must be set for cropping.") @@ -122,43 +102,34 @@ def cropped_pos_embed(self, height, width): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed - def patch_embeddings(self, latent, is_input_image: bool): + def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor: if is_input_image: - latent = self.input_image_proj(latent) + hidden_states = self.input_image_proj(hidden_states) else: - latent = self.output_image_proj(latent) - latent = latent.flatten(2).transpose(1, 2) - return latent - - def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None): - """ - Args: - latent: encoded image latents - is_input_image: use input_image_proj or output_image_proj - padding_latent: - When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence - length. - - Returns: torch.Tensor - - """ - if isinstance(latent, list): + hidden_states = self.output_image_proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + return hidden_states + + def forward( + self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None + ) -> torch.Tensor: + if isinstance(hidden_states, list): if padding_latent is None: - padding_latent = [None] * len(latent) + padding_latent = [None] * len(hidden_states) patched_latents = [] - for sub_latent, padding in zip(latent, padding_latent): + for sub_latent, padding in zip(hidden_states, padding_latent): height, width = sub_latent.shape[-2:] - sub_latent = self.patch_embeddings(sub_latent, is_input_image) - pos_embed = self.cropped_pos_embed(height, width) + sub_latent = self._patch_embeddings(sub_latent, is_input_image) + pos_embed = self._cropped_pos_embed(height, width) sub_latent = sub_latent + pos_embed if padding is not None: sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) patched_latents.append(sub_latent) else: - height, width = latent.shape[-2:] - pos_embed = self.cropped_pos_embed(height, width) - latent = self.patch_embeddings(latent, is_input_image) - patched_latents = latent + pos_embed + height, width = hidden_states.shape[-2:] + pos_embed = self._cropped_pos_embed(height, width) + hidden_states = self._patch_embeddings(hidden_states, is_input_image) + patched_latents = hidden_states + pos_embed return patched_latents @@ -180,15 +151,16 @@ def __init__( self.long_factor = rope_scaling["long_factor"] self.original_max_position_embeddings = original_max_position_embeddings - @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, hidden_states, position_ids): seq_len = torch.max(position_ids) + 1 if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device) else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device) - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + inv_freq_shape = ( + torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim + ) self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @@ -196,11 +168,11 @@ def forward(self, x, position_ids): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type + device_type = hidden_states.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) + emb = torch.cat((freqs, freqs), dim=-1)[0] scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: @@ -210,44 +182,7 @@ def forward(self, x, position_ids): cos = emb.cos() * scaling_factor sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings - to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are - reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting - tensors contain rotary embeddings and are returned as real tensors. - - Args: - x (`torch.Tensor`): - Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply - freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - - cos, sin = freqs_cis # [S, D] - if len(cos.shape) == 2: - cos = cos[None, None] - sin = sin[None, None] - elif len(cos.shape) == 3: - cos = cos[:, None] - sin = sin[:, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - # Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc. - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - x_rotated = torch.cat((-x2, x1), dim=-1) - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out + return cos, sin class OmniGenAttnProcessor2_0: @@ -278,7 +213,6 @@ def __call__( bsz, q_len, query_dim = query.size() inner_dim = key.shape[-1] head_dim = query_dim // attn.heads - dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim @@ -289,32 +223,19 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + from ..embeddings import apply_rotary_emb - query, key = query.to(dtype), key.to(dtype) + query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2) + key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) - hidden_states = hidden_states.transpose(1, 2).to(dtype) + hidden_states = hidden_states.transpose(1, 2).type_as(query) hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) hidden_states = attn.to_out[0](hidden_states) return hidden_states class OmniGenBlock(nn.Module): - """ - A LuminaNextDiTBlock for LuminaNextDiT2DModel. - - Parameters: - hidden_size (`int`): Embedding dimension of the input features. - num_attention_heads (`int`): Number of attention heads. - num_key_value_heads (`int`): - Number of attention heads in key and value features (if using GQA), or set to None for the same as query. - intermediate_size (`int`): size of intermediate layer. - rms_norm_eps (`float`): The eps for norm layer. - """ - def __init__( self, hidden_size: int, @@ -341,78 +262,77 @@ def __init__( self.mlp = OmniGenFeedForward(hidden_size, intermediate_size) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - image_rotary_emb: torch.Tensor, - ): - """ - Perform a forward pass through the LuminaNextDiTBlock. - - Parameters: - hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. - attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. - image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - attn_outputs = self.self_attn( - hidden_states=hidden_states, - encoder_hidden_states=hidden_states, + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor + ) -> torch.Tensor: + # 1. Attention + norm_hidden_states = self.input_layernorm(hidden_states) + attn_output = self.self_attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) + hidden_states = hidden_states + attn_output - hidden_states = residual + attn_outputs - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - + # 2. Feed Forward + norm_hidden_states = self.post_attention_layernorm(hidden_states) + ff_output = self.mlp(norm_hidden_states) + hidden_states = hidden_states + ff_output return hidden_states -class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class OmniGenTransformer2DModel(ModelMixin, ConfigMixin): """ - The Transformer model introduced in OmniGen. - - Reference: https://arxiv.org/pdf/2409.11340 + The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340). Parameters: - hidden_size (`int`, *optional*, defaults to 3072): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - rms_norm_eps (`float`, *optional*, defaults to 1e-5): eps for RMSNorm layer. - num_attention_heads (`int`, *optional*, defaults to 32): - The number of attention heads in each attention layer. This parameter specifies how many separate attention - mechanisms are used. - num_kv_heads (`int`, *optional*, defaults to 32): - The number of key-value heads in the attention mechanism, if different from the number of attention heads. - If None, it defaults to num_attention_heads. - intermediate_size (`int`, *optional*, defaults to 8192): dimension of the intermediate layer in FFN - num_layers (`int`, *optional*, default to 32): - The number of layers in the model. This defines the depth of the neural network. - pad_token_id (`int`, *optional*, default to 32000): - id for pad token - vocab_size (`int`, *optional*, default to 32064): - size of vocabulary - patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. - pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb. + in_channels (`int`, defaults to `4`): + The number of channels in the input. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + hidden_size (`int`, defaults to `3072`): + The dimensionality of the hidden layers in the model. + rms_norm_eps (`float`, defaults to `1e-5`): + Eps for RMSNorm layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + num_key_value_heads (`int`, defaults to `32`): + The number of heads to use for keys and values in multi-head attention. + intermediate_size (`int`, defaults to `8192`): + Dimension of the hidden layer in FeedForward layers. + num_layers (`int`, default to `32`): + The number of layers of transformer blocks to use. + pad_token_id (`int`, default to `32000`): + The id of the padding token. + vocab_size (`int`, default to `32064`): + The size of the vocabulary of the embedding vocabulary. + rope_base (`int`, default to `10000`): + The default theta value to use when creating RoPE. + rope_scaling (`Dict`, optional): + The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`. + pos_embed_max_size (`int`, default to `192`): + The maximum size of the positional embeddings. + time_step_dim (`int`, default to `256`): + Output dimension of timestep embeddings. + flip_sin_to_cos (`bool`, default to `True`): + Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings. + downscale_freq_shift (`int`, default to `0`): + The frequency shift to use when downscaling the timestep embeddings. + timestep_activation_fn (`str`, default to `silu`): + The activation function to use for the timestep embeddings. """ _supports_gradient_checkpointing = True _no_split_modules = ["OmniGenBlock"] + _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"] @register_to_config def __init__( self, + in_channels: int = 4, + patch_size: int = 2, hidden_size: int = 3072, - rms_norm_eps: float = 1e-05, + rms_norm_eps: float = 1e-5, num_attention_heads: int = 32, num_key_value_heads: int = 32, intermediate_size: int = 8192, @@ -423,8 +343,6 @@ def __init__( original_max_position_embeddings: int = 4096, rope_base: int = 10000, rope_scaling: Dict = None, - patch_size=2, - in_channels=4, pos_embed_max_size: int = 192, time_step_dim: int = 256, flip_sin_to_cos: bool = True, @@ -434,8 +352,6 @@ def __init__( super().__init__() self.in_channels = in_channels self.out_channels = in_channels - self.patch_size = patch_size - self.pos_embed_max_size = pos_embed_max_size self.patch_embedding = OmniGenPatchEmbed( patch_size=patch_size, @@ -448,11 +364,8 @@ def __init__( self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) - self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1) - self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) - self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id) - self.rotary_emb = OmniGenSuScaledRotaryEmbedding( + self.rope = OmniGenSuScaledRotaryEmbedding( hidden_size // num_attention_heads, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, @@ -462,126 +375,34 @@ def __init__( self.layers = nn.ModuleList( [ - OmniGenBlock( - hidden_size, - num_attention_heads, - num_key_value_heads, - intermediate_size, - rms_norm_eps, - ) + OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps) for _ in range(num_layers) ] ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1) + self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False - def unpatchify(self, x, h, w): - """ - x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) - """ - c = self.out_channels - - x = x.reshape( - shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c) - ) - x = torch.einsum("nhwpqc->nchpwq", x) - imgs = x.reshape(shape=(x.shape[0], c, h, w)) - return imgs - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} + def _get_multimodal_embeddings( + self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict + ) -> Optional[torch.Tensor]: + if input_ids is None: + return None - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def get_multimodal_embeddings( - self, - input_ids: torch.Tensor, - input_img_latents: List[torch.Tensor], - input_image_sizes: Dict, - ): - """ - get the multi-modal conditional embeddings - - Args: - input_ids: a sequence of text id - input_img_latents: continues embedding of input images - input_image_sizes: the index of the input image in the input_ids sequence. - - Returns: torch.Tensor - - """ input_img_latents = [x.to(self.dtype) for x in input_img_latents] - condition_tokens = None - if input_ids is not None: - condition_tokens = self.embed_tokens(input_ids) - input_img_inx = 0 - if input_img_latents is not None: - input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) - - for b_inx in input_image_sizes.keys(): - for start_inx, end_inx in input_image_sizes[b_inx]: - # replace the placeholder in text tokens with the image embedding. - condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to( - condition_tokens.dtype - ) - input_img_inx += 1 - + condition_tokens = self.embed_tokens(input_ids) + input_img_inx = 0 + input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) + for b_inx in input_image_sizes.keys(): + for start_inx, end_inx in input_image_sizes[b_inx]: + # replace the placeholder in text tokens with the image embedding. + condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to( + condition_tokens.dtype + ) + input_img_inx += 1 return condition_tokens def forward( @@ -593,106 +414,55 @@ def forward( input_image_sizes: Dict[int, List[int]], attention_mask: torch.Tensor, position_ids: torch.Tensor, - attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - ): - """ - The [`OmniGenTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - timestep (`torch.FloatTensor`): - Used to indicate denoising step. - input_ids (`torch.LongTensor`): - token ids - input_img_latents (`torch.Tensor`): - encoded image latents by VAE - input_image_sizes (`dict`): - the indices of the input_img_latents in the input_ids - attention_mask (`torch.Tensor`): - mask for self-attention - position_ids (`torch.LongTensor`): - id to represent position - past_key_values (`transformers.cache_utils.Cache`): - previous key and value states - offload_transformer_block (`bool`, *optional*, defaults to `True`): - offload transformer block to cpu - attention_kwargs: (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain tuple. - - Returns: - If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a `tuple` where the first - element is the sample tensor. - - """ - - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 + ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]: + batch_size, num_channels, height, width = hidden_states.shape + p = self.config.patch_size + post_patch_height, post_patch_width = height // p, width // p - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - height, width = hidden_states.size()[-2:] + # 1. Patch & Timestep & Conditional Embedding hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) - time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1) + timestep_proj = self.time_proj(timestep).type_as(hidden_states) + time_token = self.time_token(timestep_proj).unsqueeze(1) + temb = self.t_embedder(timestep_proj) - condition_tokens = self.get_multimodal_embeddings( - input_ids=input_ids, - input_img_latents=input_img_latents, - input_image_sizes=input_image_sizes, - ) + condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes) if condition_tokens is not None: - inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1) + hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: - inputs_embeds = torch.cat([time_token, hidden_states], dim=1) + hidden_states = torch.cat([time_token, hidden_states], dim=1) - batch_size, seq_length = inputs_embeds.shape[:2] + seq_length = hidden_states.size(1) position_ids = position_ids.view(-1, seq_length).long() + # 2. Attention mask preprocessing if attention_mask is not None and attention_mask.dim() == 3: - dtype = inputs_embeds.dtype + dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min attention_mask = (1 - attention_mask) * min_dtype - attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) - else: - raise Exception("attention_mask parameter was unavailable or invalid") + attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states) - hidden_states = inputs_embeds + # 3. Rotary position embedding + image_rotary_emb = self.rope(hidden_states, position_ids) - image_rotary_emb = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + # 4. Transformer blocks + for block in self.layers: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - decoder_layer, hidden_states, attention_mask, image_rotary_emb + block, hidden_states, attention_mask, image_rotary_emb ) else: - hidden_states = decoder_layer( - hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb - ) + hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb) + # 5. Output norm & projection hidden_states = self.norm(hidden_states) - hidden_states = hidden_states[:, -num_tokens_for_output_image:] - timestep_proj = self.time_proj(timestep) - temb = self.t_embedder(timestep_proj.type_as(hidden_states)) hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) - output = self.unpatchify(hidden_states, height, width) + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1) + output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 41bfab5e3e04..5fe5be3b26d2 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch @@ -23,11 +23,7 @@ from ...models.autoencoders import AutoencoderKL from ...models.transformers import OmniGenTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .processor_omnigen import OmniGenMultiModalProcessor @@ -48,11 +44,12 @@ >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] - >>> image.save("t2i.png") + >>> image.save("output.png") ``` """ @@ -200,7 +197,6 @@ def check_inputs( width, use_input_image_size_as_output, callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, ): if input_images is not None: if len(input_images) != len(prompt): @@ -324,10 +320,8 @@ def __call__( latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 120000, ): r""" Function invoked when calling the pipeline for generation. @@ -376,10 +370,6 @@ def __call__( [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -389,7 +379,6 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. Examples: @@ -414,11 +403,9 @@ def __call__( width, use_input_image_size_as_output, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Define call parameters @@ -451,7 +438,8 @@ def __call__( ) self._num_timesteps = len(timesteps) - # 6. Prepare latents. + # 6. Prepare latents + transformer_dtype = self.transformer.dtype if use_input_image_size_as_output: height, width = processed_data["input_pixel_values"][0].shape[-2:] latent_channels = self.transformer.config.in_channels @@ -460,7 +448,7 @@ def __call__( latent_channels, height, width, - self.transformer.dtype, + torch.float32, device, generator, latents, @@ -471,6 +459,7 @@ def __call__( for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (num_cfg + 1)) + latent_model_input = latent_model_input.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -483,7 +472,6 @@ def __call__( input_image_sizes=processed_data["input_image_sizes"], attention_mask=processed_data["attention_mask"], position_ids=processed_data["position_ids"], - attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -495,7 +483,6 @@ def __call__( noise_pred = uncond + guidance_scale * (cond - uncond) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: @@ -506,11 +493,6 @@ def __call__( latents = callback_outputs.pop("latents", latents) - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - progress_bar.update() if not output_type == "latent": diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index dd5e5fcb2918..2f9c4d4e3f8e 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -18,17 +18,10 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = OmniGenPipeline - params = frozenset( - [ - "prompt", - "guidance_scale", - ] - ) - batch_params = frozenset( - [ - "prompt", - ] - ) + params = frozenset(["prompt", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0)