-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Description
What API design would you like to have changed or added to the library? Why?
My proposal is for schedulers to expose convert_to_sample_prediction
and convert_to_prediction_type
methods, which would do the following:
-
convert_to_sample_prediction
: Converts from a givenprediction_type
tosample_prediction
(e.g.$x_0$ -prediction). This function would accept aprediction_type
argument which defaults toself.config.prediction_type
. -
convert_to_prediction_type
: Converts back fromsample_prediction
to the scheduler'sprediction_type
. This is intended to be the inverse function ofconvert_to_sample_prediction
.
The motivating use case I have in mind is to support guidance strategies such as Adaptive Projected Guidance (APG) and Frequency-Decoupled Guidance (FDG) which prefer to operate with sample /
The reason I think schedulers should expose these methods explicitly is that performing these operations depend on the scheduler state and definition. For example, the prediction type conversion code in EulerDiscreteScheduler
depends on the self.sigmas
schedule:
diffusers/src/diffusers/schedulers/scheduling_euler_discrete.py
Lines 650 to 663 in ba2ba90
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
# NOTE: "original_sample" should not be an expected prediction_type but is left in for | |
# backwards compatibility | |
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
elif self.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma_hat * model_output | |
elif self.config.prediction_type == "v_prediction": | |
# denoised = model_output * c_out + input * c_skip | |
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
) |
As a possible alternative, code that uses a scheduler could instead try to infer the prediction type conversion logic from the presence of alphas_cumprod
(for a DDPM-style conversion) or sigmas
(for an EDM-style conversion) attributes. However, I think this is unreliable because a scheduler could use alphas_cumprod
or sigmas
in a non-standard way. Since schedulers essentially already implement the convert_to_sample_prediction
logic in their step
methods, I think it could be relatively easy to implement these methods, and calling code would not have to guess how to do the conversion.
A potential difficulty is ensuring that these methods work well with the step
method, for example if they are called outside of a denoising loop (so internal state like self.step_index
may not be properly initialized) or if the conversion can be non-deterministic (for example, when gamma > 0
in EulerDiscreteScheduler
).
What use case would this enable or better enable? Can you give us a code example?
The motivating use case is to support guidance strategies which prefer to operate with sample_prediction
, run the guider's __call__
logic, and then convert back to the scheduler's prediction_type
(as schedulers currently expect model_outputs
in that prediction_type
).
There may be other potential use cases as well that I haven't thought of.
As a concrete example, we can imagine modifying EulerDiscreteScheduler
as follows:
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
def convert_to_sample_prediction(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
prediction_type: Optional[str] = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
if prediction_type is None:
prediction_type = self.config.prediction_type
# NOTE: there's a potential catch here if self.step_index isn't properly initialized
sigma = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
sigma_hat = sigma * (gamma + 1)
# NOTE: another potential problem is ensuring consistent computation with `step` if the conversion
# can be non-deterministic (as below)
if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# Compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# denoised = model_output * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
return pred_original_sample
def convert_to_prediction_type(
self,
pred_original_sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
# Convert back to `prediction_type` from `pred_original_sample` and `sample`
# A potential problem is correctly reconstructing `sample` in the `gamma > 0` case due to non-determinism
pass
...
We could then replace the prediction type conversion logic in step
with a call to convert_to_sample_prediction
.
In calling code such as a DiffusionPipeline
or PipelineBlock
such as StableDiffusionXLLoopDenoiser
, we could use the methods as follows:
class StableDiffusionXLLoopDenoiser(PipelineBlock):
...
def __call__(...):
...
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
...
model_output = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=cond_kwargs,
return_dict=False,
)[0]
# Assume guiders expose a config that says that they prefer sample predictions to e.g. noise predictions
if components.guider.config.prefer_sample_input:
model_output = components.scheduler.convert_to_sample_prediction(
model_output, t, block_state.latents
)
guider_state_batch.noise_pred = model_output
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
# Convert back to prediction_type expected by scheduler
if components.guider.config.prefer_sample_input:
block_state.noise_pred = components.scheduler.convert_to_prediction_type(
block_state.noise_pred, t, block_state.latents
)
...
...
One potential downside here is that it makes the guidance logic in the denoising loop more complicated, although I guess the calling pipeline code could choose not to honor the prefer_sample_input
config.