Skip to content

Commit 2f25156

Browse files
authored
LEditsPP - examples, check height/width, add tiling/slicing (#10471)
* LEditsPP - examples, check height/width, add tiling/slicing * make style
1 parent 6da6406 commit 2f25156

File tree

2 files changed

+95
-19
lines changed

2 files changed

+95
-19
lines changed

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,19 @@
3434
EXAMPLE_DOC_STRING = """
3535
Examples:
3636
```py
37-
>>> import PIL
38-
>>> import requests
3937
>>> import torch
40-
>>> from io import BytesIO
4138
4239
>>> from diffusers import LEditsPPPipelineStableDiffusion
4340
>>> from diffusers.utils import load_image
4441
4542
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
46-
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
43+
... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
4744
... )
45+
>>> pipe.enable_vae_tiling()
4846
>>> pipe = pipe.to("cuda")
4947
5048
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
51-
>>> image = load_image(img_url).convert("RGB")
49+
>>> image = load_image(img_url).resize((512, 512))
5250
5351
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
5452
@@ -152,7 +150,7 @@ def __init__(self, device):
152150

153151
# The gaussian kernel is the product of the gaussian function of each dimension.
154152
kernel = 1
155-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
153+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
156154
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
157155
mean = (size - 1) / 2
158156
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -706,6 +704,35 @@ def clip_skip(self):
706704
def cross_attention_kwargs(self):
707705
return self._cross_attention_kwargs
708706

707+
def enable_vae_slicing(self):
708+
r"""
709+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
710+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
711+
"""
712+
self.vae.enable_slicing()
713+
714+
def disable_vae_slicing(self):
715+
r"""
716+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
717+
computing decoding in one step.
718+
"""
719+
self.vae.disable_slicing()
720+
721+
def enable_vae_tiling(self):
722+
r"""
723+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
724+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
725+
processing larger images.
726+
"""
727+
self.vae.enable_tiling()
728+
729+
def disable_vae_tiling(self):
730+
r"""
731+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
732+
computing decoding in one step.
733+
"""
734+
self.vae.disable_tiling()
735+
709736
@torch.no_grad()
710737
@replace_example_docstring(EXAMPLE_DOC_STRING)
711738
def __call__(
@@ -1271,6 +1298,8 @@ def invert(
12711298
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
12721299
and respective VAE reconstruction(s).
12731300
"""
1301+
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
1302+
raise ValueError("height and width must be a factor of 32.")
12741303
# Reset attn processor, we do not want to store attn maps during inversion
12751304
self.unet.set_attn_processor(AttnProcessor())
12761305

@@ -1360,6 +1389,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
13601389
image = self.image_processor.preprocess(
13611390
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
13621391
)
1392+
height, width = image.shape[-2:]
1393+
if height % 32 != 0 or width % 32 != 0:
1394+
raise ValueError(
1395+
"Image height and width must be a factor of 32. "
1396+
"Consider down-sampling the input using the `height` and `width` parameters"
1397+
)
13631398
resized = self.image_processor.postprocess(image=image, output_type="pil")
13641399

13651400
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py

+54-13
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,18 @@
7272
Examples:
7373
```py
7474
>>> import torch
75-
>>> import PIL
76-
>>> import requests
77-
>>> from io import BytesIO
7875
7976
>>> from diffusers import LEditsPPPipelineStableDiffusionXL
77+
>>> from diffusers.utils import load_image
8078
8179
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
82-
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
80+
... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
8381
... )
82+
>>> pipe.enable_vae_tiling()
8483
>>> pipe = pipe.to("cuda")
8584
86-
87-
>>> def download_image(url):
88-
... response = requests.get(url)
89-
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
90-
91-
9285
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
93-
>>> image = download_image(img_url)
86+
>>> image = load_image(img_url).resize((1024, 1024))
9487
9588
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
9689
@@ -197,7 +190,7 @@ def __init__(self, device):
197190

198191
# The gaussian kernel is the product of the gaussian function of each dimension.
199192
kernel = 1
200-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
193+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
201194
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
202195
mean = (size - 1) / 2
203196
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -768,6 +761,35 @@ def denoising_end(self):
768761
def num_timesteps(self):
769762
return self._num_timesteps
770763

764+
def enable_vae_slicing(self):
765+
r"""
766+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
767+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
768+
"""
769+
self.vae.enable_slicing()
770+
771+
def disable_vae_slicing(self):
772+
r"""
773+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
774+
computing decoding in one step.
775+
"""
776+
self.vae.disable_slicing()
777+
778+
def enable_vae_tiling(self):
779+
r"""
780+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
781+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
782+
processing larger images.
783+
"""
784+
self.vae.enable_tiling()
785+
786+
def disable_vae_tiling(self):
787+
r"""
788+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
789+
computing decoding in one step.
790+
"""
791+
self.vae.disable_tiling()
792+
771793
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
772794
def prepare_unet(self, attention_store, PnP: bool = False):
773795
attn_procs = {}
@@ -1401,6 +1423,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
14011423
image = self.image_processor.preprocess(
14021424
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
14031425
)
1426+
height, width = image.shape[-2:]
1427+
if height % 32 != 0 or width % 32 != 0:
1428+
raise ValueError(
1429+
"Image height and width must be a factor of 32. "
1430+
"Consider down-sampling the input using the `height` and `width` parameters"
1431+
)
14041432
resized = self.image_processor.postprocess(image=image, output_type="pil")
14051433

14061434
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
@@ -1439,6 +1467,10 @@ def invert(
14391467
crops_coords_top_left: Tuple[int, int] = (0, 0),
14401468
num_zero_noise_steps: int = 3,
14411469
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1470+
height: Optional[int] = None,
1471+
width: Optional[int] = None,
1472+
resize_mode: Optional[str] = "default",
1473+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
14421474
):
14431475
r"""
14441476
The function to the pipeline for image inversion as described by the [LEDITS++
@@ -1486,6 +1518,8 @@ def invert(
14861518
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
14871519
and respective VAE reconstruction(s).
14881520
"""
1521+
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
1522+
raise ValueError("height and width must be a factor of 32.")
14891523

14901524
# Reset attn processor, we do not want to store attn maps during inversion
14911525
self.unet.set_attn_processor(AttnProcessor())
@@ -1510,7 +1544,14 @@ def invert(
15101544
do_classifier_free_guidance = source_guidance_scale > 1.0
15111545

15121546
# 1. prepare image
1513-
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
1547+
x0, resized = self.encode_image(
1548+
image,
1549+
dtype=self.text_encoder_2.dtype,
1550+
height=height,
1551+
width=width,
1552+
resize_mode=resize_mode,
1553+
crops_coords=crops_coords,
1554+
)
15141555
width = x0.shape[2] * self.vae_scale_factor
15151556
height = x0.shape[3] * self.vae_scale_factor
15161557
self.size = (height, width)

0 commit comments

Comments
 (0)