Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
1c7d04e
added the era5 v8 file
kctezcan Aug 11, 2025
03588f5
functions for encoding target source like
kctezcan Aug 11, 2025
5067531
more functions for source like encoding
kctezcan Aug 11, 2025
68a271c
take only preds_all in loss calc
kctezcan Aug 11, 2025
f848c57
set epoch=epoch in inference
kctezcan Aug 11, 2025
599feaf
smaller stuff, nccl, paths etc
kctezcan Aug 11, 2025
9376353
unpack preds
kctezcan Aug 12, 2025
95f599b
ruff
kctezcan Aug 12, 2025
9b72723
added the necessary changes to the default config for visibility
kctezcan Aug 12, 2025
c797622
added the latent loss
kctezcan Aug 12, 2025
fe409eb
added terminal logging
kctezcan Aug 12, 2025
1e275f9
ensure working without the latent loss
kctezcan Aug 12, 2025
01af6ef
more ensure working without the latent loss
kctezcan Aug 12, 2025
2980c44
merged dev
kctezcan Aug 12, 2025
7b22945
merged lat loss for streams
kctezcan Aug 12, 2025
babcd89
using tensor for lat_loss_hist
kctezcan Aug 12, 2025
153f88d
fixed terminal logging
kctezcan Aug 12, 2025
5b22fae
testing lat loss weight
kctezcan Aug 12, 2025
5d3ddc9
Cleaned up to use proper logger
clessig Aug 12, 2025
a707beb
Cleaned up to use proper logger
clessig Aug 12, 2025
3876de2
Fix logging: needs to be registered per output stream and not per log…
clessig Aug 12, 2025
f9365b9
Set logging level consistently with debug to file
clessig Aug 12, 2025
f24bf2b
Merge branch 'develop' into clessig/develop/fix_logging_719
clessig Aug 12, 2025
44e1be1
merged clessig/develop/fix_logging_719
kctezcan Aug 13, 2025
599b491
fixed CLs fix for duplicate logging
kctezcan Aug 13, 2025
0566b83
getting the embed and embed_pe from the srclk targets
kctezcan Aug 13, 2025
aece959
explicit detaching
kctezcan Aug 13, 2025
d4017dd
fixed bug in case fo=1, fs=1
kctezcan Aug 13, 2025
d5b17ca
fixed wrong time_win1-2
kctezcan Aug 13, 2025
ae28f98
fixed alignement of timesteps in the loss calculation
kctezcan Aug 13, 2025
689da80
added comment
kctezcan Aug 13, 2025
2a940c3
added the def config
kctezcan Aug 13, 2025
f2a7093
def config
kctezcan Aug 13, 2025
bdfcf1a
hotfix for device
kctezcan Aug 13, 2025
7ffa025
fix loss history if no lat loss
kctezcan Aug 14, 2025
7169d87
config
kctezcan Aug 14, 2025
3d292ee
merged dev
kctezcan Aug 14, 2025
f051b1a
trying stuff for NCCL
kctezcan Aug 14, 2025
5df4ef2
config
kctezcan Aug 14, 2025
6def791
adjusted run_evaluation and utils code to take into account forecast_…
moritzhauschulz Aug 15, 2025
ef5a577
print statement change
moritzhauschulz Aug 15, 2025
6984996
catching error when fstep not present in zarr file
moritzhauschulz Aug 18, 2025
7b0bdd2
upgrades based on PR feedback
moritzhauschulz Aug 19, 2025
9d2b23d
intermediate commit
moritzhauschulz Aug 19, 2025
b3ed5a2
intermediate commit
moritzhauschulz Aug 20, 2025
f7935a9
new functions _get_channels_fsteps_samples and check_metric
moritzhauschulz Aug 21, 2025
e417627
edited plotting
moritzhauschulz Aug 21, 2025
0c24902
Merge branch 'develop' into issue717
moritzhauschulz Aug 21, 2025
49cad51
working changes
kctezcan Sep 1, 2025
07be423
can batchify targets sources with diff no of channels
kctezcan Sep 2, 2025
7555c56
saving tokens to scratch
kctezcan Sep 2, 2025
d6812d6
Merge branch 'develop' into issue717
moritzhauschulz Sep 2, 2025
93b717a
inter commit
moritzhauschulz Sep 2, 2025
dfc3e44
seperating preds from tokens for writing inference
kctezcan Sep 3, 2025
7e89c83
added run_evaluate.py for debugging
kctezcan Sep 3, 2025
92b1144
fixed bug in get_data
moritzhauschulz Sep 3, 2025
9c8aee8
self review
moritzhauschulz Sep 3, 2025
64ae9ea
merge conf def conf
kctezcan Sep 3, 2025
92b08d1
name back the streams_anemoi
kctezcan Sep 3, 2025
ff6def4
changed the order of target/pred
kctezcan Sep 4, 2025
017db32
refactor
iluise Sep 4, 2025
94cf694
dummy commit
iluise Sep 4, 2025
8e73d9d
inter commit
moritzhauschulz Sep 5, 2025
f5ea7e0
Merge branch 'iluise/test717' into issue702
moritzhauschulz Sep 5, 2025
379d577
Merge branch 'develop' into issue717
moritzhauschulz Sep 5, 2025
9aa7722
feedback appleid
moritzhauschulz Sep 5, 2025
c2bd100
Merge branch 'issue702' into issue717
moritzhauschulz Sep 5, 2025
a3f2e28
incorporate review feedback
moritzhauschulz Sep 5, 2025
5496c1f
removed sorting of fsteps_final
moritzhauschulz Sep 5, 2025
34e36e0
remove comments
moritzhauschulz Sep 5, 2025
094f5bd
first draft of forward pass
Sep 9, 2025
79df612
first draft of train and sample – not running
Sep 9, 2025
8be832f
inter connect
moritzhauschulz Sep 10, 2025
1481a9f
error fixes and edited compute_iffsets_scatter_embed for target varia…
moritzhauschulz Sep 10, 2025
3b4ed4f
implemented provisional target embedding... to be tested
moritzhauschulz Sep 11, 2025
ef92e1f
def config
kctezcan Sep 16, 2025
b1e4d43
inter commit
moritzhauschulz Sep 16, 2025
620a1c8
Merge branch 'ktezcan/dev/iss706_latent_loss' into issue702
moritzhauschulz Sep 16, 2025
ae46101
Merge branch 'develop' into issue702
moritzhauschulz Sep 16, 2025
bda087d
added era5.yml back
moritzhauschulz Sep 18, 2025
6d46dd8
inter commit – debugging in progress
moritzhauschulz Sep 18, 2025
43e82ad
inter commit
moritzhauschulz Sep 18, 2025
22b8e17
Merge branch 'develop' into issue702
moritzhauschulz Sep 18, 2025
f0f530b
inter commit
moritzhauschulz Sep 18, 2025
5d058e4
inter commit
moritzhauschulz Sep 19, 2025
fbed208
creating draft PR
moritzhauschulz Sep 23, 2025
322bb84
reset default config
moritzhauschulz Sep 23, 2025
228cf43
remove default_config_diff
moritzhauschulz Sep 23, 2025
0be48aa
added diff_config
moritzhauschulz Sep 23, 2025
dc93b4e
inter commit
moritzhauschulz Sep 25, 2025
7423cea
minor refactoring model.py
moritzhauschulz Oct 7, 2025
eb338db
added pretrain config
moritzhauschulz Oct 7, 2025
4c5b77e
hot fix to adjust att dimension for concatenated input
moritzhauschulz Oct 8, 2025
79d3c78
inter commit
moritzhauschulz Oct 8, 2025
4fa8bdf
inter commit
moritzhauschulz Oct 8, 2025
5c4642c
inter commit
moritzhauschulz Oct 8, 2025
c54fb09
inter commit
moritzhauschulz Oct 8, 2025
dab1029
inter commit
moritzhauschulz Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions config/diff_config_forecast.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
streams_directory: "./config/streams/era5_1deg/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 256
ae_global_num_blocks: 8
ae_global_num_heads: 16
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 0.2
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 1
forecast_policy: "diffusion"
forecast_freeze_model: False
forecast_att_dense_rate: 0.2
fe_global_block_factor: 32
fe_local_num_queries: 1
fe_num_blocks: 8
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_diff_sigma_min: 0.02
fe_diff_sigma_max: 88
fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1...

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

loss_fcts_lat:
-
- "mse"
- 1.0
loss_fcts_val:
-
- "mse"
- 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"probabilities": [0.34, 0.33, 0.33],
"hl_mask": 3, "mode": "per_cell",
"same_strategy_per_batch": false
}

num_epochs: 32
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 6
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20
152 changes: 152 additions & 0 deletions config/diff_config_old.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
streams_directory: "./config/streams/era5_1deg/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 256
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 0.2
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 1
forecast_policy: "diffusion"
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
fe_num_blocks: 2
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_diff_sigma_min: 0.02
fe_diff_sigma_max: 88
fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1...

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

loss_fcts_lat:
-
- "mse"
- 1.0

loss_fcts:
-
- "mse"
- 1.0
loss_fcts_val:
-
- "mse"
- 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix" and "channel"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"hl_mask": 3}

num_epochs: 32
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 6
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20
Loading