Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add silu and mul grad #502

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/flag_gems/fused/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,32 @@ def silu_and_mul_kernel(x, y):
return x_silu * y


@pointwise_dynamic(
promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
)
@triton.jit
def silu_and_mul_grad_kernel(x, y, dgrad):
x_fp32 = x.to(tl.float32)
sig = 1 / (1 + tl.exp(-x_fp32))
x_silu = x_fp32 * sig
d_x_silu = sig * (1 + x_fp32 * (1 - sig))
dx = d_x_silu * dgrad * y
dy = dgrad * x_silu
return dx, dy


class SiluAndMul(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B):
ctx.save_for_backward(A, B)
logging.debug("GEMS SILU AND MUL FORWARD")
return silu_and_mul_kernel(A, B)

def backward(ctx, grad_output):
A, B = ctx.saved_tensors
grad_A, grad_B = silu_and_mul_grad_kernel(A, B, grad_output)
return grad_A, grad_B


def silu_and_mul(A, B):
return SiluAndMul.apply(A, B)
17 changes: 15 additions & 2 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,16 +1353,29 @@ def test_accuracy_rsub(shape, alpha, dtype):
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_silu_and_mul(shape, dtype):
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_out = torch.mul(torch.nn.functional.silu(ref_inp1), ref_inp2)
with flag_gems.use_gems():
res_out = flag_gems.silu_and_mul(inp1, inp2)

out_grad = torch.randn_like(res_out)
ref_grad = to_reference(out_grad, True)

(ref_inp1_grad, ref_inp2_grad) = torch.autograd.grad(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend splitting out the backward test, so that the input data received by gems function could be the same as the reference function

ref_out, (ref_inp1, ref_inp2), ref_grad
)

(res_inp1_grad, res_inp2_grad) = torch.autograd.grad(
res_out, (inp1, inp2), out_grad
)

gems_assert_close(res_out, ref_out, dtype)
gems_assert_close(res_inp1_grad, ref_inp1_grad, dtype)
gems_assert_close(res_inp2_grad, ref_inp2_grad, dtype)


@pytest.mark.sub
Expand Down
Loading