Skip to content

Commit

Permalink
Make naming more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Dec 12, 2024
1 parent 6561408 commit a08be65
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
drop=drop_rate,
mlp_ratio=mlp_ratio,
ln_eps=perceiver_ln_eps,
k_q_ln=stabilise_level_agg,
ln_k_q=stabilise_level_agg,
)

# Drop patches after encoding.
Expand Down
12 changes: 6 additions & 6 deletions aurora/model/perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
context_dim: int,
head_dim: int = 64,
num_heads: int = 8,
k_q_ln: bool = False,
ln_k_q: bool = False,
) -> None:
"""Initialise.
Expand All @@ -106,7 +106,7 @@ def __init__(
context_dim (int): Dimensionality of the context features also given as input.
head_dim (int): Attention head dimensionality.
num_heads (int): Number of heads.
k_q_ln (bool): Apply an extra layer norm. to the keys and queries.
ln_k_q (bool): Apply an extra layer norm. to the keys and queries.
"""
super().__init__()
self.num_heads = num_heads
Expand All @@ -117,7 +117,7 @@ def __init__(
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False)

if k_q_ln:
if ln_k_q:
self.ln_k = nn.LayerNorm(num_heads * head_dim)
self.ln_q = nn.LayerNorm(num_heads * head_dim)
else:
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
drop: float = 0.0,
residual_latent: bool = True,
ln_eps: float = 1e-5,
k_q_ln: bool = False,
ln_k_q: bool = False,
) -> None:
"""Initialise.
Expand All @@ -183,7 +183,7 @@ def __init__(
Defaults to `True`.
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
`1e-5`.
k_q_ln (bool, optional): Apply an extra layer norm. to the keys and queries of the first
ln_k_q (bool, optional): Apply an extra layer norm. to the keys and queries of the first
resampling layer. Defaults to `False`.
"""
super().__init__()
Expand All @@ -200,7 +200,7 @@ def __init__(
context_dim=context_dim,
head_dim=head_dim,
num_heads=num_heads,
k_q_ln=k_q_ln if i == 0 else False,
ln_k_q=ln_k_q if i == 0 else False,
),
MLP(dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop),
nn.LayerNorm(latent_dim, eps=ln_eps),
Expand Down

0 comments on commit a08be65

Please sign in to comment.