Skip to content

Commit 464374f

Browse files
authored
EDMEulerScheduler accept sigmas, add final_sigmas_type (#10734)
1 parent d43ce14 commit 464374f

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import math
1616
from dataclasses import dataclass
17-
from typing import Optional, Tuple, Union
17+
from typing import List, Optional, Tuple, Union
1818

1919
import torch
2020

@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
7777
Video](https://imagen.research.google/video/paper.pdf) paper).
7878
rho (`float`, *optional*, defaults to 7.0):
7979
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
80+
final_sigmas_type (`str`, defaults to `"zero"`):
81+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
82+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
8083
"""
8184

8285
_compatibles = []
@@ -92,22 +95,32 @@ def __init__(
9295
num_train_timesteps: int = 1000,
9396
prediction_type: str = "epsilon",
9497
rho: float = 7.0,
98+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
9599
):
96100
if sigma_schedule not in ["karras", "exponential"]:
97101
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
98102

99103
# setable values
100104
self.num_inference_steps = None
101105

102-
ramp = torch.linspace(0, 1, num_train_timesteps)
106+
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
103107
if sigma_schedule == "karras":
104-
sigmas = self._compute_karras_sigmas(ramp)
108+
sigmas = self._compute_karras_sigmas(sigmas)
105109
elif sigma_schedule == "exponential":
106-
sigmas = self._compute_exponential_sigmas(ramp)
110+
sigmas = self._compute_exponential_sigmas(sigmas)
107111

108112
self.timesteps = self.precondition_noise(sigmas)
109113

110-
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
114+
if self.config.final_sigmas_type == "sigma_min":
115+
sigma_last = sigmas[-1]
116+
elif self.config.final_sigmas_type == "zero":
117+
sigma_last = 0
118+
else:
119+
raise ValueError(
120+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
121+
)
122+
123+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
111124

112125
self.is_scale_input_called = False
113126

@@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
197210
self.is_scale_input_called = True
198211
return sample
199212

200-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
213+
def set_timesteps(
214+
self,
215+
num_inference_steps: int = None,
216+
device: Union[str, torch.device] = None,
217+
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
218+
):
201219
"""
202220
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
203221
@@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206224
The number of diffusion steps used when generating samples with a pre-trained model.
207225
device (`str` or `torch.device`, *optional*):
208226
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
227+
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
228+
Custom sigmas to use for the denoising process. If not defined, the default behavior when
229+
`num_inference_steps` is passed will be used.
209230
"""
210231
self.num_inference_steps = num_inference_steps
211232

212-
ramp = torch.linspace(0, 1, self.num_inference_steps)
233+
if sigmas is None:
234+
sigmas = torch.linspace(0, 1, self.num_inference_steps)
235+
elif isinstance(sigmas, float):
236+
sigmas = torch.tensor(sigmas, dtype=torch.float32)
237+
else:
238+
sigmas = sigmas
213239
if self.config.sigma_schedule == "karras":
214-
sigmas = self._compute_karras_sigmas(ramp)
240+
sigmas = self._compute_karras_sigmas(sigmas)
215241
elif self.config.sigma_schedule == "exponential":
216-
sigmas = self._compute_exponential_sigmas(ramp)
242+
sigmas = self._compute_exponential_sigmas(sigmas)
217243

218244
sigmas = sigmas.to(dtype=torch.float32, device=device)
219245
self.timesteps = self.precondition_noise(sigmas)
220246

221-
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
247+
if self.config.final_sigmas_type == "sigma_min":
248+
sigma_last = sigmas[-1]
249+
elif self.config.final_sigmas_type == "zero":
250+
sigma_last = 0
251+
else:
252+
raise ValueError(
253+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
254+
)
255+
256+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
222257
self._step_index = None
223258
self._begin_index = None
224259
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

0 commit comments

Comments
 (0)