Skip to content

Commit ffffb0c

Browse files
More changes
1 parent d7a9066 commit ffffb0c

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
330330
Value outputZp = operands[7];
331331
Value output = operands.size() == 9 ? operands[8] : nullptr;
332332

333-
// auto check = [](Value v) {
334-
// auto vTy = cast<Torch::ValueTensorType>(v.getType());
335-
// return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1;
336-
// });
337-
// };
338-
// if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
339-
// !check(cScale) || !check(cScale))
340-
// return rewriter.notifyMatchFailure(
341-
// binder.op, "not supported for non per-tensor quantization");
342-
343333
auto extract = [&rewriter, &binder](Value v) {
344334
auto vTy = cast<Torch::ValueTensorType>(v.getType());
345335
Type extractTy = rewriter.getType<Torch::FloatType>();
@@ -374,14 +364,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
374364
input = makePerTensor(input, inputScale, inputZp);
375365
// The onnx's QLinearConv op expects per channel quantization only for
376366
// the weight tensor for axis = 0.
377-
llvm::outs() << "I'm here\n";
378367
auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType());
379368
auto weightScaleTy =
380369
dyn_cast<Torch::ValueTensorType>(weightScale.getType());
381370
if (!weightTy || !weightScaleTy || !weightTy.hasSizes() ||
382371
!weightScaleTy.hasSizes())
383372
return failure();
384-
llvm::outs() << "I'm here 1\n";
385373
auto weightShape = weightTy.getSizes();
386374
auto weightScaleShape = weightScaleTy.getSizes();
387375
Value weightScaleScalar = extract(weightScale);
@@ -395,13 +383,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
395383
weightZp = extract(weightZp);
396384
weight = makePerTensor(weight, weightScaleScalar, weightZp);
397385
}
398-
weight = weightScaleScalar;
386+
weightScale = weightScaleScalar;
399387

400388
auto outputTy = rewriter.getType<Torch::ValueTensorType>(
401389
resultType.getOptionalSizes(),
402390
rewriter.getIntegerType(32, /*issigned=*/true));
403391

404-
llvm::outs() << "I'm here 2\n";
405392
// TODO(suderman): insert convolution operator.
406393
llvm::SmallVector<Value> newOperands = {input, weight};
407394
if (output)
@@ -438,7 +425,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
438425
outputTy = rewriter.getType<Torch::ValueTensorType>(
439426
resultType.getOptionalSizes(), rewriter.getF32Type());
440427

441-
llvm::outs() << "I'm here 3\n";
442428
output = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(),
443429
outputTy, output);
444430
outputTy = getQTorchTypeFromTorchIntType(resultType);
@@ -452,7 +438,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
452438
binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal);
453439
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
454440
output);
455-
llvm::outs() << "I'm here 4\n";
456441
return success();
457442
});
458443
patterns.onOp(

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,21 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
779779
weight = make.getSelf();
780780
weightZp = make.getZeroPoint();
781781

782+
weight = typeConverter->materializeTargetConversion(
783+
rewriter, loc, typeConverter->convertType(weight.getType()), weight);
784+
weightZp = typeConverter->materializeTargetConversion(
785+
rewriter, loc, typeConverter->convertType(weightZp.getType()),
786+
weightZp);
787+
weightZp = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(),
788+
weightZp);
789+
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
790+
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
791+
} else if (auto make =
792+
op.getWeight()
793+
.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>()) {
794+
weight = make.getSelf();
795+
weightZp = make.getZeroPoint();
796+
782797
weight = typeConverter->materializeTargetConversion(
783798
rewriter, loc, typeConverter->convertType(weight.getType()), weight);
784799
weightZp = typeConverter->materializeTargetConversion(

0 commit comments

Comments
 (0)