Skip to content

API Suggestion: Expose Methods to Convert to Sample Prediction in Schedulers #12079

@dg845

Description

@dg845

What API design would you like to have changed or added to the library? Why?

My proposal is for schedulers to expose convert_to_sample_prediction and convert_to_prediction_type methods, which would do the following:

  1. convert_to_sample_prediction: Converts from a given prediction_type to sample_prediction (e.g. $x_0$-prediction). This function would accept a prediction_type argument which defaults to self.config.prediction_type.
  2. convert_to_prediction_type: Converts back from sample_prediction to the scheduler's prediction_type. This is intended to be the inverse function of convert_to_sample_prediction.

The motivating use case I have in mind is to support guidance strategies such as Adaptive Projected Guidance (APG) and Frequency-Decoupled Guidance (FDG) which prefer to operate with sample / $x_0$-predictions. A code example will be given below.

The reason I think schedulers should expose these methods explicitly is that performing these operations depend on the scheduler state and definition. For example, the prediction type conversion code in EulerDiscreteScheduler depends on the self.sigmas schedule:

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# denoised = model_output * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)

As a possible alternative, code that uses a scheduler could instead try to infer the prediction type conversion logic from the presence of alphas_cumprod (for a DDPM-style conversion) or sigmas (for an EDM-style conversion) attributes. However, I think this is unreliable because a scheduler could use alphas_cumprod or sigmas in a non-standard way. Since schedulers essentially already implement the convert_to_sample_prediction logic in their step methods, I think it could be relatively easy to implement these methods, and calling code would not have to guess how to do the conversion.

A potential difficulty is ensuring that these methods work well with the step method, for example if they are called outside of a denoising loop (so internal state like self.step_index may not be properly initialized) or if the conversion can be non-deterministic (for example, when gamma > 0 in EulerDiscreteScheduler).

What use case would this enable or better enable? Can you give us a code example?

The motivating use case is to support guidance strategies which prefer to operate with $x_0$-predictions. For this use case, we want to convert the denoising model prediction to sample_prediction, run the guider's __call__ logic, and then convert back to the scheduler's prediction_type (as schedulers currently expect model_outputs in that prediction_type).

There may be other potential use cases as well that I haven't thought of.

As a concrete example, we can imagine modifying EulerDiscreteScheduler as follows:

class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
    ...
    def convert_to_sample_prediction(
        self,
        model_output: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
        prediction_type: Optional[str] = None,
        s_churn: float = 0.0,
        s_tmin: float = 0.0,
        s_tmax: float = float("inf"),
        s_noise: float = 1.0,
        generator: Optional[torch.Generator] = None,
    ) -> torch.Tensor:
        if prediction_type is None:
            prediction_type = self.config.prediction_type

        # NOTE: there's a potential catch here if self.step_index isn't properly initialized
        sigma = self.sigmas[self.step_index]
        gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
        sigma_hat = sigma * (gamma + 1)

        # NOTE: another potential problem is ensuring consistent computation with `step` if the conversion
        # can be non-deterministic (as below)
        if gamma > 0:
            noise = randn_tensor(
                model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
            )
            eps = noise * s_noise
            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

        # Compute predicted original sample (x_0) from sigma-scaled predicted noise
        # NOTE: "original_sample" should not be an expected prediction_type but is left in for
        # backwards compatibility
        if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
            pred_original_sample = model_output
        elif self.config.prediction_type == "epsilon":
            pred_original_sample = sample - sigma_hat * model_output
        elif self.config.prediction_type == "v_prediction":
            # denoised = model_output * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )

        return pred_original_sample

    def convert_to_prediction_type(
        self,
        pred_original_sample: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
        s_churn: float = 0.0,
        s_tmin: float = 0.0,
        s_tmax: float = float("inf"),
        s_noise: float = 1.0,
        generator: Optional[torch.Generator] = None,
    ) -> torch.Tensor:
        # Convert back to `prediction_type` from `pred_original_sample` and `sample`
        # A potential problem is correctly reconstructing `sample` in the `gamma > 0` case due to non-determinism
        pass
    ...

We could then replace the prediction type conversion logic in step with a call to convert_to_sample_prediction.

In calling code such as a DiffusionPipeline or PipelineBlock such as StableDiffusionXLLoopDenoiser, we could use the methods as follows:

class StableDiffusionXLLoopDenoiser(PipelineBlock):
    ...
    def __call__(...):
        ...
        # run the denoiser for each guidance batch
        for guider_state_batch in guider_state:
            ...
            model_output = components.unet(
                block_state.scaled_latents,
                t,
                encoder_hidden_states=prompt_embeds,
                timestep_cond=block_state.timestep_cond,
                cross_attention_kwargs=block_state.cross_attention_kwargs,
                added_cond_kwargs=cond_kwargs,
                return_dict=False,
            )[0]
            # Assume guiders expose a config that says that they prefer sample predictions to e.g. noise predictions
            if components.guider.config.prefer_sample_input:
                model_output = components.scheduler.convert_to_sample_prediction(
                    model_output, t, block_state.latents
                )
            guider_state_batch.noise_pred = model_output
            components.guider.cleanup_models(components.unet)
        
        # Perform guidance
        block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)

        # Convert back to prediction_type expected by scheduler
        if components.guider.config.prefer_sample_input:
            block_state.noise_pred = components.scheduler.convert_to_prediction_type(
                block_state.noise_pred, t, block_state.latents
            )
        ...
    ...

One potential downside here is that it makes the guidance logic in the denoising loop more complicated, although I guess the calling pipeline code could choose not to honor the prefer_sample_input config.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions