Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Sep 9, 2024
1 parent 26c2af9 commit cfef01c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 26 deletions.
15 changes: 8 additions & 7 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,11 @@ def __init__(
Args:
surf_vars (tuple[str, ...], optional): All surface-level variables supported by the
model. The model is sensitive to the order of `surf_vars`! Currently, adding
one more variable here causes the model to incorrectly load the static variables.
It is possible to hack around this. We are working on a more principled fix. Please
open an issue if this is a problem for you.
model.
static_vars (tuple[str, ...], optional): All static variables supported by the
model. The model is sensitive to the order of `static_vars`!
model.
atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the
model. The model is sensitive to the order of `atmos-vars`!
model.
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
Expand Down Expand Up @@ -231,7 +228,11 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None:
path = hf_hub_download(repo_id=repo, filename=name)
d = torch.load(path, map_location=device, weights_only=True)

# Rename keys to ensure compatibility.
# You can safely ignore all cumbersome processing below. We modified the model after we
# trained it. The code below manually adapts the checkpoints, so the checkpoints are
# compatible with the new model.

# Remove possibly prefix from the keys.
for k, v in list(d.items()):
if k.startswith("net."):
del d[k]
Expand Down
11 changes: 0 additions & 11 deletions docs/beware.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,3 @@ If you changed the model and added or removed parameters, you need to set `stric
loading a checkpoint `Aurora.load_checkpoint(..., strict=False)`.
Importantly, enabling or disabling LoRA for a model that was trained respectively without or
with LoRA changes the parameters!

## Extending the Model with New Surface-Level Variables

Whereas we have attempted to design a robust and flexible model,
inevitably some unfortunate design choices slipped through.

A notable unfortunate design choice is that extending the model with a new surface-level
variable breaks compatibility with existing checkpoints.
It is possible to hack around this in a relatively simple way.
We are working on a more principled fix.
Please open an issue if this is a problem for you.
44 changes: 36 additions & 8 deletions docs/finetuning.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fine-Tuning

If you wish to fine-tune Aurora for you specific application,
you should use the pretrained version:
Generally, if you wish to fine-tune Aurora for a specific application,
you should build on the pretrained version:

```python
from aurora import Aurora
Expand All @@ -10,21 +10,49 @@ model = Aurora(use_lora=False) # Model is not fine-tuned.
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
```

You are also free to extend the model for your particular use case.
In that case, it might be that you add or remove parameters.
## Extending Aurora with New Variables

Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`,
`static_vars`, and `atmos_vars`.
When you add a new variable, you also need to set the normalisation statistics.

```python
from aurora import Aurora
from aurora.normalisation import locations, scales

model = Aurora(
use_lora=False,
surf_vars=("2t", "10u", "10v", "msl", "new_surf_var"),
static_vars=("lsm", "z", "slt", "new_static_var"),
atmos_vars=("z", "u", "v", "t", "q", "new_atmos_var"),
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")

# Normalisation means:
locations["new_surf_var"] = 0.0
locations["new_static_var"] = 0.0
locations["new_atmos_var"] = 0.0

# Normalisation standard deviations:
scales["new_surf_var"] = 1.0
scales["new_static_var"] = 1.0
scales["new_atmos_var"] = 1.0
```

## Other Model Extensions

It is possible to extend to model in any way you like.
If you do this, you will likely you add or remove parameters.
Then `Aurora.load_checkpoint` will error,
because the existing checkpoint now mismatches with the model's parameters.
Simply set `Aurora.load_checkpoint(..., strict=False)`:
Simply set `Aurora.load_checkpoint(..., strict=False)` to ignore the mismatches:

```python
from aurora import Aurora


model = Aurora(...)

... # Modify `model`.

model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
```

More instructions coming soon!

0 comments on commit cfef01c

Please sign in to comment.