Skip to content

[torch-mlir] Support lowering of aten constraint ops #3943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17771,6 +17771,77 @@ def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_
}];
}

def Torch_AtenSymConstrainRangeOp : Torch_Op<"aten.sym_constrain_range", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sym_constrain_range : (Scalar, int?, int?) -> ()`";
let arguments = (ins
AnyTorchScalarType:$size,
AnyTorchOptionalIntType:$min,
AnyTorchOptionalIntType:$max
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSymConstrainRangeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 0);
}
void AtenSymConstrainRangeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 0);
}
}];
}

def Torch_AtenSymConstrainRangeForSizeOp : Torch_Op<"aten.sym_constrain_range_for_size", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()`";
let arguments = (ins
AnyTorchScalarType:$size,
AnyTorchOptionalIntType:$min,
AnyTorchOptionalIntType:$max
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSymConstrainRangeForSizeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 0);
}
void AtenSymConstrainRangeForSizeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 0);
}
}];
}

def Torch_Aten_AssertScalarOp : Torch_Op<"aten._assert_scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_assert_scalar : (Scalar, str) -> ()`";
let arguments = (ins
AnyTorchScalarType:$self,
Torch_StringType:$assert_msg
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AssertScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 0);
}
void Aten_AssertScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 0);
}
}];
}

def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
66 changes: 66 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
#include <string>
#include <type_traits>

using namespace mlir;
Expand Down Expand Up @@ -3564,6 +3566,68 @@ class ConvertAtenPolarOp : public OpConversionPattern<AtenPolarOp> {
};
} // namespace

namespace {
class ConvertSymConstrainRangeOp
: public OpConversionPattern<AtenSymConstrainRangeOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

auto loc = op.getLoc();
auto min = op.getMin();
auto max = op.getMax();

int64_t minValue = std::numeric_limits<int64_t>::min();
int64_t maxValue = std::numeric_limits<int64_t>::max();

Type operandType = getTypeConverter()->convertType(op.getSize().getType());

if (!isa<Torch::NoneType>(min.getType()))
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
return rewriter.notifyMatchFailure(
op, "Expected min value to be constant integer");

if (!isa<Torch::NoneType>(max.getType()))
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
return rewriter.notifyMatchFailure(
op, "Expected max value to be constant integer");

if (maxValue < minValue) {
std::string errorMsg =
"Max must be greater than or equal to min, got min = " +
std::to_string(minValue) + ", max = " + std::to_string(maxValue);
return op.emitError(errorMsg);
}

min = getConstant(rewriter, loc, minValue, operandType);
max = getConstant(rewriter, loc, maxValue, operandType);

// Check min <= size <= max

// FIXME:: Skip the below checks if constraint ops are already inserted as
// part of symbol expr evaluation
auto checkMin = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, min, adaptor.getSize());
auto checkMax = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, adaptor.getSize(), max);
auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax);

std::string assertMessage = "Size constraint failed. Expected range: [" +
std::to_string(minValue) + ", " +
std::to_string(maxValue) + "]";
rewriter.create<cf::AssertOp>(loc, compareVal,
rewriter.getStringAttr(assertMessage));

rewriter.eraseOp(op);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -3626,4 +3690,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
target.addIllegalOp<AtenPolarOp>();
patterns.add<ConvertAtenPolarOp>(typeConverter, context);
target.addIllegalOp<AtenSymConstrainRangeOp>();
patterns.add<ConvertSymConstrainRangeOp>(typeConverter, context);
}
78 changes: 78 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11455,6 +11455,80 @@ class DecomposeAtenSpecialExpm1Op
};
} // namespace

