Skip to content

Commit 1dd573c

Browse files
committed
Address review comments
1 parent 34c2599 commit 1dd573c

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,7 +3589,7 @@ class ConvertSymConstrainRangeOp
35893589
int64_t minValue = std::numeric_limits<int64_t>::min();
35903590
int64_t maxValue = std::numeric_limits<int64_t>::max();
35913591

3592-
Type operandType = rewriter.getI64Type();
3592+
Type operandType = getTypeConverter()->convertType(op.getSize().getType());
35933593

35943594
if (!isa<Torch::ConstantNoneOp>(minOp))
35953595
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
@@ -3615,10 +3615,10 @@ class ConvertSymConstrainRangeOp
36153615

36163616
// FIXME:: Skip the below checks if constraint ops are already inserted as
36173617
// part of symbol expr evaluation
3618-
auto checkMin = createLessThanOrEqual(rewriter, loc, operandType, min,
3619-
adaptor.getSize());
3620-
auto checkMax = createLessThanOrEqual(rewriter, loc, operandType,
3621-
adaptor.getSize(), max);
3618+
auto checkMin = rewriter.create<arith::CmpIOp>(
3619+
loc, arith::CmpIPredicate::sle, min, adaptor.getSize());
3620+
auto checkMax = rewriter.create<arith::CmpIOp>(
3621+
loc, arith::CmpIPredicate::sle, adaptor.getSize(), max);
36223622
auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax);
36233623

36243624
std::string assertMessage = "Invalid value range for size between [" +

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6488,7 +6488,7 @@ def __init__(self):
64886488
def forward(self, x):
64896489
a = x.item()
64906490
torch._check_is_size(a)
6491-
# max should be >= 2
6491+
# max should be > 2
64926492
torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10)
64936493
return a
64946494

test/Conversion/TorchToLinalg/constraints.mlir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,26 @@
1010
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
1111
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64
1212
// CHECK: %[[VAL_7:.*]] = arith.constant 9223372036854775807 : i64
13-
// CHECK: %[[VAL_8:.*]] = arith.cmpi ule, %[[VAL_6]], %[[VAL_5]] : i64
14-
// CHECK: %[[VAL_9:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_7]] : i64
13+
// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64
14+
// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64
1515
// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1
1616
// CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]"
1717
// CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64
1818
// CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64
19-
// CHECK: %[[VAL_13:.*]] = arith.cmpi ule, %[[VAL_11]], %[[VAL_5]] : i64
20-
// CHECK: %[[VAL_14:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_12]] : i64
19+
// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64
20+
// CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64
2121
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
2222
// CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]"
2323
// CHECK: return %[[VAL_4]] : !torch.int
2424

25-
module {
26-
func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
25+
func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
2726
%int7 = torch.constant.int 7
2827
%int0 = torch.constant.int 0
2928
%none = torch.constant.none
3029
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
3130
torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none
3231
torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int
3332
return %0 : !torch.int
34-
}
3533
}
3634

3735
// -----
@@ -47,8 +45,8 @@ module {
4745
// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]]
4846
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64
4947
// CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64
50-
// CHECK: %[[VAL_10:.*]] = arith.cmpi ule, %[[VAL_8]], %[[VAL_7]] : i64
51-
// CHECK: %[[VAL_11:.*]] = arith.cmpi ule, %[[VAL_7]], %[[VAL_9]] : i64
48+
// CHECK: %[[VAL_10:.*]] = arith.cmpi sle, %[[VAL_8]], %[[VAL_7]] : i64
49+
// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_9]] : i64
5250
// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1
5351
// CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]"
5452
// CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool

0 commit comments

Comments
 (0)