Skip to content

Failed to load model in FSDP #1062

@ankitpatnala

Description

@ankitpatnala

What happened?

I am using a MTM pre-trained model.
For finetuning, I want to use forecasting blocks with updated forecasting parameters.
Then I am getting an error

14: [rank14]: RuntimeError: FSDP parameters should be materialized from meta device before tr
14: aining, but the following were still on meta device: ['fe_blocks.0.lnorm.embed_aux.0.weight', 'fe_blocks.0.lnorm.embed_aux.0.bias', 'fe_blocks.0.lnorm.embed_aux.2.weight', 'fe_blocks.0.lnorm.embed_aux.2.bias', 'fe_blocks.0.proj_heads_q.weight', 'fe_blocks.0.proj_heads_k.weight', 'fe_blocks.0.proj_heads_v.weight', 'fe_blocks.0.proj_out.weight']
14: [rank14]: For example, call module.to_empty(device) to materialize to device and call module.reset_parameters() on each module to initialize values.

MTM config

streams_directory: "./config/streams/era5_1deg/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 0.2
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5 
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True 

loss_fcts:
  -
    - "mse"
    - 1.0
loss_fcts_val:
  -
    - "mse"
    - 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.8
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"], 
                          "probabilities": [0.34, 0.33, 0.33],
                          "hl_mask": 3, "mode": "per_cell",
                          "same_strategy_per_batch": false
                          }

num_epochs: 32
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512 
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 12
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
  # The period to log metrics (in number of batch steps)
  log_interval: 20

Forecast_config

streams_directory: "./config/streams/era5_1deg/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 0.2
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 1
forecast_delta_hrs: 0
forecast_steps: 2
forecast_policy: "fixed"
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
fe_num_blocks: 8
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5 
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True 

loss_fcts:
  -
    - "mse"
    - 1.0
loss_fcts_val:
  -
    - "mse"
    - 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"], 
                          "probabilities": [0.34, 0.33, 0.33],
                          "hl_mask": 3, "mode": "per_cell",
                          "same_strategy_per_batch": false
                          }

num_epochs: 64
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 0.0001
lr_final_decay: 2e-6
lr_final: 0.0
lr_steps_warmup: 256
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 6
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
  # The period to log metrics (in number of batch steps)
  log_interval: 20

launch_script: ./WeatherGenerator-private/hpc/launch-slurm.py --nodes=4 --ntasks=16 --account=hclimrep --partition=booster --time=24:00:00 --from-run-id=xepvwsy1 --config=./forecasting_config.yml

What are the steps to reproduce the bug?

Train_continue from any masked pre-trained model with forecast config.

Version

0.1

Platform (OS and architecture)

JUWELS linux

Relevant log output

Accompanying data

ERA5

Organisation

JSC

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmodelRelated to model training or definition (not generic infra)

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions