@@ -296,7 +296,7 @@ def quantized_layer_norm_per_tensor(
296
296
)
297
297
298
298
299
- def quantized_conv (
299
+ def quantized_conv_per_tensor (
300
300
input_tensor : torch .Tensor ,
301
301
weight : torch .Tensor ,
302
302
bias : torch .Tensor ,
@@ -305,12 +305,12 @@ def quantized_conv(
305
305
dilation : tuple [int , int ],
306
306
groups : int ,
307
307
in_zero_point : int ,
308
- weight_zero_point : torch . Tensor ,
309
- bias_scale : torch . Tensor ,
308
+ weight_zero_point : int ,
309
+ bias_scale : float ,
310
310
output_scale : float ,
311
311
output_zero_point : int ,
312
- out_multiplier : torch . Tensor ,
313
- out_shift : torch . Tensor ,
312
+ out_multiplier : int ,
313
+ out_shift : int ,
314
314
) -> torch .Tensor :
315
315
"""
316
316
Quantized convolution operation.
@@ -324,19 +324,13 @@ def quantized_conv(
324
324
- dilation (Tuple[int]): The dilation of the convolution
325
325
- groups (int): The number of groups
326
326
- in_zero_point (int): The quantized mapping of zero for the input
327
- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
328
- - bias_scale (Tensor ): The quantized bias scale
327
+ - weight_zero_point (int ): The quantized mapping of zero for the weight
328
+ - bias_scale (float ): The quantized bias scale
329
329
- output_scale (float): The scale of the output
330
330
- output_zero_point (int): The zero point of the output
331
- - out_multiplier (Tensor ): Unused
332
- - out_shift (Tensor ): Unused
331
+ - out_multiplier (int ): Unused
332
+ - out_shift (int ): Unused
333
333
"""
334
- if weight_zero_point .view (- 1 ).shape != (1 ,):
335
- raise ValueError ("Weight zero point must be a scalar" )
336
-
337
- if bias_scale .view (- 1 ).shape != (1 ,):
338
- raise ValueError ("Bias scale must be a scalar" )
339
-
340
334
if len (input_tensor .shape ) == 3 :
341
335
float_out = torch .nn .functional .conv1d (
342
336
(input_tensor - in_zero_point ).float (),
@@ -371,8 +365,8 @@ def quantized_conv(
371
365
)
372
366
373
367
374
- @impl (m , "quantized_conv_nchw " )
375
- def quantized_conv_nchw (
368
+ @impl (m , "quantized_conv_nchw_per_tensor " )
369
+ def quantized_conv_nchw_per_tensor (
376
370
input_tensor : torch .Tensor ,
377
371
weight : torch .Tensor ,
378
372
bias : torch .Tensor ,
@@ -381,12 +375,12 @@ def quantized_conv_nchw(
381
375
dilation : tuple [int , int ],
382
376
groups : int ,
383
377
in_zero_point : int ,
384
- weight_zero_point : torch . Tensor ,
385
- bias_scale : torch . Tensor ,
378
+ weight_zero_point : int ,
379
+ bias_scale : float ,
386
380
output_scale : float ,
387
381
output_zero_point : int ,
388
- out_multiplier : torch . Tensor ,
389
- out_shift : torch . Tensor ,
382
+ out_multiplier : int ,
383
+ out_shift : int ,
390
384
) -> torch .Tensor :
391
385
"""
392
386
Quantized convolution operation.
@@ -400,16 +394,16 @@ def quantized_conv_nchw(
400
394
- dilation (Tuple[int]): The dilation of the convolution
401
395
- groups (int): The number of groups
402
396
- in_zero_point (int): The quantized mapping of zero for the input
403
- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
404
- - bias_scale (Tensor ): The quantized bias scale
397
+ - weight_zero_point (int ): The quantized mapping of zero for the weight
398
+ - bias_scale (float ): The quantized bias scale
405
399
- output_scale (float): The scale of the output
406
400
- output_zero_point (int): The zero point of the output
407
- - out_multiplier (Tensor ): Unused
408
- - out_shift (Tensor ): Unused
401
+ - out_multiplier (int ): Unused
402
+ - out_shift (int ): Unused
409
403
"""
410
404
if not input_tensor .is_contiguous (memory_format = torch .contiguous_format ):
411
405
raise ValueError ("Input tensor must be in NCHW format" )
412
- return quantized_conv (
406
+ return quantized_conv_per_tensor (
413
407
input_tensor ,
414
408
weight ,
415
409
bias ,
@@ -427,8 +421,8 @@ def quantized_conv_nchw(
427
421
)
428
422
429
423
430
- @impl (m , "quantized_conv_nhwc " )
431
- def quantized_conv_nhwc (
424
+ @impl (m , "quantized_conv_nhwc_per_tensor " )
425
+ def quantized_conv_nhwc_per_tensor (
432
426
input_tensor : torch .Tensor ,
433
427
weight : torch .Tensor ,
434
428
bias : torch .Tensor ,
@@ -437,12 +431,12 @@ def quantized_conv_nhwc(
437
431
dilation : tuple [int , int ],
438
432
groups : int ,
439
433
in_zero_point : int ,
440
- weight_zero_point : torch . Tensor ,
441
- bias_scale : torch . Tensor ,
434
+ weight_zero_point : int ,
435
+ bias_scale : float ,
442
436
output_scale : float ,
443
437
output_zero_point : int ,
444
- out_multiplier : torch . Tensor ,
445
- out_shift : torch . Tensor ,
438
+ out_multiplier : int ,
439
+ out_shift : int ,
446
440
) -> torch .Tensor :
447
441
"""
448
442
Quantized convolution operation.
@@ -456,18 +450,18 @@ def quantized_conv_nhwc(
456
450
- dilation (Tuple[int]): The dilation of the convolution
457
451
- groups (int): The number of groups
458
452
- in_zero_point (int): The quantized mapping of zero for the input
459
- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
460
- - bias_scale (Tensor ): The quantized bias scale
453
+ - weight_zero_point (int ): The quantized mapping of zero for the weight
454
+ - bias_scale (float ): The quantized bias scale
461
455
- output_scale (float): The scale of the output
462
456
- output_zero_point (int): The zero point of the output
463
- - out_multiplier (Tensor ): Unused
464
- - out_shift (Tensor ): Unused
457
+ - out_multiplier (int ): Unused
458
+ - out_shift (int ): Unused
465
459
"""
466
460
467
461
if not input_tensor .is_contiguous (memory_format = torch .channels_last ):
468
462
raise ValueError ("Input tensor must be in NHWC format" )
469
463
470
- return quantized_conv (
464
+ return quantized_conv_per_tensor (
471
465
input_tensor ,
472
466
weight ,
473
467
bias ,
0 commit comments