Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9f79739

Browse files
committedFeb 6, 2025
[Operator] register batch_norm backward
1 parent 3ac7229 commit 9f79739

File tree

5 files changed

+192
-192
lines changed

5 files changed

+192
-192
lines changed
 

‎src/flag_gems/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
2525
("arange.start_step", arange_start, Autograd.disable),
2626
("arange.start", arange_start, Autograd.disable),
2727
("arange", arange, Autograd.disable),
28-
("batch_norm", batch_norm, Autograd.enable),
28+
("native_batch_norm", batch_norm, Autograd.disable),
29+
("native_batch_norm_backward", batch_norm_backward, Autograd.disable),
2930
("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
3031
("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
3132
("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),

‎src/flag_gems/ops/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .argmax import argmax
99
from .argmin import argmin
1010
from .attention import scaled_dot_product_attention
11-
from .batch_norm import batch_norm
11+
from .batch_norm import batch_norm, batch_norm_backward
1212
from .bitwise_and import (
1313
bitwise_and_scalar,
1414
bitwise_and_scalar_tensor,
@@ -150,6 +150,7 @@
150150
"arange",
151151
"arange_start",
152152
"batch_norm",
153+
"batch_norm_backward",
153154
"bitwise_and_tensor",
154155
"bitwise_and_scalar",
155156
"bitwise_and_scalar_tensor",

‎src/flag_gems/ops/batch_norm.py

+111-152
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .. import runtime
99
from ..runtime import torch_device_fn
1010
from ..utils import libentry, tl_extra_shim
11-
from ..utils.type_utils import get_accumulator_dtype
1211

1312
rsqrt = tl_extra_shim.rsqrt
1413

@@ -63,8 +62,6 @@ def batch_norm_forward_kernel(
6362
output_spatial_stride,
6463
momentum,
6564
eps,
66-
affine: tl.constexpr,
67-
save_stats: tl.constexpr,
6865
is_train: tl.constexpr,
6966
BLOCK_M: tl.constexpr,
7067
BLOCK_N: tl.constexpr,
@@ -114,9 +111,8 @@ def batch_norm_forward_kernel(
114111
inv_std = rsqrt(var + eps)
115112
mean = final_mean
116113

117-
if save_stats:
118-
tl.store(feat_pid + mean_pointer, mean)
119-
tl.store(feat_pid + inv_std_pointer, inv_std)
114+
tl.store(feat_pid + mean_pointer, mean)
115+
tl.store(feat_pid + inv_std_pointer, inv_std)
120116

121117
running_mean_pointer += feat_pid
122118
running_var_pointer += feat_pid
@@ -135,12 +131,13 @@ def batch_norm_forward_kernel(
135131
mean = tl.load(feat_pid + running_mean_pointer)
136132
inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps)
137133

138-
if affine:
139-
weight = tl.load(feat_pid + weight_pointer)
140-
bias = tl.load(feat_pid + bias_pointer)
141-
134+
if weight_pointer:
135+
weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
142136
else:
143137
weight = 1.0
138+
if bias_pointer:
139+
bias = tl.load(feat_pid + bias_pointer).to(tl.float32)
140+
else:
144141
bias = 0.0
145142

146143
for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
@@ -203,7 +200,9 @@ def batch_norm_backward_kernel(
203200
input_grad_batch_stride,
204201
input_grad_feat_stride,
205202
input_grad_spatial_stride,
206-
affine: tl.constexpr,
203+
input_grad_mask: tl.constexpr,
204+
weight_grad_mask: tl.constexpr,
205+
bias_grad_mask: tl.constexpr,
207206
BLOCK_M: tl.constexpr,
208207
BLOCK_N: tl.constexpr,
209208
):
@@ -250,11 +249,16 @@ def batch_norm_backward_kernel(
250249
term1 = tl.sum(term1)
251250
term2 = tl.sum(term2)
252251

253-
if affine:
254-
weight = tl.load(feat_pid + weight_pointer)
255-
weight_grad_acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
256-
bias_grad_acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
252+
if weight_grad_mask:
253+
tl.store(feat_pid + weight_grad_pointer, term1)
254+
if bias_grad_mask:
255+
tl.store(feat_pid + bias_grad_pointer, term2)
256+
257+
if not input_grad_mask:
258+
return
257259

260+
if weight_pointer:
261+
weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
258262
else:
259263
weight = 1.0
260264

@@ -306,152 +310,107 @@ def batch_norm_backward_kernel(
306310
mask=batch_mask[:, None] & spatial_mask[None, :],
307311
)
308312

309-
if affine:
310-
weight_grad_acc += curr_pre_lin * curr_output_grad
311-
bias_grad_acc += curr_output_grad
312-
313-
if affine:
314-
tl.store(feat_pid + weight_grad_pointer, tl.sum(weight_grad_acc))
315-
tl.store(feat_pid + bias_grad_pointer, tl.sum(bias_grad_acc))
316-
317-
318-
class BatchNorm(torch.autograd.Function):
319-
@staticmethod
320-
def forward(
321-
ctx,
322-
input: Tensor,
323-
weight=None,
324-
bias=None,
325-
running_mean=None, # self.running_mean if not self.training or self.track_running_state else None
326-
running_var=None,
327-
training=False, # (self.running_mean is None) and (self.running_var is None)
328-
momentum=0.1,
329-
eps=1e-05,
330-
cudnn_enable=True,
331-
):
332-
logging.debug("GEMS BATCHNORM FORWARD")
333-
334-
input_3d = make_3d_for_bn(input)
335-
336-
affine = weight is not None and bias is not None
337-
requires_grad = (
338-
input.requires_grad
339-
or (affine and weight.requires_grad)
340-
or (affine and bias.requires_grad)
341-
)
342-
343-
batch_dim, feat_dim, spatial_dim = input_3d.shape
344-
output = torch.empty_like(input_3d)
345313

346-
if requires_grad:
347-
acc_type = get_accumulator_dtype(input.dtype)
348-
mean = torch.empty(feat_dim, device=input.device, dtype=acc_type)
349-
inv_std = torch.empty(feat_dim, device=input.device, dtype=acc_type)
350-
351-
else:
352-
mean = inv_std = None
353-
354-
running_mean = input if running_mean is None else running_mean
355-
running_var = input if running_var is None else running_var
314+
def batch_norm(
315+
input: Tensor,
316+
weight=None,
317+
bias=None,
318+
running_mean=None, # self.running_mean if not self.training or self.track_running_state else None
319+
running_var=None,
320+
training=False, # (self.running_mean is None) and (self.running_var is None)
321+
momentum=0.1,
322+
eps=1e-05,
323+
):
324+
logging.debug("GEMS BATCHNORM FORWARD")
325+
326+
input_3d = make_3d_for_bn(input)
327+
328+
batch_dim, feat_dim, spatial_dim = input_3d.shape
329+
output = torch.empty_like(input_3d)
330+
331+
mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
332+
inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
333+
334+
running_mean = input if running_mean is None else running_mean
335+
running_var = input if running_var is None else running_var
336+
337+
# Launches 1D grid where each program operates over one feature.
338+
with torch_device_fn.device(input.device):
339+
batch_norm_forward_kernel[(feat_dim,)](
340+
input_3d,
341+
weight,
342+
bias,
343+
mean,
344+
inv_std,
345+
output,
346+
running_mean,
347+
running_var,
348+
batch_dim,
349+
spatial_dim,
350+
*input_3d.stride(),
351+
*output.stride(),
352+
momentum,
353+
eps,
354+
is_train=training,
355+
)
356356

357-
# Launches 1D grid where each program operates over one feature.
358-
with torch_device_fn.device(input.device):
359-
batch_norm_forward_kernel[(feat_dim,)](
360-
input_3d,
361-
weight,
362-
bias,
363-
mean,
364-
inv_std,
365-
output,
366-
running_mean,
367-
running_var,
368-
batch_dim,
369-
spatial_dim,
370-
*input_3d.stride(),
371-
*output.stride(),
372-
momentum,
373-
eps,
374-
affine=affine,
375-
save_stats=requires_grad,
376-
is_train=training,
377-
)
357+
return output.view_as(input), mean, inv_std
378358

379-
ctx.affine = affine
380-
if requires_grad:
381-
ctx.save_for_backward(input, mean, inv_std, weight)
382359

383-
return output.view_as(input)
360+
def batch_norm_backward(
361+
grad_out,
362+
input,
363+
weight=None,
364+
running_mean=None,
365+
running_var=None,
366+
save_mean=None,
367+
save_invstd=None,
368+
train=False,
369+
eps=1e-05,
370+
output_mask=None,
371+
):
372+
logging.debug("GEMS BATCHNORM BACKWARD")
373+
input_3d = make_3d_for_bn(input)
374+
output_grad_3d = make_3d_for_bn(grad_out)
384375

385-
@staticmethod
386-
def backward(ctx, output_grad):
387-
logging.debug("GEMS BATCHNORM BACKWARD")
388-
(input, mean, inv_std, weight) = ctx.saved_tensors
389-
input_3d = make_3d_for_bn(input)
390-
output_grad_3d = make_3d_for_bn(output_grad)
376+
batch_dim, feat_dim, spatial_dim = input_3d.shape
391377

392-
batch_dim, feat_dim, spatial_dim = input_3d.shape
378+
if output_mask[0]:
393379
input_grad = torch.empty_like(input_3d)
394-
395-
if ctx.affine:
396-
weight_grad = torch.empty((feat_dim,), device=input.device)
397-
bias_grad = torch.empty_like(weight_grad)
398-
399-
else:
400-
weight_grad = bias_grad = None
401-
402-
# Launches 1D grid where each program operates over one feature.
403-
with torch_device_fn.device(input.device):
404-
batch_norm_backward_kernel[(feat_dim,)](
405-
output_grad_3d,
406-
input_3d,
407-
mean,
408-
inv_std,
409-
weight,
410-
input_grad,
411-
weight_grad,
412-
bias_grad,
413-
batch_dim,
414-
spatial_dim,
415-
*output_grad_3d.stride(),
416-
*input_3d.stride(),
417-
*input_grad.stride(),
418-
affine=ctx.affine,
419-
)
420-
421-
# Pads output with None because a gradient is necessary for
422-
# all input arguments.
423-
return (
424-
input_grad.view_as(input),
380+
else:
381+
input_grad = None
382+
if output_mask[1]:
383+
weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
384+
else:
385+
weight_grad = None
386+
if output_mask[2]:
387+
bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
388+
else:
389+
bias_grad = None
390+
391+
# Launches 1D grid where each program operates over one feature.
392+
with torch_device_fn.device(input.device):
393+
batch_norm_backward_kernel[(feat_dim,)](
394+
output_grad_3d,
395+
input_3d,
396+
save_mean,
397+
save_invstd,
398+
weight,
399+
input_grad,
425400
weight_grad,
426401
bias_grad,
427-
None,
428-
None,
429-
None,
430-
None,
431-
None,
432-
None,
402+
batch_dim,
403+
spatial_dim,
404+
*output_grad_3d.stride(),
405+
*input_3d.stride(),
406+
*input_grad.stride(),
407+
*output_mask,
433408
)
434409

435-
436-
def batch_norm(
437-
input,
438-
weight=None,
439-
bias=None,
440-
running_mean=None,
441-
running_var=None,
442-
training=False,
443-
momentum=0.1,
444-
eps=1e-05,
445-
cudnn_enable=True,
446-
):
447-
return BatchNorm.apply(
448-
input,
449-
weight,
450-
bias,
451-
running_mean,
452-
running_var,
453-
training,
454-
momentum,
455-
eps,
456-
cudnn_enable,
410+
# Pads output with None because a gradient is necessary for
411+
# all input arguments.
412+
return (
413+
input_grad.view_as(input),
414+
weight_grad,
415+
bias_grad,
457416
)

‎src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml

-3
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,6 @@ batch_norm:
966966
META: {}
967967
num_warps: warps
968968
warps:
969-
- 1
970-
- 2
971969
- 4
972970
- 8
973971
- 16
974-
- 32

‎tests/test_norm_ops.py

+77-35
Original file line numberDiff line numberDiff line change
@@ -592,25 +592,14 @@ def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype):
592592
)
593593
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
594594
@pytest.mark.parametrize("affine", [True, False])
595-
@pytest.mark.parametrize("require_grad", [True, False])
596-
def test_accuracy_batch_norm(shape, dtype, affine, require_grad):
595+
def test_accuracy_batch_norm(shape, dtype, affine):
597596
C = shape[1]
598-
inp = torch.randn(
599-
size=shape, dtype=dtype, device=flag_gems.device, requires_grad=require_grad
600-
)
597+
inp = torch.randn(size=shape, dtype=dtype, device=flag_gems.device)
601598
weight = (
602-
torch.randn(
603-
size=(C,), dtype=dtype, device=flag_gems.device, requires_grad=require_grad
604-
)
605-
if affine
606-
else None
599+
torch.randn(size=(C,), dtype=dtype, device=flag_gems.device) if affine else None
607600
)
608601
bias = (
609-
torch.randn(
610-
size=(C,), dtype=dtype, device=flag_gems.device, requires_grad=require_grad
611-
)
612-
if affine
613-
else None
602+
torch.randn(size=(C,), dtype=dtype, device=flag_gems.device) if affine else None
614603
)
615604

616605
running_mean = torch.zeros(size=(C,), dtype=dtype, device=flag_gems.device)
@@ -624,15 +613,12 @@ def test_accuracy_batch_norm(shape, dtype, affine, require_grad):
624613
ref_running_mean = to_reference(running_mean, True)
625614
ref_running_var = to_reference(running_var, True)
626615

627-
training = require_grad
628-
629616
ref_out = torch.nn.functional.batch_norm(
630617
ref_inp,
631618
ref_running_mean,
632619
ref_running_var,
633620
weight=ref_weight,
634621
bias=ref_bias,
635-
training=training,
636622
eps=eps,
637623
)
638624

@@ -643,36 +629,92 @@ def test_accuracy_batch_norm(shape, dtype, affine, require_grad):
643629
running_var,
644630
weight=weight,
645631
bias=bias,
646-
training=training,
647632
eps=eps,
648633
)
649634

650635
gems_assert_close(res_out, ref_out, dtype)
651636
gems_assert_close(running_mean, ref_running_mean, dtype)
652637
gems_assert_close(running_var, ref_running_var, dtype)
653638

654-
if not require_grad:
655-
return
656639

657-
out_grad = torch.randn_like(inp)
658-
ref_grad = to_reference(out_grad, True)
659-
reduce_dim = int(math.prod(shape) / C)
640+
@pytest.mark.batch_norm
641+
@pytest.mark.parametrize(
642+
"shape",
643+
[
644+
(16, 3),
645+
(32, 32, 32),
646+
(8, 32, 224, 224),
647+
(2050, 16, 32, 32),
648+
(8, 16, 3, 224, 224),
649+
],
650+
)
651+
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
652+
@pytest.mark.parametrize("affine", [True, False])
653+
def test_accuracy_batch_norm_backward(shape, dtype, affine):
654+
C = shape[1]
655+
res_grad = torch.randn(size=shape, dtype=dtype, device=flag_gems.device)
656+
res_inp = torch.randn_like(res_grad)
657+
res_weight = (
658+
torch.randn(size=(C,), dtype=dtype, device=flag_gems.device) if affine else None
659+
)
660+
res_running_mean = torch.zeros(size=(C,), dtype=dtype, device=flag_gems.device)
661+
res_running_var = torch.ones(size=(C,), dtype=dtype, device=flag_gems.device)
662+
res_save_mean = torch.randn(C, dtype=torch.float32, device=flag_gems.device)
663+
res_save_invstd = torch.randn(C, dtype=torch.float32, device=flag_gems.device)
660664

665+
ref_grad = to_reference(res_grad, True)
666+
ref_inp = to_reference(res_inp, True)
667+
ref_weight = to_reference(res_weight, True)
668+
ref_running_mean = to_reference(res_running_mean, True)
669+
ref_running_var = to_reference(res_running_var, True)
670+
ref_save_mean = to_reference(res_save_mean, True)
671+
ref_save_invstd = to_reference(res_save_invstd, True)
672+
673+
train = True
674+
eps = 1e-05
661675
if affine:
662-
(ref_in_grad, ref_weight_grad, ref_bias_grad) = torch.autograd.grad(
663-
ref_out, (ref_inp, ref_weight, ref_bias), ref_grad
664-
)
665-
(res_in_grad, res_weight_grad, res_bias_grad) = torch.autograd.grad(
666-
res_out, (inp, weight, bias), out_grad
676+
output_mask = [True, True, True]
677+
else:
678+
output_mask = [True, False, False]
679+
680+
(
681+
ref_in_grad,
682+
ref_weight_grad,
683+
ref_bias_grad,
684+
) = torch.ops.aten.native_batch_norm_backward(
685+
ref_grad,
686+
ref_inp,
687+
ref_weight,
688+
ref_running_mean,
689+
ref_running_var,
690+
ref_save_mean,
691+
ref_save_invstd,
692+
train,
693+
eps,
694+
output_mask,
695+
)
696+
with flag_gems.use_gems():
697+
(
698+
res_in_grad,
699+
res_weight_grad,
700+
res_bias_grad,
701+
) = torch.ops.aten.native_batch_norm_backward(
702+
res_grad,
703+
res_inp,
704+
res_weight,
705+
res_running_mean,
706+
res_running_var,
707+
res_save_mean,
708+
res_save_invstd,
709+
train,
710+
eps,
711+
output_mask,
667712
)
668713

669-
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=reduce_dim)
714+
reduce_dim = math.prod(shape) // C
715+
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=reduce_dim)
716+
if affine:
670717
gems_assert_close(
671718
res_weight_grad, ref_weight_grad, dtype, reduce_dim=reduce_dim
672719
)
673720
gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=reduce_dim)
674-
else:
675-
(ref_in_grad,) = torch.autograd.grad(ref_out, (ref_inp,), ref_grad)
676-
(res_in_grad,) = torch.autograd.grad(res_out, (inp,), out_grad)
677-
678-
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=reduce_dim)

0 commit comments

Comments
 (0)
Please sign in to comment.