Skip to content

How to implement batch norm #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
albertz opened this issue Jan 5, 2022 · 2 comments
Closed

How to implement batch norm #89

albertz opened this issue Jan 5, 2022 · 2 comments

Comments

@albertz
Copy link
Member

albertz commented Jan 5, 2022

There are a couple of open question regarding how to implement batch norm using the building blocks of returnn-common. Of course we could also wrap the existing BatchNormLayer in RETURNN (which needs rwth-i6/returnn#891 though) but even if this would be the implementation of BatchNorm in returnn-common, the question still remains how to implement it from scratch using the building blocks of returnn-common. In any case, this should be possible, and preferably also in a straight-forward way.

One question is, how to handle the the train flag. This is #18.

Another question is, how to do custom updates for the running statistic variables. This is #90.

Another question is, how to make use of the TF fused op, which would be important for efficiency. Specifically, tf.compat.v1.nn.fused_batch_norm with data_format="NCHW".

Related are also the batch norm defaults (#83) although not too relevant for the question on how to implement this.

@albertz
Copy link
Member Author

albertz commented Jan 5, 2022

Demo implementation:

class BatchNorm(nn.Module):

  def __init__(self, in_dim: Optional[nn.Dim] = None, *, affine: bool = True):
    """
    :param in_dim: the feature dimension of the input
    :param affine: whether to use learnable parameters gamma and beta
    """
    super().__init__()
    self.in_dim = in_dim
    self.mean = None  # type: Optional[nn.Parameter]
    self.var = None  # type: Optional[nn.Parameter]
    self.affine = affine
    self.gamma = None  # type: Optional[nn.Parameter]
    self.beta = None  # type: Optional[nn.Parameter]
    if in_dim:
      self._lazy_init(in_dim)

  def _lazy_init(self, in_dim: nn.Dim):
    self.in_dim = in_dim
    self.mean = nn.Parameter([in_dim], auxiliary=True)
    self.var = nn.Parameter([in_dim], auxiliary=True)
    if self.affine:
      self.gamma = nn.Parameter([in_dim])
      self.beta = nn.Parameter([in_dim])

  def __call__(self, source: nn.LayerRef, *, epsilon=1e-5, momentum=0.1) -> nn.Layer:
    source = nn.check_in_feature_dim_lazy_init(source, self.in_dim, self._lazy_init)
    reduce_dims = [d for d in source.data.dim_tags if d != self.in_dim]
    with nn.Cond(nn.get_train_flag()) as cond:
      mean_cur_batch, var_cur_batch = moments(source, reduce_dims)
      cond.else((mean_cur_batch, var_cur_batch))
      mean, var = cond.end((self.mean, self.var))
    with nn.Cond(nn.get_train_flag()) as cond:  # separate Cond such that this can be delayed
      self.mean.assign_add((mean - self.mean) * momentum)
      self.var.assign_add((var - self.var) * momentum)
    return (source - mean) * nn.rsqrt(var + epsilon)

albertz added a commit that referenced this issue Jan 6, 2022
@albertz
Copy link
Member Author

albertz commented Jan 6, 2022

Because we want that the fused op is used when possible, we need to wrap some RETURNN layer in any case. Because of this, I just wrapped the whole module to the RETURNN layer now.

Some of these questions (how to handle/implement custom aux var updates) still remain but are not needed anymore for this case. Also, basically the case here is also clear, and just needs to be implemented. Which can be done once it is needed (maybe for sth else). But this is #90. And maybe also #18.

So I will close this now.

@albertz albertz closed this as completed Jan 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant