Skip to content

Conversation

moritzhauschulz
Copy link
Contributor

@moritzhauschulz moritzhauschulz commented Sep 23, 2025

Description

DRAFT PR for diffusion forecasting engine...

I summarise the changes below, with some open questions (Q) and some notes (NB).

Arguments:

(see diff_config.yml)

forecast_policy: diffusion

loss_fcts_lat:
  -
    - "mse"
    - 1.0 

fe_diff_sigma_min: 0.02

fe_diff_sigma_max: 88

fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1...

Components:

  • some changes copied (or adapted) from [Kerem’s PR](https://iffchat.fz-juelich.de/mlesm-dev/messages/@kerem_tezcan) in:
    • datasets/data_reader_base.py
    • datasets/multi_stream_data_sampler.py
    • datasets/stream_data.py
    • datasets/tokeniser_forecat.py
    • datasets/utils.py
    • train/loss_calculator.py
    • train/loss.py
    • train/trainer.py
    • train_trainer_base.py
    • model/model.py (for encoding target variables)
  • new class LinearNormConditioning in attention
    • this basically acts as a conditional scale and offset layer which depends on the noise
    • added as noise_conditioning to MultiSelfAttentionHeadLocal and MultiSelfAttentionHeadGlobal
      • called after layer norm and prior to attention
        • Q: Should this replace the layer norm?
  • in ForecastingEngine added…
    • …new parameters:
      • sigma_min
      • sigma_max
      • sigma_data (this is set to 1, which assumes that that the latent channels are normalised to unit variance – this is currently not implemented, since clearly this depends on the encoder unlike in the gencast case where normalisation is simply done to the dataset; note however that we do apply this transformation prior to encoding)
    • …new layers:
      • map_noise which embeds the noise via PositionalEmbedding
        • PositionalEmbedding
        • Q: Alternatively, should we use the alternative embedding used in Gencast?
      • map_layer0 and map_layer1 , two linear layers post processing the embedding
        • Q: Why?
      • Q: Are we correctly freezing the model parameters for the new layers?
    • in Model class
      • add attributes P_mean , P_std which are used to scale the noise added used to noise the target in training (NB: noise used at sampling is separate)
        • Q: Why is that, and what is the difference?
      • in forward , add cases for forecast_policy==diffusion distinguishing between training mode and not
        • in both cases, the target (in latent space) has to be computed as the residual between source and target tokens
          • Q: What would be the impact if source and target tokens are based on different datasets?
        • in training mode: calls denoise to obtain predicted denoised residuals
        • in non-training mode: calls edm_sampler to obtain predicted denoised residuals
        • NB: loss should be computed in latent space, so the relevant comparison is between the items in tokens_all and tokens_targets – yet, if loss_fcts are passed, physical losses can also be computed and added up.
        • also carrying over target embedding functions from Kerem (see above)
        • in forecast added optional argument noise_conditioning which contains the noise that is passed to the blocks in diffusion mode
        • two new methods edm_denoise (called during training) and edm_sampler (called during inference)
          • edm_denoise
            • samples sigma from exponential normal distribution (mean P_mean, variance P_std ) – sigma is the noise level per sample in batch
            • sigma then multiplied by per sample/cell/channel normal sample – this is the noise added to the target token
            • per sample weight is computed according to formula from edm paper – this is later passed to the loss function
            • finally calls edm_preconditioning , see below
          • edm_sample
            • first computes the spacing of t_steps
            • then applies (2nd order) Heun sampling scheme using the conditioning tokens from the previous time step concatenated with normal noise
              • note that at sampling time we are denoising pure noise
              • parameters taken from gencast supplementary
              • note also that this is taken from the edm code and Gencast is using a variation of this…
          • both then call edm_preconditioning (see below) for the actual denoising
            • NB: might need to rename functions here to avoid confusion.
        • edm_preconditioning
          • this concatenates the conditioning tokens (i.e. the tokens from the previous time steps) with noised target (or the pure noise at inference time) and embeds the noise via the forecasting engine’s map_noise (PositionalEmbedding) layer
          • these are passed to the forecast method before the final output is obtained via the following formula known form the EDM and GenCast papers:
            • $D_{\theta}!\left( \mathbf{Z}^t_{\sigma}; \mathbf{X}^{t-1}, \sigma \right) := c_{\text{skip}}(\sigma) \cdot \mathbf{Z}^t_{\sigma} + c_{\text{out}}(\sigma) \cdot f_{\theta}!\left( c_{\text{in}}(\sigma)\mathbf{Z}^t_{\sigma}; \mathbf{X}^{t-1}, c_{\text{noise}}(\sigma) \right)$
              • NB: the current model is not able to condition on $\mathbf{X}^{t-2}$ – this requires further work on the data pipeline…
      • small changes in loss_calculator.py
        • introduced weights_samples which are computed in edm_noise
          • NB: currently only implemented in mse_channel_location_weighted
        • some refactoring, mainly replacing preds with more comprehensive out
          • out contains preds, posteriors, weights, tokens_all, tokens_targets
            • where tokens_all are the predicted latents for each fstep and tokens_targets are the true encoded targets for each fstep
        • when training diffusion model only in latent space, then the loss calculation happens via the block following if self.loss_fcts_lat:
        • NB: caution should be taken not to specify the loss_fcts though of course loss in physical space and latent space can be combined…

Issue Number

Closes #702

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@moritzhauschulz moritzhauschulz changed the title Issue702 Diffusion v1 (closes #702) Sep 25, 2025
@moritzhauschulz
Copy link
Contributor Author

moritzhauschulz commented Sep 25, 2025

Description will be updated regularly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Explore Diffusion Based Forecasting Engines

4 participants