Skip to content

Commit

Permalink
Add more fine-tuning advice
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Dec 12, 2024
1 parent ecc7f7f commit 6561408
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/beware.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ exactly the right variables
at exactly the right pressure levels
from exactly the right source.

This also means that the performance of the model will be sensitive to how the
data is regridded.
For optimal performance, you should ensure that the data is regridded
exactly like the data seen during pretraining and fine-tuning.

(t0-vs-analysis)=
## HRES IFS T0 Versus HRES IFS Analysis

Expand Down
57 changes: 57 additions & 0 deletions docs/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ loss = ...
loss.backward()
```

## Exploding Gradients

When fine-tuning, you may run into very large gradient values.
Gradient clipping and internal layer normalisation layers mitigate the impact
of large gradients,
meaning that large gradients will not immediately lead to abnormal model outputs and loss values.
Nevertheless, if gradients do blow up, the model will not learn anymore and eventually the loss value
will also blow up.
You should carefully monitor the value of the gradients to detect exploding gradients.

One cause of exploding gradients is too large values for internal activations.
Typically this can be fixed by judiciously inserting a layer normalisation layer.

We have identified the level aggregation as weak point of the model that can be susceptible
to exploding gradients.
You can stabilise the level aggregation of the model
by setting the following flag in the constructor: `stabilise_level_agg=True`.
Note that `stabilise_level_agg=True` will considerably perturb the model,
so significant additional fine-tuning may be required to get to the desired level of performance.

```python
from aurora import Aurora
from aurora.normalisation import locations, scales

model = Aurora(
use_lora=False,
stabilise_level_agg=True, # Insert extra layer norm. to mitigate exploding gradients.
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
```

## Extending Aurora with New Variables

Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`,
Expand Down Expand Up @@ -66,6 +97,18 @@ scales["new_static_var"] = 1.0
scales["new_atmos_var"] = 1.0
```

To more efficiently learn new variables, it is recommended to use a separate learning rate for
the patch embeddings of the new variables in the encoder and decoder.
For example, if you are using Adam, you can try `1e-3` for the new patch embeddings
and `3e-4` for the other parameters.

By default, patch embeddings in the encoder for new variables are initialised randomly.
This means that adding new variables to the model perturbs the predictions for the existing
variables.
If you do not want this, you can alternatively initialise the new patch embeddings in the encoder
to zero.
The relevant parameter dictionaries are `model.encoder.{surf,atmos}_token_embeds.weights`.

## Other Model Extensions

It is possible to extend to model in any way you like.
Expand All @@ -83,3 +126,17 @@ model = Aurora(...)

model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
```

## Triple Check Your Fine-Tuning Data!

When fine-tuning the model, it is absolutely essential to carefully check your fine-tuning data.

* Are the old (and possibly new) normalisation statistics appropriate for the new data?

* Is any data missing?

* Does the data contains zeros or NaNs?

* Does the data contain any outliers that could possibly interfere with fine-tuning?

_Et cetera._

0 comments on commit 6561408

Please sign in to comment.