Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import dtypes, initializers
from flax import nnx
from flax.typing import (
Dtype,
Shape,
Expand Down Expand Up @@ -873,6 +872,8 @@ def maybe_broadcast(
class ConvTranspose(Module):
"""Convolution Module wrapping ``lax.conv_transpose``.

**Note:** The `padding` argument behaves differently from PyTorch; see the argument description below.

Example usage::

>>> from flax import nnx
Expand Down Expand Up @@ -919,13 +920,31 @@ class ConvTranspose(Module):
sequence of integers.
strides: an integer or a sequence of ``n`` integers, representing the
inter-window strides (default: 1).
padding: either the string ``'SAME'``, the string ``'VALID'``, the string
``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n``
``(low, high)`` integer pairs that give the padding to apply before and after each
padding: either a string indicating a specialized padding mode,
or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will
left-pad the convolution axis, resulting in same-sized output.
in all dims and a single int in a sequence causes the same padding
to be used on both sides.

**Note that this behavior is different from
PyTorch**. In PyTorch, the padding argument effectively adds ``dilation * (kernel_size - 1) - padding``
amount of zero padding to the input instead. This is set so that when ``torch.Conv2d`` and ``torch.ConvTranspose2d``
are initialized with the same parameters, they are inverses of each other in regard to the input and output shapes.
``nnx.Conv`` and ``nnx.ConvTranspose`` do *not* have this behavior; if you want a ``nnx.ConvTranspose`` layer
to invert the shape change produced by a ``nnx.Conv`` layer with a given padding and dilation, you should explicitly pass
``dilation * (kernel_size - 1) - padding`` as the `padding` argument to the ``nnx.ConvTranspose`` layer.

Strings for specifying padding modes can be one of the following:

- ``VALID`` adds ``dilation * (kernel_size - 1)`` padding to all dimensions. This is set so that a
``nnx.Conv`` layer with ``VALID`` padding would produce the inverse shape transformation.

- ``SAME`` pads the input so that the output shape is the same as the input shape.

- ``CIRCULAR`` pads the input with periodic boundary conditions.

- ``CAUSAL`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

kernel_dilation: an integer or a sequence of ``n`` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
Expand Down
Loading