14
14
15
15
import math
16
16
from dataclasses import dataclass
17
- from typing import Optional , Tuple , Union
17
+ from typing import List , Optional , Tuple , Union
18
18
19
19
import torch
20
20
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
77
77
Video](https://imagen.research.google/video/paper.pdf) paper).
78
78
rho (`float`, *optional*, defaults to 7.0):
79
79
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
80
+ final_sigmas_type (`str`, defaults to `"zero"`):
81
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
82
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
80
83
"""
81
84
82
85
_compatibles = []
@@ -92,22 +95,32 @@ def __init__(
92
95
num_train_timesteps : int = 1000 ,
93
96
prediction_type : str = "epsilon" ,
94
97
rho : float = 7.0 ,
98
+ final_sigmas_type : str = "zero" , # can be "zero" or "sigma_min"
95
99
):
96
100
if sigma_schedule not in ["karras" , "exponential" ]:
97
101
raise ValueError (f"Wrong value for provided for `{ sigma_schedule = } `.`" )
98
102
99
103
# setable values
100
104
self .num_inference_steps = None
101
105
102
- ramp = torch .linspace ( 0 , 1 , num_train_timesteps )
106
+ sigmas = torch .arange ( num_train_timesteps + 1 ) / num_train_timesteps
103
107
if sigma_schedule == "karras" :
104
- sigmas = self ._compute_karras_sigmas (ramp )
108
+ sigmas = self ._compute_karras_sigmas (sigmas )
105
109
elif sigma_schedule == "exponential" :
106
- sigmas = self ._compute_exponential_sigmas (ramp )
110
+ sigmas = self ._compute_exponential_sigmas (sigmas )
107
111
108
112
self .timesteps = self .precondition_noise (sigmas )
109
113
110
- self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
114
+ if self .config .final_sigmas_type == "sigma_min" :
115
+ sigma_last = sigmas [- 1 ]
116
+ elif self .config .final_sigmas_type == "zero" :
117
+ sigma_last = 0
118
+ else :
119
+ raise ValueError (
120
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
121
+ )
122
+
123
+ self .sigmas = torch .cat ([sigmas , torch .full ((1 ,), fill_value = sigma_last , device = sigmas .device )])
111
124
112
125
self .is_scale_input_called = False
113
126
@@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
197
210
self .is_scale_input_called = True
198
211
return sample
199
212
200
- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
213
+ def set_timesteps (
214
+ self ,
215
+ num_inference_steps : int = None ,
216
+ device : Union [str , torch .device ] = None ,
217
+ sigmas : Optional [Union [torch .Tensor , List [float ]]] = None ,
218
+ ):
201
219
"""
202
220
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
203
221
@@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206
224
The number of diffusion steps used when generating samples with a pre-trained model.
207
225
device (`str` or `torch.device`, *optional*):
208
226
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
227
+ sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
228
+ Custom sigmas to use for the denoising process. If not defined, the default behavior when
229
+ `num_inference_steps` is passed will be used.
209
230
"""
210
231
self .num_inference_steps = num_inference_steps
211
232
212
- ramp = torch .linspace (0 , 1 , self .num_inference_steps )
233
+ if sigmas is None :
234
+ sigmas = torch .linspace (0 , 1 , self .num_inference_steps )
235
+ elif isinstance (sigmas , float ):
236
+ sigmas = torch .tensor (sigmas , dtype = torch .float32 )
237
+ else :
238
+ sigmas = sigmas
213
239
if self .config .sigma_schedule == "karras" :
214
- sigmas = self ._compute_karras_sigmas (ramp )
240
+ sigmas = self ._compute_karras_sigmas (sigmas )
215
241
elif self .config .sigma_schedule == "exponential" :
216
- sigmas = self ._compute_exponential_sigmas (ramp )
242
+ sigmas = self ._compute_exponential_sigmas (sigmas )
217
243
218
244
sigmas = sigmas .to (dtype = torch .float32 , device = device )
219
245
self .timesteps = self .precondition_noise (sigmas )
220
246
221
- self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
247
+ if self .config .final_sigmas_type == "sigma_min" :
248
+ sigma_last = sigmas [- 1 ]
249
+ elif self .config .final_sigmas_type == "zero" :
250
+ sigma_last = 0
251
+ else :
252
+ raise ValueError (
253
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
254
+ )
255
+
256
+ self .sigmas = torch .cat ([sigmas , torch .full ((1 ,), fill_value = sigma_last , device = sigmas .device )])
222
257
self ._step_index = None
223
258
self ._begin_index = None
224
259
self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
0 commit comments