Skip to content

Commit

Permalink
deploy: 00a1a7a
Browse files Browse the repository at this point in the history
  • Loading branch information
a-lucic committed Sep 15, 2024
1 parent 3a52249 commit d9deed2
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 26 deletions.
40 changes: 32 additions & 8 deletions _modules/aurora/batch.html
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,23 @@ <h1>Source code for aurora.batch</h1><div class="highlight"><pre>

<div class="viewcode-block" id="Batch.normalise">
<a class="viewcode-back" href="../../api.html#aurora.Batch.normalise">[docs]</a>
<span class="k">def</span> <span class="nf">normalise</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;Batch&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Normalise all variables in the batch.&quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">normalise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">surf_stats</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="s2">&quot;Batch&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Normalise all variables in the batch.</span>

<span class="sd"> Args:</span>
<span class="sd"> surf_stats (dict[str, tuple[float, float]]): For these surface-level variables, adjust</span>
<span class="sd"> the normalisation to the given tuple consisting of a new location and scale.</span>

<span class="sd"> Returns:</span>
<span class="sd"> :class:`.Batch`: Normalised batch.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">Batch</span><span class="p">(</span>
<span class="n">surf_vars</span><span class="o">=</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">normalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">surf_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()},</span>
<span class="n">static_vars</span><span class="o">=</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">normalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">static_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()},</span>
<span class="n">surf_vars</span><span class="o">=</span><span class="p">{</span>
<span class="n">k</span><span class="p">:</span> <span class="n">normalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">stats</span><span class="o">=</span><span class="n">surf_stats</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">surf_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">static_vars</span><span class="o">=</span><span class="p">{</span>
<span class="n">k</span><span class="p">:</span> <span class="n">normalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">stats</span><span class="o">=</span><span class="n">surf_stats</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">static_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">atmos_vars</span><span class="o">=</span><span class="p">{</span>
<span class="n">k</span><span class="p">:</span> <span class="n">normalise_atmos_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">atmos_levels</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">atmos_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
Expand All @@ -437,11 +449,23 @@ <h1>Source code for aurora.batch</h1><div class="highlight"><pre>

<div class="viewcode-block" id="Batch.unnormalise">
<a class="viewcode-back" href="../../api.html#aurora.Batch.unnormalise">[docs]</a>
<span class="k">def</span> <span class="nf">unnormalise</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;Batch&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Unnormalise all variables in the batch.&quot;&quot;&quot;</span>
<span class="k">def</span> <span class="nf">unnormalise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">surf_stats</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="s2">&quot;Batch&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Unnormalise all variables in the batch.</span>

<span class="sd"> Args:</span>
<span class="sd"> surf_stats (dict[str, tuple[float, float]]): For these surface-level variables, adjust</span>
<span class="sd"> the normalisation to the given tuple consisting of a new location and scale.</span>

<span class="sd"> Returns:</span>
<span class="sd"> :class:`.Batch`: Unnormalised batch.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">Batch</span><span class="p">(</span>
<span class="n">surf_vars</span><span class="o">=</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">unnormalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">surf_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()},</span>
<span class="n">static_vars</span><span class="o">=</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">unnormalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">static_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()},</span>
<span class="n">surf_vars</span><span class="o">=</span><span class="p">{</span>
<span class="n">k</span><span class="p">:</span> <span class="n">unnormalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">stats</span><span class="o">=</span><span class="n">surf_stats</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">surf_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">static_vars</span><span class="o">=</span><span class="p">{</span>
<span class="n">k</span><span class="p">:</span> <span class="n">unnormalise_surf_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">stats</span><span class="o">=</span><span class="n">surf_stats</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">static_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">atmos_vars</span><span class="o">=</span><span class="p">{</span>
<span class="n">k</span><span class="p">:</span> <span class="n">unnormalise_atmos_var</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">metadata</span><span class="o">.</span><span class="n">atmos_levels</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">atmos_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
Expand Down
12 changes: 10 additions & 2 deletions _modules/aurora/model/aurora.html
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ <h1>Source code for aurora.model.aurora</h1><div class="highlight"><pre>
<span class="kn">import</span> <span class="nn">dataclasses</span>
<span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">timedelta</span>
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>

<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">huggingface_hub</span> <span class="kn">import</span> <span class="n">hf_hub_download</span>
Expand Down Expand Up @@ -395,6 +396,7 @@ <h1>Source code for aurora.model.aurora</h1><div class="highlight"><pre>
<span class="n">use_lora</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">lora_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">40</span><span class="p">,</span>
<span class="n">lora_mode</span><span class="p">:</span> <span class="n">LoRAMode</span> <span class="o">=</span> <span class="s2">&quot;single&quot;</span><span class="p">,</span>
<span class="n">surf_stats</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">autocast</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Construct an instance of the model.</span>
Expand Down Expand Up @@ -441,13 +443,17 @@ <h1>Source code for aurora.model.aurora</h1><div class="highlight"><pre>
<span class="sd"> lora_mode (str, optional): LoRA mode. `&quot;single&quot;` uses the same LoRA for all roll-out</span>
<span class="sd"> steps, and `&quot;all&quot;` uses a different LoRA for every roll-out step. Defaults to</span>
<span class="sd"> `&quot;single&quot;`.</span>
<span class="sd"> surf_stats (dict[str, tuple[float, float]], optional): For these surface-level</span>
<span class="sd"> variables, adjust the normalisation to the given tuple consisting of a new location</span>
<span class="sd"> and scale.</span>
<span class="sd"> autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to</span>
<span class="sd"> `False`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">surf_vars</span> <span class="o">=</span> <span class="n">surf_vars</span>
<span class="bp">self</span><span class="o">.</span><span class="n">atmos_vars</span> <span class="o">=</span> <span class="n">atmos_vars</span>
<span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">surf_stats</span> <span class="o">=</span> <span class="n">surf_stats</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">autocast</span> <span class="o">=</span> <span class="n">autocast</span>

<span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">Perceiver3DEncoder</span><span class="p">(</span>
Expand Down Expand Up @@ -511,7 +517,7 @@ <h1>Source code for aurora.model.aurora</h1><div class="highlight"><pre>
<span class="c1"># Get the first parameter. We&#39;ll derive the data type and device from this parameter.</span>
<span class="n">p</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">normalise</span><span class="p">()</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">normalise</span><span class="p">(</span><span class="n">surf_stats</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">surf_stats</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">crop</span><span class="p">(</span><span class="n">patch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>

Expand Down Expand Up @@ -560,7 +566,7 @@ <h1>Source code for aurora.model.aurora</h1><div class="highlight"><pre>
<span class="n">atmos_vars</span><span class="o">=</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">pred</span><span class="o">.</span><span class="n">atmos_vars</span><span class="o">.</span><span class="n">items</span><span class="p">()},</span>
<span class="p">)</span>

<span class="n">pred</span> <span class="o">=</span> <span class="n">pred</span><span class="o">.</span><span class="n">unnormalise</span><span class="p">()</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">pred</span><span class="o">.</span><span class="n">unnormalise</span><span class="p">(</span><span class="n">surf_stats</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">surf_stats</span><span class="p">)</span>

<span class="k">return</span> <span class="n">pred</span></div>

Expand Down Expand Up @@ -685,6 +691,8 @@ <h1>Source code for aurora.model.aurora</h1><div class="highlight"><pre>
<span class="n">patch_size</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="n">encoder_depths</span><span class="o">=</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">),</span>
<span class="n">decoder_depths</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span>
<span class="c1"># One particular static variable requires a different normalisation.</span>
<span class="n">surf_stats</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;z&quot;</span><span class="p">:</span> <span class="p">(</span><span class="o">-</span><span class="mf">3.270407e03</span><span class="p">,</span> <span class="mf">6.540335e04</span><span class="p">)},</span>
<span class="p">)</span>
</pre></div>

Expand Down
Loading

0 comments on commit d9deed2

Please sign in to comment.