-
Notifications
You must be signed in to change notification settings - Fork 38
Diffusion v1 (closes #702) #944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
moritzhauschulz
wants to merge
98
commits into
ecmwf:develop
Choose a base branch
from
moritzhauschulz:issue702
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 tasks
Description will be updated regularly. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
DRAFT PR for diffusion forecasting engine...
I summarise the changes below, with some open questions (Q) and some notes (NB).
Arguments:
(see diff_config.yml)
forecast_policy: diffusion
fe_diff_sigma_min: 0.02
fe_diff_sigma_max: 88
fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1...
Components:
datasets/data_reader_base.py
datasets/multi_stream_data_sampler.py
datasets/stream_data.py
datasets/tokeniser_forecat.py
datasets/utils.py
train/loss_calculator.py
train/loss.py
train/trainer.py
train_trainer_base.py
model/model.py
(for encoding target variables)LinearNormConditioning
inattention
noise_conditioning
toMultiSelfAttentionHeadLocal
andMultiSelfAttentionHeadGlobal
ForecastingEngine
added…sigma_min
sigma_max
sigma_data
(this is set to 1, which assumes that that the latent channels are normalised to unit variance – this is currently not implemented, since clearly this depends on the encoder unlike in the gencast case where normalisation is simply done to the dataset; note however that we do apply this transformation prior to encoding)map_noise
which embeds the noise viaPositionalEmbedding
PositionalEmbedding
map_layer0
andmap_layer1
, two linear layers post processing the embeddingModel
classP_mean
,P_std
which are used to scale the noise added used to noise the target in training (NB: noise used at sampling is separate)forward
, add cases forforecast_policy==diffusion
distinguishing between training mode and notdenoise
to obtain predicted denoised residualsedm_sampler
to obtain predicted denoised residualstokens_all
andtokens_targets
– yet, ifloss_fcts
are passed, physical losses can also be computed and added up.forecast
added optional argumentnoise_conditioning
which contains the noise that is passed to the blocks in diffusion modeedm_denoise
(called during training) andedm_sampler
(called during inference)edm_denoise
sigma
from exponential normal distribution (meanP_mean
, varianceP_std
) –sigma
is the noise level per sample in batchsigma
then multiplied by per sample/cell/channel normal sample – this is the noise added to the target tokenedm_preconditioning
, see belowedm_sample
t_steps
edm_preconditioning
(see below) for the actual denoisingedm_preconditioning
map_noise
(PositionalEmbedding
) layerforecast
method before the final output is obtained via the following formula known form the EDM and GenCast papers:loss_calculator.py
weights_samples
which are computed inedm_noise
mse_channel_location_weighted
preds
with more comprehensiveout
out
containspreds, posteriors, weights, tokens_all, tokens_targets
tokens_all
are the predicted latents for each fstep andtokens_targets
are the true encoded targets for each fstepif self.loss_fcts_lat:
loss_fcts
though of course loss in physical space and latent space can be combined…Issue Number
Closes #702
Checklist before asking for review
./scripts/actions.sh lint
./scripts/actions.sh unit-test
./scripts/actions.sh integration-test
launch-slurm.py --time 60