From 6561408c157294a286dc47378433db6e36cee636 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Thu, 12 Dec 2024 13:44:27 +0100 Subject: [PATCH] Add more fine-tuning advice --- docs/beware.md | 5 ++++ docs/finetuning.md | 57 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/docs/beware.md b/docs/beware.md index f8235a1..6141661 100644 --- a/docs/beware.md +++ b/docs/beware.md @@ -17,6 +17,11 @@ exactly the right variables at exactly the right pressure levels from exactly the right source. +This also means that the performance of the model will be sensitive to how the +data is regridded. +For optimal performance, you should ensure that the data is regridded +exactly like the data seen during pretraining and fine-tuning. + (t0-vs-analysis)= ## HRES IFS T0 Versus HRES IFS Analysis diff --git a/docs/finetuning.md b/docs/finetuning.md index 266f894..4787242 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -37,6 +37,37 @@ loss = ... loss.backward() ``` +## Exploding Gradients + +When fine-tuning, you may run into very large gradient values. +Gradient clipping and internal layer normalisation layers mitigate the impact +of large gradients, +meaning that large gradients will not immediately lead to abnormal model outputs and loss values. +Nevertheless, if gradients do blow up, the model will not learn anymore and eventually the loss value +will also blow up. +You should carefully monitor the value of the gradients to detect exploding gradients. + +One cause of exploding gradients is too large values for internal activations. +Typically this can be fixed by judiciously inserting a layer normalisation layer. + +We have identified the level aggregation as weak point of the model that can be susceptible +to exploding gradients. +You can stabilise the level aggregation of the model +by setting the following flag in the constructor: `stabilise_level_agg=True`. +Note that `stabilise_level_agg=True` will considerably perturb the model, +so significant additional fine-tuning may be required to get to the desired level of performance. + +```python +from aurora import Aurora +from aurora.normalisation import locations, scales + +model = Aurora( + use_lora=False, + stabilise_level_agg=True, # Insert extra layer norm. to mitigate exploding gradients. +) +model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False) +``` + ## Extending Aurora with New Variables Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`, @@ -66,6 +97,18 @@ scales["new_static_var"] = 1.0 scales["new_atmos_var"] = 1.0 ``` +To more efficiently learn new variables, it is recommended to use a separate learning rate for +the patch embeddings of the new variables in the encoder and decoder. +For example, if you are using Adam, you can try `1e-3` for the new patch embeddings +and `3e-4` for the other parameters. + +By default, patch embeddings in the encoder for new variables are initialised randomly. +This means that adding new variables to the model perturbs the predictions for the existing +variables. +If you do not want this, you can alternatively initialise the new patch embeddings in the encoder +to zero. +The relevant parameter dictionaries are `model.encoder.{surf,atmos}_token_embeds.weights`. + ## Other Model Extensions It is possible to extend to model in any way you like. @@ -83,3 +126,17 @@ model = Aurora(...) model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False) ``` + +## Triple Check Your Fine-Tuning Data! + +When fine-tuning the model, it is absolutely essential to carefully check your fine-tuning data. + +* Are the old (and possibly new) normalisation statistics appropriate for the new data? + +* Is any data missing? + +* Does the data contains zeros or NaNs? + +* Does the data contain any outliers that could possibly interfere with fine-tuning? + +_Et cetera._