@@ -330,16 +330,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
330330 Value outputZp = operands[7 ];
331331 Value output = operands.size () == 9 ? operands[8 ] : nullptr ;
332332
333- // auto check = [](Value v) {
334- // auto vTy = cast<Torch::ValueTensorType>(v.getType());
335- // return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1;
336- // });
337- // };
338- // if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
339- // !check(cScale) || !check(cScale))
340- // return rewriter.notifyMatchFailure(
341- // binder.op, "not supported for non per-tensor quantization");
342-
343333 auto extract = [&rewriter, &binder](Value v) {
344334 auto vTy = cast<Torch::ValueTensorType>(v.getType ());
345335 Type extractTy = rewriter.getType <Torch::FloatType>();
@@ -374,14 +364,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
374364 input = makePerTensor (input, inputScale, inputZp);
375365 // The onnx's QLinearConv op expects per channel quantization only for
376366 // the weight tensor for axis = 0.
377- llvm::outs () << " I'm here\n " ;
378367 auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType ());
379368 auto weightScaleTy =
380369 dyn_cast<Torch::ValueTensorType>(weightScale.getType ());
381370 if (!weightTy || !weightScaleTy || !weightTy.hasSizes () ||
382371 !weightScaleTy.hasSizes ())
383372 return failure ();
384- llvm::outs () << " I'm here 1\n " ;
385373 auto weightShape = weightTy.getSizes ();
386374 auto weightScaleShape = weightScaleTy.getSizes ();
387375 Value weightScaleScalar = extract (weightScale);
@@ -395,13 +383,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
395383 weightZp = extract (weightZp);
396384 weight = makePerTensor (weight, weightScaleScalar, weightZp);
397385 }
398- weight = weightScaleScalar;
386+ weightScale = weightScaleScalar;
399387
400388 auto outputTy = rewriter.getType <Torch::ValueTensorType>(
401389 resultType.getOptionalSizes (),
402390 rewriter.getIntegerType (32 , /* issigned=*/ true ));
403391
404- llvm::outs () << " I'm here 2\n " ;
405392 // TODO(suderman): insert convolution operator.
406393 llvm::SmallVector<Value> newOperands = {input, weight};
407394 if (output)
@@ -438,7 +425,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
438425 outputTy = rewriter.getType <Torch::ValueTensorType>(
439426 resultType.getOptionalSizes (), rewriter.getF32Type ());
440427
441- llvm::outs () << " I'm here 3\n " ;
442428 output = rewriter.create <Torch::AtenDequantizeSelfOp>(binder.getLoc (),
443429 outputTy, output);
444430 outputTy = getQTorchTypeFromTorchIntType (resultType);
@@ -452,7 +438,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
452438 binder.getLoc (), outputTy, output, outputScale, outputZp, dtyVal);
453439 rewriter.replaceOpWithNewOp <Torch::AtenIntReprOp>(binder.op , resultType,
454440 output);
455- llvm::outs () << " I'm here 4\n " ;
456441 return success ();
457442 });
458443 patterns.onOp (
0 commit comments