Skip to content

Commit

Permalink
Fixed GroupNorm raising a spurious runtime error (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger authored Jun 6, 2022
1 parent aa0d968 commit b6abc4c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
5 changes: 4 additions & 1 deletion equinox/nn/array_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Tuple

import jax.numpy as jnp

from ..custom_types import Array


def left_broadcast_to(arr: Array, shape: Tuple[int]):
return arr.reshape(shape + (1,) * (len(shape) - arr.ndim))
arr = arr.reshape(arr.shape + (1,) * (len(shape) - arr.ndim))
return jnp.broadcast_to(arr, shape)
4 changes: 4 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ def test_group_norm(getkey):
x = jrandom.uniform(getkey(), (128,))
assert gn(x).shape == (128,)

gn = eqx.nn.GroupNorm(groups=4, channels=128)
x = jrandom.uniform(getkey(), (128, 4, 5))
assert gn(x).shape == (128, 4, 5)

gn = eqx.nn.GroupNorm(groups=4, channels=128, channelwise_affine=False)
x = jrandom.uniform(getkey(), (128, 4, 5))
assert gn(x).shape == (128, 4, 5)
Expand Down

0 comments on commit b6abc4c

Please sign in to comment.