diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ffb7c1dc0f3..e9779d8f85af 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1759,6 +1759,61 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { } // namespace +namespace { +class ConvertAtenSoftmaxIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSoftmaxIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Value input = adaptor.getSelf(); + + Value result = op.getResult(); + Value dimValue = op.getDim(); + auto constOp = dimValue.getDefiningOp(); + if (!constOp) { + return rewriter.notifyMatchFailure( + op, "dimension must be a constant integer"); + } + + int64_t dimInt = constOp.getValue(); + + // Handle negative dimensions by converting to positive + if (auto tensorType = cast(input.getType())) { + int64_t rank = tensorType.getRank(); + if (dimInt < 0) { + dimInt += rank; + } + if (dimInt < 0 || dimInt >= rank) { + return rewriter.notifyMatchFailure(op, "dimension out of bounds"); + } + } + + IntegerAttr dimAttr = rewriter.getI64IntegerAttr(dimInt); + + Type newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + auto resultType = cast(newResultType); + Value result_tensor = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + + auto softmax = rewriter.create( + loc, TypeRange{resultType}, input, result_tensor, dimAttr); + + rewriter.replaceOp(op, softmax.getResult()); + // if we know constop is only used by this softmax op, erase it + if (constOp.getResult().hasOneUse()) { + rewriter.eraseOp(constOp); + } + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1775,4 +1830,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); }