Skip to content

Commit 3ac7229

Browse files
committed
[Operator] register log_softmax backward
1 parent 4904bd0 commit 3ac7229

File tree

4 files changed

+79
-73
lines changed

4 files changed

+79
-73
lines changed

src/flag_gems/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
171171
("any", any, Autograd.disable),
172172
("any.dim", any_dim, Autograd.disable),
173173
("any.dims", any_dims, Autograd.disable),
174-
("log_softmax.int", log_softmax, Autograd.enable),
174+
("_log_softmax", log_softmax, Autograd.disable),
175+
("_log_softmax_backward_data", log_softmax_backward, Autograd.disable),
175176
("outer", outer, Autograd.enable),
176177
("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
177178
("nll_loss_forward", nll_loss_forward, Autograd.disable),

src/flag_gems/ops/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from .isnan import isnan
5858
from .layernorm import layer_norm, layer_norm_backward
5959
from .le import le, le_scalar
60-
from .log_softmax import log_softmax
60+
from .log_softmax import log_softmax, log_softmax_backward
6161
from .logical_and import logical_and
6262
from .logical_not import logical_not
6363
from .logical_or import logical_or
@@ -276,6 +276,7 @@
276276
"var_mean",
277277
"vector_norm",
278278
"log_softmax",
279+
"log_softmax_backward",
279280
"outer",
280281
"cross_entropy_loss",
281282
"where_self_out",

src/flag_gems/ops/log_softmax.py

+57-66
Original file line numberDiff line numberDiff line change
@@ -93,73 +93,64 @@ def log_softmax_backward_kernel(
9393
tl.store(in_grad_ptrs, in_grad, mask=mask)
9494

9595

96-
class LogSoftmax(torch.autograd.Function):
97-
@staticmethod
98-
def forward(ctx, x, dim, dtype):
99-
logging.debug("GEMS LOG_SOFTMAX")
100-
101-
assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
102-
dim = dim % x.ndim
103-
M = 1
104-
N = x.shape[dim]
105-
for i in range(dim):
106-
M *= x.shape[i]
107-
inp = x.contiguous()
108-
if dtype is None:
109-
dtype = x.dtype
110-
out = torch.empty_like(inp, dtype=dtype)
111-
K = inp.numel() // M // N
112-
113-
grid = lambda meta: (
114-
triton.cdiv(M, meta["BLOCK_M"]),
96+
def log_softmax(self, dim, half_to_float=False):
97+
logging.debug("GEMS LOG_SOFTMAX")
98+
99+
assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
100+
dim = dim % self.ndim
101+
M = 1
102+
N = self.shape[dim]
103+
for i in range(dim):
104+
M *= self.shape[i]
105+
inp = self.contiguous()
106+
if half_to_float:
107+
dtype = torch.float32
108+
else:
109+
dtype = self.dtype
110+
out = torch.empty_like(inp, dtype=dtype)
111+
K = inp.numel() // M // N
112+
113+
grid = lambda meta: (
114+
triton.cdiv(M, meta["BLOCK_M"]),
115+
K,
116+
)
117+
with torch_device_fn.device(inp.device):
118+
log_softmax_kernel[grid](
119+
out,
120+
inp,
121+
M,
122+
N,
115123
K,
124+
num_warps=8,
116125
)
117-
with torch_device_fn.device(inp.device):
118-
log_softmax_kernel[grid](
119-
out,
120-
inp,
121-
M,
122-
N,
123-
K,
124-
num_warps=8,
125-
)
126-
ctx.save_for_backward(out)
127-
ctx.dim = dim
128-
return out
129-
130-
@staticmethod
131-
def backward(ctx, out_grad):
132-
logging.debug("GEMS LOG_SOFTMAX VJP")
133-
134-
dim = ctx.dim
135-
(out,) = ctx.saved_tensors
136-
137-
assert dim >= -out.ndim and dim < out.ndim, "Invalid dim"
138-
dim = dim % out.ndim
139-
M = 1
140-
N = out.shape[dim]
141-
for i in range(dim):
142-
M *= out.shape[i]
143-
144-
out_grad = out_grad.contiguous()
145-
in_grad = torch.empty_like(out)
146-
K = out.numel() // M // N
147-
148-
grid = lambda meta: (
149-
triton.cdiv(M, meta["BLOCK_M"]),
126+
return out
127+
128+
129+
def log_softmax_backward(grad_output, output, dim, input_dtype):
130+
logging.debug("GEMS LOG_SOFTMAX VJP")
131+
132+
assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
133+
dim = dim % output.ndim
134+
M = 1
135+
N = output.shape[dim]
136+
for i in range(dim):
137+
M *= output.shape[i]
138+
139+
grad_output = grad_output.contiguous()
140+
in_grad = torch.empty_like(output, dtype=input_dtype)
141+
K = output.numel() // M // N
142+
143+
grid = lambda meta: (
144+
triton.cdiv(M, meta["BLOCK_M"]),
145+
K,
146+
)
147+
with torch_device_fn.device(in_grad.device):
148+
log_softmax_backward_kernel[grid](
149+
output,
150+
grad_output,
151+
in_grad,
152+
M,
153+
N,
150154
K,
151155
)
152-
with torch_device_fn.device(in_grad.device):
153-
log_softmax_backward_kernel[grid](
154-
out,
155-
out_grad,
156-
in_grad,
157-
M,
158-
N,
159-
K,
160-
)
161-
return in_grad, None, None
162-
163-
164-
def log_softmax(x, dim=-1, dtype=None):
165-
return LogSoftmax.apply(x, dim, dtype)
156+
return in_grad

tests/test_reduction_ops.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -339,19 +339,32 @@ def test_accuracy_count_nonzero(shape, dtype):
339339
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
340340
def test_accuracy_log_softmax(shape, dtype):
341341
dim = 1
342-
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
342+
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
343343
ref_inp = to_reference(inp, True)
344344

345345
ref_out = torch.nn.functional.log_softmax(ref_inp, dim=dim)
346346
with flag_gems.use_gems():
347347
res_out = torch.nn.functional.log_softmax(inp, dim=dim)
348348
gems_assert_close(res_out, ref_out, dtype)
349349

350-
out_grad = torch.randn_like(res_out)
351-
ref_grad = to_reference(out_grad, True)
352350

353-
(ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)
354-
(res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
351+
@pytest.mark.log_softmax
352+
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
353+
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
354+
def test_accuracy_log_softmax_backward(shape, dtype):
355+
res_grad = torch.randn(shape, dtype=dtype, device=flag_gems.device)
356+
res_out = torch.randn_like(res_grad)
357+
ref_grad = to_reference(res_grad, True)
358+
ref_out = to_reference(res_out, True)
359+
dim = 1
360+
361+
ref_in_grad = torch.ops.aten._log_softmax_backward_data(
362+
ref_grad, ref_out, dim, ref_grad.dtype
363+
)
364+
with flag_gems.use_gems():
365+
res_in_grad = torch.ops.aten._log_softmax_backward_data(
366+
res_grad, res_out, dim, dtype
367+
)
355368
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim])
356369

357370

0 commit comments

Comments
 (0)