8
8
from .. import runtime
9
9
from ..runtime import torch_device_fn
10
10
from ..utils import libentry , tl_extra_shim
11
- from ..utils .type_utils import get_accumulator_dtype
12
11
13
12
rsqrt = tl_extra_shim .rsqrt
14
13
@@ -63,8 +62,6 @@ def batch_norm_forward_kernel(
63
62
output_spatial_stride ,
64
63
momentum ,
65
64
eps ,
66
- affine : tl .constexpr ,
67
- save_stats : tl .constexpr ,
68
65
is_train : tl .constexpr ,
69
66
BLOCK_M : tl .constexpr ,
70
67
BLOCK_N : tl .constexpr ,
@@ -114,9 +111,8 @@ def batch_norm_forward_kernel(
114
111
inv_std = rsqrt (var + eps )
115
112
mean = final_mean
116
113
117
- if save_stats :
118
- tl .store (feat_pid + mean_pointer , mean )
119
- tl .store (feat_pid + inv_std_pointer , inv_std )
114
+ tl .store (feat_pid + mean_pointer , mean )
115
+ tl .store (feat_pid + inv_std_pointer , inv_std )
120
116
121
117
running_mean_pointer += feat_pid
122
118
running_var_pointer += feat_pid
@@ -135,12 +131,13 @@ def batch_norm_forward_kernel(
135
131
mean = tl .load (feat_pid + running_mean_pointer )
136
132
inv_std = rsqrt (tl .load (feat_pid + running_var_pointer ) + eps )
137
133
138
- if affine :
139
- weight = tl .load (feat_pid + weight_pointer )
140
- bias = tl .load (feat_pid + bias_pointer )
141
-
134
+ if weight_pointer :
135
+ weight = tl .load (feat_pid + weight_pointer ).to (tl .float32 )
142
136
else :
143
137
weight = 1.0
138
+ if bias_pointer :
139
+ bias = tl .load (feat_pid + bias_pointer ).to (tl .float32 )
140
+ else :
144
141
bias = 0.0
145
142
146
143
for m_step in range (0 , tl .cdiv (batch_dim , BLOCK_M )):
@@ -203,7 +200,9 @@ def batch_norm_backward_kernel(
203
200
input_grad_batch_stride ,
204
201
input_grad_feat_stride ,
205
202
input_grad_spatial_stride ,
206
- affine : tl .constexpr ,
203
+ input_grad_mask : tl .constexpr ,
204
+ weight_grad_mask : tl .constexpr ,
205
+ bias_grad_mask : tl .constexpr ,
207
206
BLOCK_M : tl .constexpr ,
208
207
BLOCK_N : tl .constexpr ,
209
208
):
@@ -250,11 +249,16 @@ def batch_norm_backward_kernel(
250
249
term1 = tl .sum (term1 )
251
250
term2 = tl .sum (term2 )
252
251
253
- if affine :
254
- weight = tl .load (feat_pid + weight_pointer )
255
- weight_grad_acc = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
256
- bias_grad_acc = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
252
+ if weight_grad_mask :
253
+ tl .store (feat_pid + weight_grad_pointer , term1 )
254
+ if bias_grad_mask :
255
+ tl .store (feat_pid + bias_grad_pointer , term2 )
256
+
257
+ if not input_grad_mask :
258
+ return
257
259
260
+ if weight_pointer :
261
+ weight = tl .load (feat_pid + weight_pointer ).to (tl .float32 )
258
262
else :
259
263
weight = 1.0
260
264
@@ -306,152 +310,107 @@ def batch_norm_backward_kernel(
306
310
mask = batch_mask [:, None ] & spatial_mask [None , :],
307
311
)
308
312
309
- if affine :
310
- weight_grad_acc += curr_pre_lin * curr_output_grad
311
- bias_grad_acc += curr_output_grad
312
-
313
- if affine :
314
- tl .store (feat_pid + weight_grad_pointer , tl .sum (weight_grad_acc ))
315
- tl .store (feat_pid + bias_grad_pointer , tl .sum (bias_grad_acc ))
316
-
317
-
318
- class BatchNorm (torch .autograd .Function ):
319
- @staticmethod
320
- def forward (
321
- ctx ,
322
- input : Tensor ,
323
- weight = None ,
324
- bias = None ,
325
- running_mean = None , # self.running_mean if not self.training or self.track_running_state else None
326
- running_var = None ,
327
- training = False , # (self.running_mean is None) and (self.running_var is None)
328
- momentum = 0.1 ,
329
- eps = 1e-05 ,
330
- cudnn_enable = True ,
331
- ):
332
- logging .debug ("GEMS BATCHNORM FORWARD" )
333
-
334
- input_3d = make_3d_for_bn (input )
335
-
336
- affine = weight is not None and bias is not None
337
- requires_grad = (
338
- input .requires_grad
339
- or (affine and weight .requires_grad )
340
- or (affine and bias .requires_grad )
341
- )
342
-
343
- batch_dim , feat_dim , spatial_dim = input_3d .shape
344
- output = torch .empty_like (input_3d )
345
313
346
- if requires_grad :
347
- acc_type = get_accumulator_dtype (input .dtype )
348
- mean = torch .empty (feat_dim , device = input .device , dtype = acc_type )
349
- inv_std = torch .empty (feat_dim , device = input .device , dtype = acc_type )
350
-
351
- else :
352
- mean = inv_std = None
353
-
354
- running_mean = input if running_mean is None else running_mean
355
- running_var = input if running_var is None else running_var
314
+ def batch_norm (
315
+ input : Tensor ,
316
+ weight = None ,
317
+ bias = None ,
318
+ running_mean = None , # self.running_mean if not self.training or self.track_running_state else None
319
+ running_var = None ,
320
+ training = False , # (self.running_mean is None) and (self.running_var is None)
321
+ momentum = 0.1 ,
322
+ eps = 1e-05 ,
323
+ ):
324
+ logging .debug ("GEMS BATCHNORM FORWARD" )
325
+
326
+ input_3d = make_3d_for_bn (input )
327
+
328
+ batch_dim , feat_dim , spatial_dim = input_3d .shape
329
+ output = torch .empty_like (input_3d )
330
+
331
+ mean = torch .empty (feat_dim , device = input .device , dtype = input .dtype )
332
+ inv_std = torch .empty (feat_dim , device = input .device , dtype = input .dtype )
333
+
334
+ running_mean = input if running_mean is None else running_mean
335
+ running_var = input if running_var is None else running_var
336
+
337
+ # Launches 1D grid where each program operates over one feature.
338
+ with torch_device_fn .device (input .device ):
339
+ batch_norm_forward_kernel [(feat_dim ,)](
340
+ input_3d ,
341
+ weight ,
342
+ bias ,
343
+ mean ,
344
+ inv_std ,
345
+ output ,
346
+ running_mean ,
347
+ running_var ,
348
+ batch_dim ,
349
+ spatial_dim ,
350
+ * input_3d .stride (),
351
+ * output .stride (),
352
+ momentum ,
353
+ eps ,
354
+ is_train = training ,
355
+ )
356
356
357
- # Launches 1D grid where each program operates over one feature.
358
- with torch_device_fn .device (input .device ):
359
- batch_norm_forward_kernel [(feat_dim ,)](
360
- input_3d ,
361
- weight ,
362
- bias ,
363
- mean ,
364
- inv_std ,
365
- output ,
366
- running_mean ,
367
- running_var ,
368
- batch_dim ,
369
- spatial_dim ,
370
- * input_3d .stride (),
371
- * output .stride (),
372
- momentum ,
373
- eps ,
374
- affine = affine ,
375
- save_stats = requires_grad ,
376
- is_train = training ,
377
- )
357
+ return output .view_as (input ), mean , inv_std
378
358
379
- ctx .affine = affine
380
- if requires_grad :
381
- ctx .save_for_backward (input , mean , inv_std , weight )
382
359
383
- return output .view_as (input )
360
+ def batch_norm_backward (
361
+ grad_out ,
362
+ input ,
363
+ weight = None ,
364
+ running_mean = None ,
365
+ running_var = None ,
366
+ save_mean = None ,
367
+ save_invstd = None ,
368
+ train = False ,
369
+ eps = 1e-05 ,
370
+ output_mask = None ,
371
+ ):
372
+ logging .debug ("GEMS BATCHNORM BACKWARD" )
373
+ input_3d = make_3d_for_bn (input )
374
+ output_grad_3d = make_3d_for_bn (grad_out )
384
375
385
- @staticmethod
386
- def backward (ctx , output_grad ):
387
- logging .debug ("GEMS BATCHNORM BACKWARD" )
388
- (input , mean , inv_std , weight ) = ctx .saved_tensors
389
- input_3d = make_3d_for_bn (input )
390
- output_grad_3d = make_3d_for_bn (output_grad )
376
+ batch_dim , feat_dim , spatial_dim = input_3d .shape
391
377
392
- batch_dim , feat_dim , spatial_dim = input_3d . shape
378
+ if output_mask [ 0 ]:
393
379
input_grad = torch .empty_like (input_3d )
394
-
395
- if ctx .affine :
396
- weight_grad = torch .empty ((feat_dim ,), device = input .device )
397
- bias_grad = torch .empty_like (weight_grad )
398
-
399
- else :
400
- weight_grad = bias_grad = None
401
-
402
- # Launches 1D grid where each program operates over one feature.
403
- with torch_device_fn .device (input .device ):
404
- batch_norm_backward_kernel [(feat_dim ,)](
405
- output_grad_3d ,
406
- input_3d ,
407
- mean ,
408
- inv_std ,
409
- weight ,
410
- input_grad ,
411
- weight_grad ,
412
- bias_grad ,
413
- batch_dim ,
414
- spatial_dim ,
415
- * output_grad_3d .stride (),
416
- * input_3d .stride (),
417
- * input_grad .stride (),
418
- affine = ctx .affine ,
419
- )
420
-
421
- # Pads output with None because a gradient is necessary for
422
- # all input arguments.
423
- return (
424
- input_grad .view_as (input ),
380
+ else :
381
+ input_grad = None
382
+ if output_mask [1 ]:
383
+ weight_grad = torch .empty ((feat_dim ,), dtype = input .dtype , device = input .device )
384
+ else :
385
+ weight_grad = None
386
+ if output_mask [2 ]:
387
+ bias_grad = torch .empty ((feat_dim ,), dtype = input .dtype , device = input .device )
388
+ else :
389
+ bias_grad = None
390
+
391
+ # Launches 1D grid where each program operates over one feature.
392
+ with torch_device_fn .device (input .device ):
393
+ batch_norm_backward_kernel [(feat_dim ,)](
394
+ output_grad_3d ,
395
+ input_3d ,
396
+ save_mean ,
397
+ save_invstd ,
398
+ weight ,
399
+ input_grad ,
425
400
weight_grad ,
426
401
bias_grad ,
427
- None ,
428
- None ,
429
- None ,
430
- None ,
431
- None ,
432
- None ,
402
+ batch_dim ,
403
+ spatial_dim ,
404
+ * output_grad_3d . stride () ,
405
+ * input_3d . stride () ,
406
+ * input_grad . stride () ,
407
+ * output_mask ,
433
408
)
434
409
435
-
436
- def batch_norm (
437
- input ,
438
- weight = None ,
439
- bias = None ,
440
- running_mean = None ,
441
- running_var = None ,
442
- training = False ,
443
- momentum = 0.1 ,
444
- eps = 1e-05 ,
445
- cudnn_enable = True ,
446
- ):
447
- return BatchNorm .apply (
448
- input ,
449
- weight ,
450
- bias ,
451
- running_mean ,
452
- running_var ,
453
- training ,
454
- momentum ,
455
- eps ,
456
- cudnn_enable ,
410
+ # Pads output with None because a gradient is necessary for
411
+ # all input arguments.
412
+ return (
413
+ input_grad .view_as (input ),
414
+ weight_grad ,
415
+ bias_grad ,
457
416
)
0 commit comments