From 97159971d8ab462386a54942689c4f05334e9791 Mon Sep 17 00:00:00 2001 From: Mrityunjai Singh Date: Sat, 19 Jul 2025 02:41:38 -0700 Subject: [PATCH] feature: add implementation for 'same' mode for convolve --- autograd/scipy/signal.py | 54 ++++++++++++++++++++++++++++++++-------- tests/test_scipy.py | 2 +- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/autograd/scipy/signal.py b/autograd/scipy/signal.py index 49b76194..daecd9cd 100644 --- a/autograd/scipy/signal.py +++ b/autograd/scipy/signal.py @@ -9,14 +9,16 @@ @primitive def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"): - assert mode in ["valid", "full"], f"Mode {mode} not yet implemented" + assert mode in ["valid", "full", "same"], ( + f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'" + ) if axes is None: axes = [list(range(A.ndim)), list(range(A.ndim))] wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)]) if wrong_order: if mode == "valid" and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]): raise Exception("One array must be larger than the other along all convolved dimensions") - elif mode != "full" or B.size <= A.size: # Tie breaker + elif mode == "valid" or (mode == "full" and B.size <= A.size): # Tie breaker i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore i3 = i2 + len(axes[0]) @@ -27,8 +29,12 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"): ignore_A + ignore_B + conv ) - if mode == "full": - B = pad_to_full(B, A, axes[::-1]) + if mode == "same": + B, A = A, B + axes = axes[::-1] + dot_axes = dot_axes[::-1] + if mode != "valid": + B = pad(B, A, axes[::-1], mode=mode) B_view_shape = list(B.shape) B_view_strides = list(B.strides) flipped_idxs = [slice(None)] * A.ndim @@ -40,7 +46,16 @@ def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"): B_view = as_strided(B, B_view_shape, B_view_strides) A_view = A[tuple(flipped_idxs)] all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]] - return einsum_tensordot(A_view, B_view, all_axes) + if mode == "same": + i1 = B.ndim - len(dot_axes[1]) - len(axes[1]) # B ignore + i2 = i1 + len(axes[0]) + i3 = i2 + A.ndim - len(dot_axes[0]) - len(axes[0]) # A ignore + ignore_A = list(range(i1)) + conv = list(range(i1, i2)) + ignore_B = list(range(i2, i3)) + return einsum_tensordot(B_view, A_view, all_axes[::-1]).transpose(ignore_A + ignore_B + conv) + else: + return einsum_tensordot(A_view, B_view, all_axes) def einsum_tensordot(A, B, axes, reverse=False): @@ -54,10 +69,15 @@ def einsum_tensordot(A, B, axes, reverse=False): return npo.einsum(A, A_axnums, B, B_axnums) -def pad_to_full(A, B, axes): +def pad(A, B, axes, mode="full"): A_pad = [(0, 0)] * A.ndim - for ax_A, ax_B in zip(*axes): - A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2 + if mode == "full": + for ax_A, ax_B in zip(*axes): + A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2 + elif mode == "same": + for ax_A, ax_B in zip(*axes): + right_bound = (B.shape[ax_B] - 1) // 2 + A_pad[ax_A] = (B.shape[ax_B] - 1 - right_bound, right_bound) return npo.pad(A, A_pad, mode="constant") @@ -124,7 +144,9 @@ def flipped_idxs(ndim, axes): def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"): - assert mode in ["valid", "full"], f"Grad for mode {mode} not yet implemented" + assert mode in ["valid", "full", "same"], ( + f"Mode {mode} undefined, it can be one of 'valid', 'full', and 'same'" + ) axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode) if argnum == 0: X, Y = A, B @@ -139,11 +161,23 @@ def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"): if mode == "full": new_mode = "valid" - else: + elif mode == "valid": if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]["conv"], shapes[_Y_]["conv"])]): new_mode = "full" else: new_mode = "valid" + elif mode == "same": + if _X_ == "A": + new_mode = "same" + else: + # This step makes convolve once differentiable, which is fine. + new_mode = "valid" + Y = pad( + Y, + X, + [axes[_Y_]["conv"], axes[_X_]["conv"]], + mode="same", + ) def vjp(g): result = convolve( diff --git a/tests/test_scipy.py b/tests/test_scipy.py index ba563eff..a0a78764 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -274,7 +274,7 @@ def test_convolve_generalization(): A_2543 = R(2, 5, 4, 3) A_24232 = R(2, 4, 2, 3, 2) - for mode in ["valid", "full"]: + for mode in ["valid", "full", "same"]: assert npo.allclose( ag_convolve(A_35, A_34, axes=([1], [0]), mode=mode)[1, 2], sp_convolve(A_35[1, :], A_34[:, 2], mode),