namespace {
class DecomposeAtenConstrainRangeForSizeOp
: public OpRewritePattern<AtenSymConstrainRangeForSizeOp> {
public:
using OpRewritePattern<AtenSymConstrainRangeForSizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSymConstrainRangeForSizeOp op,
PatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto min = op.getMin();
auto max = op.getMax();

int64_t minValue, maxValue;

if (isa<Torch::NoneType>(min.getType())) {
// Set min value to 0
min = rewriter.create<Torch::ConstantIntOp>(loc, 0);
} else {
// Check if min value is a constant
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
return rewriter.notifyMatchFailure(
op, "Expected min value to be constant integer");
}

if (!isa<Torch::NoneType>(max.getType())) {
// Verify that max value is greater than 2
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
return rewriter.notifyMatchFailure(
op, "Expected max value to be constant integer");

if (maxValue <= 2) {
std::string errorMsg = "Max value to constrain_range_for_size must be "
"greater than 2, got: " +
std::to_string(maxValue);
return op.emitError(errorMsg);
}
}

rewriter.replaceOpWithNewOp<AtenSymConstrainRangeOp>(op, op.getSize(), min,
max);
return success();
}
};
} // namespace

namespace {
class DecomposeAten_AssertScalarOp
: public OpRewritePattern<Aten_AssertScalarOp> {
public:
using OpRewritePattern<Aten_AssertScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_AssertScalarOp op,
PatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto assertCond = op.getSelf();

if (isa<Torch::IntType>(assertCond.getType()))
assertCond = rewriter.create<AtenBoolIntOp>(loc, assertCond);
else if (isa<Torch::FloatType>(assertCond.getType()))
assertCond = rewriter.create<AtenBoolFloatOp>(loc, assertCond);
assert(isa<Torch::BoolType>(assertCond.getType()) &&
"Unhandled type encountered in aten._assert_scalar op");

std::string assertMessage;
if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage)))
return rewriter.notifyMatchFailure(
op, "Assert message must be a constant string");

rewriter.replaceOpWithNewOp<RuntimeAssertOp>(op, assertCond, assertMessage);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -11753,6 +11827,10 @@ class DecomposeComplexOpsPass
// Torchvision ops
addPatternIfTargetOpIsIllegal<DecomposeTorchvisionNmsOp>(patterns);

addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
12 changes: 11 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"Aten_TrilinearModuleZerodDimBug_basic",
# missing lowering from aten.pow.Tensor_Tensor for integer result
"PowIntIntModule_basic",
# Unknown builtin op: aten::_check_is_size in TorchScript
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

if torch_version_for_comparison() < version.parse("2.5.0.dev"):
Expand Down Expand Up @@ -623,7 +627,6 @@
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"AtenNonzero1DDynamicModule_basic",
"AtenRealView128Module_basic",
"AtenRealView64Module_basic",
"AtenTopKModule_basic",
Expand Down Expand Up @@ -941,6 +944,9 @@
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand All @@ -964,6 +970,7 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"AtenNonzero1DDynamicModule_basic", # error: Mismatched ranks of types2 vs 1
}

STABLEHLO_PASS_SET = {
Expand Down Expand Up @@ -3254,6 +3261,9 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"ScaledDotProductAttentionGQAModule_basic",
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,11 @@ def emit_with_mutating_variants(key, **kwargs):
)
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")

# Constraint ops
emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()")
emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()")
emit("aten::_assert_scalar : (Scalar, str) -> ()")

# ==========================================================================
# `prim::` namespace.
# ==========================================================================
Expand Down
59 changes: 59 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6480,3 +6480,62 @@ def forward(self, x):
@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule())
def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))


# ==============================================================================


class AtenSymConstrainRange(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
torch.ops.aten.sym_constrain_range(a, max=5)
return a


@register_test_case(module_factory=lambda: AtenSymConstrainRange())
def AtenSymConstrainRange_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))


# ==============================================================================


class AtenSymConstrainRangeForSize(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10)
return a


@register_test_case(module_factory=lambda: AtenSymConstrainRangeForSize())
def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))


# ==============================================================================
class Aten_AssertScalar(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
assert_msg = "Assertion failed for condition x.item() > 3"
torch.ops.aten._assert_scalar(a > 3, assert_msg)
return a


@register_test_case(module_factory=lambda: Aten_AssertScalar())
def Aten_AssertScalar_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))
Loading
Loading