Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AuroraHighRes #12

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aurora.model.lora import LoRAMode
from aurora.model.swin3d import Swin3DTransformerBackbone

__all__ = ["Aurora", "AuroraSmall"]
__all__ = ["Aurora", "AuroraSmall", "AuroraHighRes"]


class Aurora(torch.nn.Module):
Expand Down Expand Up @@ -63,14 +63,14 @@ def __init__(
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
encoder_num_heads (tuple[int, ...], optional) Number of attention heads in each encoder
encoder_num_heads (tuple[int, ...], optional): Number of attention heads in each encoder
layer. The dimensionality doubles after every layer. To keep the dimensionality of
every head constant, you want to double the number of heads after every layer. The
dimensionality of attention head of the first layer is determined by `embed_dim`
divided by the value here. For all cases except one, this is equal to `64`.
decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer.
Generally, you want this to be the reversal of `encoder_depths`.
decoder_num_heads (tuple[int, ...], optional) Number of attention heads in each decoder
decoder_num_heads (tuple[int, ...], optional): Number of attention heads in each decoder
layer. Generally, you want this to be the reversal of `encoder_num_heads`.
latent_levels (int, optional): Number of latent pressure levels.
patch_size (int, optional): Patch size.
Expand Down Expand Up @@ -250,3 +250,9 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None:
num_heads=8,
use_lora=False,
)

AuroraHighRes = partial(
Aurora,
encoder_depths=(6, 8, 8),
decoder_depths=(8, 8, 6),
)
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ Models

.. autoclass:: aurora.AuroraSmall
:members:

.. autoclass:: aurora.AuroraHighRes
:members:
8 changes: 4 additions & 4 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ Aurora 0.1° Fine-Tuned is a high-resolution version of Aurora.
### Usage

```python
from aurora import Aurora
from aurora import AuroraHighRes

model = Aurora()
model = AuroraHighRes()
model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt")
```

Expand Down Expand Up @@ -170,8 +170,8 @@ Therefore, you should use the static variables provided in
you can turn off LoRA to obtain more realistic predictions at the expensive of slightly higher long-term MSE:

```python
from aurora import Aurora
from aurora import AuroraHighRes

model = Aurora(use_lora=False) # Disable LoRA for more realistic samples.
model = AuroraHighRes(use_lora=False) # Disable LoRA for more realistic samples.
model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt", strict=False)
```
Loading