Skip to content

Commit 2c13310

Browse files
Revert some changes to original state
1 parent b8902df commit 2c13310

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
497497
if (!isPerChannelQuantization)
498498
input = makePerTensor(input, inputScale, inputZp);
499499

500-
// TODO(suderman): insert convolution operator.
501500
llvm::SmallVector<Value> newOperands = {input, weight};
502501
if (bias)
503502
newOperands.push_back(bias);
@@ -524,6 +523,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
524523
newAttributes, binder.op->getRegions().size())
525524
.getResult(0);
526525

526+
if (!isPerChannelQuantization) {
527+
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
528+
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale,
529+
weightScale);
530+
Value outZp = rewriter.create<Torch::ConstantIntOp>(
531+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
532+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
533+
output = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
534+
binder.getLoc(), outputTy, output, outScale, outZp);
535+
outputTy = rewriter.getType<Torch::ValueTensorType>(
536+
resultType.getOptionalSizes(), rewriter.getF32Type());
537+
538+
output = rewriter.create<Torch::AtenDequantizeSelfOp>(
539+
binder.getLoc(), outputTy, output);
540+
}
541+
527542
outputTy = getQTorchTypeFromTorchIntType(resultType);
528543
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
529544
binder.getLoc(), rewriter.getType<Torch::IntType>(),

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1:
9090
// CHECK: %[[NONE:.+]] = torch.constant.none
9191
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
9292
// CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
93+
// CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float
94+
// CHECK: %[[INT0_6:.+]] = torch.constant.int 0
95+
// CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
96+
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32>
9397
// CHECK: %[[INT13:.+]] = torch.constant.int 13
94-
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
98+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
9599
// CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8>
96100
// CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8>
97101
return %0 : !torch.vtensor<[1,1,7,7],ui8>
@@ -124,8 +128,12 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t
124128
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
125129
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
126130
// CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
131+
// CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float
132+
// CHECK: %[[INT0_6:.+]] = torch.constant.int 0
133+
// CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
134+
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32>
127135
// CHECK: %[[INT13:.+]] = torch.constant.int 13
128-
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
136+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
129137
// CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8>
130138
// CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8>
131139
return %0 : !torch.vtensor<[1,1,7,7],ui8>

0 commit comments

Comments
 (0)