Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions autograd/scipy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@

@primitive
def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
assert mode in ["valid", "full"], f"Mode {mode} not yet implemented"
assert mode in ["valid", "full", "same"], (
f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'"
)
if axes is None:
axes = [list(range(A.ndim)), list(range(A.ndim))]
wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)])
if wrong_order:
if mode == "valid" and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]):
raise Exception("One array must be larger than the other along all convolved dimensions")
elif mode != "full" or B.size <= A.size: # Tie breaker
elif mode == "valid" or (mode == "full" and B.size <= A.size): # Tie breaker
i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore
i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore
i3 = i2 + len(axes[0])
Expand All @@ -27,8 +29,12 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
ignore_A + ignore_B + conv
)

if mode == "full":
B = pad_to_full(B, A, axes[::-1])
if mode == "same":
B, A = A, B
axes = axes[::-1]
dot_axes = dot_axes[::-1]
if mode != "valid":
B = pad(B, A, axes[::-1], mode=mode)
B_view_shape = list(B.shape)
B_view_strides = list(B.strides)
flipped_idxs = [slice(None)] * A.ndim
Expand All @@ -40,7 +46,16 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
B_view = as_strided(B, B_view_shape, B_view_strides)
A_view = A[tuple(flipped_idxs)]
all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]]
return einsum_tensordot(A_view, B_view, all_axes)
if mode == "same":
i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore
i2 = i1 + len(axes[0])
i3 = i2 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore
ignore_A = list(range(i1))
conv = list(range(i1, i2))
ignore_B = list(range(i2, i3))
return einsum_tensordot(B_view, A_view, all_axes[::-1]).transpose(ignore_A + ignore_B + conv)
else:
return einsum_tensordot(A_view, B_view, all_axes)


def einsum_tensordot(A, B, axes, reverse=False):
Expand All @@ -54,10 +69,15 @@ def einsum_tensordot(A, B, axes, reverse=False):
return npo.einsum(A, A_axnums, B, B_axnums)


def pad_to_full(A, B, axes):
def pad(A, B, axes, mode="full"):
A_pad = [(0, 0)] * A.ndim
for ax_A, ax_B in zip(*axes):
A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
if mode == "full":
for ax_A, ax_B in zip(*axes):
A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
elif mode == "same":
for ax_A, ax_B in zip(*axes):
right_bound = (B.shape[ax_B] - 1) // 2
A_pad[ax_A] = (B.shape[ax_B] - 1 - right_bound, right_bound)
return npo.pad(A, A_pad, mode="constant")


Expand Down Expand Up @@ -124,7 +144,9 @@ def flipped_idxs(ndim, axes):


def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"):
assert mode in ["valid", "full"], f"Grad for mode {mode} not yet implemented"
assert mode in ["valid", "full", "same"], (
f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'"
)
axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode)
if argnum == 0:
X, Y = A, B
Expand All @@ -139,11 +161,23 @@ def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"):

if mode == "full":
new_mode = "valid"
else:
elif mode == "valid":
if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]["conv"], shapes[_Y_]["conv"])]):
new_mode = "full"
else:
new_mode = "valid"
elif mode == "same":
if _X_ == "A":
new_mode = "same"
else:
# This step makes convolve once differentiable, which is fine.
new_mode = "valid"
Y = pad(
Y,
X,
[axes[_Y_]["conv"], axes[_X_]["conv"]],
mode="same",
)

def vjp(g):
result = convolve(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_convolve_generalization():
A_2543 = R(2, 5, 4, 3)
A_24232 = R(2, 4, 2, 3, 2)

for mode in ["valid", "full"]:
for mode in ["valid", "full", "same"]:
assert npo.allclose(
ag_convolve(A_35, A_34, axes=([1], [0]), mode=mode)[1, 2],
sp_convolve(A_35[1, :], A_34[:, 2], mode),
Expand Down