Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b5863c2
Enable the i2vgenpipeline
yuanwu2017 Nov 15, 2024
0a16d64
Debug memory issue
yuanwu2017 Nov 21, 2024
e90726a
Cpu offload
yuanwu2017 Nov 22, 2024
3a40645
Merge branch 'main' into i2v-debug
yuanwu2017 Dec 26, 2024
2fccd3f
Use the FusedSdpa
yuanwu2017 Jan 6, 2025
1a49704
Use graph mode
yuanwu2017 Jan 15, 2025
dc544b4
Optimize the i2v pipeline
yuanwu2017 Jan 18, 2025
0b8ccd7
Fix the reshape error
yuanwu2017 Jan 18, 2025
e3026bc
Fix the error
yuanwu2017 Jan 18, 2025
ef93e3c
Fix the error and clean the code
yuanwu2017 Jan 18, 2025
5dbb14b
Add i2v example and clean code
yuanwu2017 Jan 19, 2025
24930eb
Add the export_to_gif
yuanwu2017 Jan 19, 2025
7e3becf
Fix the video size issue
yuanwu2017 Jan 19, 2025
b4612e0
Add comments and clean the code
yuanwu2017 Jan 19, 2025
cb192f7
Remove debug files
yuanwu2017 Jan 19, 2025
ba7c9f4
Remove useless modifications.
yuanwu2017 Jan 19, 2025
10df9ca
Fix errors of make style
yuanwu2017 Jan 19, 2025
94df7fe
Change the default value
yuanwu2017 Jan 20, 2025
b2bddba
Add the sdp_on_bf16
yuanwu2017 Jan 20, 2025
8cdcf70
Add default value for negative_prompts
yuanwu2017 Jan 20, 2025
b379365
Add ReadME
yuanwu2017 Jan 20, 2025
e3ef373
Update examples/stable-diffusion/image_to_video_generation.py
yuanwu2017 Jan 30, 2025
ef14e26
Remove reloading image
yuanwu2017 Jan 30, 2025
d2c4fa4
Merge branch 'huggingface:main' into i2v-debug
yuanwu2017 Jan 31, 2025
07db61b
Add CI test
yuanwu2017 Jan 31, 2025
a3c8129
Fix errors of make style
yuanwu2017 Jan 31, 2025
827f16a
Fix ReadME typo
yuanwu2017 Jan 31, 2025
c54bf41
Change to is_i2v_model
yuanwu2017 Feb 6, 2025
0afc536
Remove the examples commands in README.
yuanwu2017 Feb 6, 2025
ed98a1b
Fix typo
yuanwu2017 Feb 6, 2025
c8b075d
Enable the autocast according to gaudi_config
yuanwu2017 Feb 6, 2025
2dbfc20
Update tests/test_diffusers.py
yuanwu2017 Feb 7, 2025
c5e08ec
Merge branch 'huggingface:main' into i2v-debug
yuanwu2017 Feb 7, 2025
a22e73b
Add the i2vgen-xl into ReadME and docs
yuanwu2017 Feb 7, 2025
51b48cd
Merge branch 'huggingface:main' into i2v-debug
yuanwu2017 Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ The following model architectures, tasks and device distributions have been vali
| FLUX.1 | <li>LoRA</li> | <li>Single card</li> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#flux1)</li><li>[image-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#flux1-image-to-image)</li> |
| Text to Video | | <li>Single card</li> | <li>[text-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#text-to-video-generation)</li> |
| Image to Video | | <li>Single card</li> | <li>[image-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#image-to-video-generation)</li> |
| i2vgen-xl | | <li>Single card</li> | <li>[image-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#I2vgen-xl)</li> |

### PyTorch Image Models/TIMM:

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| LDM3D | | <div style="text-align:left"><li>Single card</li></div> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)</li> |
| FLUX.1 | <li>[fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#dreambooth-lora-fine-tuning-with-flux1-dev)</li> | <li>Single card</li> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)</li> |
| Text to Video | | <li>Single card</li> | <li>[text-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#text-to-video-generation)</li> |
| i2vgen-xl | | <li>Single card</li> | <li>[image-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#I2vgen-xl)</li> |

- PyTorch Image Models/TIMM:

Expand Down
25 changes: 25 additions & 0 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,31 @@ python image_to_video_generation.py \
--height=512
```

# I2vgen-xl
I2vgen-xl is high quality Image-to-Video synthesis via cascaded diffusion models. Please refer to [Huggingface i2vgen-xl doc](https://huggingface.co/ali-vilab/i2vgen-xl).

Here is how to generate video with one image and text prompt:

```bash
PT_HPU_MAX_COMPOUND_OP_SIZE=1 \
python image_to_video_generation.py \
--model_name_or_path "ali-vilab/i2vgen-xl" \
--image_path "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" \
--num_videos_per_prompt 1 \
--video_save_dir ./i2vgen_xl \
--num_inference_steps 50 \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--gif \
--num_frames 16 \
--prompts "Papers were floating in the air on a table in the library" \
--negative_prompts "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" \
--seed 8888 \
--sdp_on_bf16 \
--bf16
```

# Important Notes for Gaudi3 Users

- **Batch Size Limitation**: Due to a known issue, batch sizes for some Stable Diffusion models need to be reduced.
Expand Down
55 changes: 51 additions & 4 deletions examples/stable-diffusion/image_to_video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
from pathlib import Path

import torch
from diffusers.utils import export_to_video, load_image
from diffusers.utils import export_to_gif, export_to_video, load_image

from optimum.habana.diffusers import GaudiEulerDiscreteScheduler, GaudiStableVideoDiffusionPipeline
from optimum.habana.diffusers import (
GaudiEulerDiscreteScheduler,
GaudiI2VGenXLPipeline,
GaudiStableVideoDiffusionPipeline,
)
from optimum.habana.utils import set_seed


Expand Down Expand Up @@ -57,6 +61,20 @@ def main():
)

# Pipeline arguments
parser.add_argument(
"--prompts",
type=str,
nargs="*",
default="Papers were floating in the air on a table in the library",
help="The prompt or prompts to guide the image generation.",
)
parser.add_argument(
"--negative_prompts",
type=str,
nargs="*",
default="Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms",
help="The prompt or prompts not to guide the image generation.",
)
parser.add_argument(
"--image_path",
type=str,
Expand Down Expand Up @@ -177,6 +195,7 @@ def main():
),
)
parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")
parser.add_argument("--gif", action="store_true", help="Whether to generate the video in gif format.")
parser.add_argument(
"--sdp_on_bf16",
action="store_true",
Expand Down Expand Up @@ -212,14 +231,20 @@ def main():
)
logger.setLevel(logging.INFO)

i2v_models = ["i2vgen-xl"]
is_i2v_model = any(model in args.model_name_or_path for model in i2v_models)

# Load input image(s)
input = []
logger.info("Input image(s):")
if isinstance(args.image_path, str):
args.image_path = [args.image_path]
for image_path in args.image_path:
image = load_image(image_path)
image = image.resize((args.height, args.width))
if is_i2v_model:
image = image.convert("RGB")
else:
image = image.resize((args.height, args.width))
input.append(image)
logger.info(image_path)

Expand Down Expand Up @@ -281,6 +306,24 @@ def main():
output_type=args.output_type,
num_frames=args.num_frames,
)
elif is_i2v_model:
del kwargs["scheduler"]
pipeline = GaudiI2VGenXLPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
)
generator = torch.manual_seed(args.seed)
outputs = pipeline(
prompt=args.prompts,
image=input,
num_videos_per_prompt=args.num_videos_per_prompt,
batch_size=args.batch_size,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
negative_prompt=args.negative_prompts,
guidance_scale=9.0,
generator=generator,
)
else:
pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained(
args.model_name_or_path,
Expand Down Expand Up @@ -321,7 +364,11 @@ def main():
video_save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving video frames in {video_save_dir.resolve()}...")
for i, frames in enumerate(outputs.frames):
export_to_video(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".mp4", fps=7)
if args.gif:
export_to_gif(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".gif")
else:
export_to_video(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".mp4", fps=7)

if args.save_frames_as_images:
for j, frame in enumerate(frames):
frame.save(
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .pipelines.ddpm.pipeline_ddpm import GaudiDDPMPipeline
from .pipelines.flux.pipeline_flux import GaudiFluxPipeline
from .pipelines.flux.pipeline_flux_img2img import GaudiFluxImg2ImgPipeline
from .pipelines.i2vgen_xl.pipeline_i2vgen_xl import GaudiI2VGenXLPipeline
from .pipelines.pipeline_utils import GaudiDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline
from .pipelines.stable_diffusion.pipeline_stable_diffusion_depth2img import GaudiStableDiffusionDepth2ImgPipeline
Expand Down
Loading