Skip to content

Commit 9a659fa

Browse files
authored
[flang] fix MAXVAL(x%array_comp_with_custom_lower_bounds) (llvm#129684)
The HLFIR inlining of MAXVAL kicks in at O1 and more when the argument is an array component reference but the implementation did not account for the rare cases where the array components have non default lower bounds. This patch fixes the issue by using `getElementAt` to compute the element address. Rename `indices` to `oneBasedIndices` for more clarity.
1 parent 17f0aaa commit 9a659fa

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -910,22 +910,25 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
910910
}
911911

912912
auto inlineSource =
913-
[elemental, &designate](
914-
fir::FirOpBuilder builder, mlir::Location loc,
915-
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
913+
[elemental,
914+
&designate](fir::FirOpBuilder builder, mlir::Location loc,
915+
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
916+
-> mlir::Value {
916917
if (elemental) {
917918
// Inline the elemental and get the value from it.
918-
auto yield = inlineElementalOp(loc, builder, elemental, indices);
919+
auto yield =
920+
inlineElementalOp(loc, builder, elemental, oneBasedIndices);
919921
auto tmp = yield.getElementValue();
920922
yield->erase();
921923
return tmp;
922924
}
923925
if (designate) {
924-
// Create a designator over designator, then load the reference.
925-
auto resEntity = hlfir::Entity{designate.getResult()};
926-
auto tmp = builder.create<hlfir::DesignateOp>(
927-
loc, getVariableElementType(resEntity), designate, indices);
928-
return builder.create<fir::LoadOp>(loc, tmp);
926+
// Create a designator over the array designator, then load the
927+
// reference.
928+
mlir::Value elementAddr = hlfir::getElementAt(
929+
loc, builder, hlfir::Entity{designate.getResult()},
930+
oneBasedIndices);
931+
return builder.create<fir::LoadOp>(loc, elementAddr);
929932
}
930933
llvm_unreachable("unsupported type");
931934
};
@@ -936,38 +939,41 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
936939
GenBodyFn genBodyFn;
937940
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
938941
init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
939-
genBodyFn =
940-
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
941-
mlir::Value reduction,
942-
const llvm::SmallVectorImpl<mlir::Value> &indices)
942+
genBodyFn = [inlineSource](
943+
fir::FirOpBuilder builder, mlir::Location loc,
944+
mlir::Value reduction,
945+
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
943946
-> mlir::Value {
944947
// Conditionally set the reduction variable.
945948
mlir::Value cond = builder.create<fir::ConvertOp>(
946-
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
949+
loc, builder.getI1Type(),
950+
inlineSource(builder, loc, oneBasedIndices));
947951
return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
948952
};
949953
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
950954
init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
951-
genBodyFn =
952-
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
953-
mlir::Value reduction,
954-
const llvm::SmallVectorImpl<mlir::Value> &indices)
955+
genBodyFn = [inlineSource](
956+
fir::FirOpBuilder builder, mlir::Location loc,
957+
mlir::Value reduction,
958+
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
955959
-> mlir::Value {
956960
// Conditionally set the reduction variable.
957961
mlir::Value cond = builder.create<fir::ConvertOp>(
958-
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
962+
loc, builder.getI1Type(),
963+
inlineSource(builder, loc, oneBasedIndices));
959964
return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
960965
};
961966
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
962967
init = builder.createIntegerConstant(loc, op.getType(), 0);
963-
genBodyFn =
964-
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
965-
mlir::Value reduction,
966-
const llvm::SmallVectorImpl<mlir::Value> &indices)
968+
genBodyFn = [inlineSource](
969+
fir::FirOpBuilder builder, mlir::Location loc,
970+
mlir::Value reduction,
971+
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
967972
-> mlir::Value {
968973
// Conditionally add one to the current value
969974
mlir::Value cond = builder.create<fir::ConvertOp>(
970-
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
975+
loc, builder.getI1Type(),
976+
inlineSource(builder, loc, oneBasedIndices));
971977
mlir::Value one =
972978
builder.createIntegerConstant(loc, reduction.getType(), 1);
973979
mlir::Value add1 =
@@ -984,12 +990,12 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
984990
std::is_same_v<Op, hlfir::MinvalOp>) {
985991
bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
986992
init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
987-
genBodyFn = [inlineSource,
988-
isMax](fir::FirOpBuilder builder, mlir::Location loc,
989-
mlir::Value reduction,
990-
const llvm::SmallVectorImpl<mlir::Value> &indices)
993+
genBodyFn = [inlineSource, isMax](
994+
fir::FirOpBuilder builder, mlir::Location loc,
995+
mlir::Value reduction,
996+
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
991997
-> mlir::Value {
992-
mlir::Value val = inlineSource(builder, loc, indices);
998+
mlir::Value val = inlineSource(builder, loc, oneBasedIndices);
993999
mlir::Value cmp =
9941000
generateMinMaxComparison(builder, loc, val, reduction, isMax);
9951001
return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);

flang/test/HLFIR/maxval-elemental.fir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,25 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
9393
// CHECK-NEXT: hlfir.assign %[[V6]] to %3#0 : f32, !fir.ref<f32>
9494
// CHECK-NEXT: return
9595
// CHECK-NEXT: }
96+
97+
// Verify that lower bounds of designator are applied in the indexing inside
98+
// the generated loop (hlfir.designate takes indices relative to the base lower
99+
// bounds).
100+
func.func @component_lower_bounds(%arg0: !fir.ref<!fir.type<sometype{i:!fir.array<10xi32>}>>) -> i32 {
101+
%c10 = arith.constant 10 : index
102+
%c101 = arith.constant 101 : index
103+
%4 = fir.shape_shift %c101, %c10 : (index, index) -> !fir.shapeshift<1>
104+
%5 = hlfir.designate %arg0{"i"} shape %4 : (!fir.ref<!fir.type<sometype{i:!fir.array<10xi32>}>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<10xi32>>
105+
%6 = hlfir.maxval %5 : (!fir.box<!fir.array<10xi32>>) -> i32
106+
return %6 : i32
107+
}
108+
// CHECK-LABEL: func.func @component_lower_bounds(
109+
// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index
110+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
111+
// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
112+
// CHECK: %[[VAL_5:.*]] = arith.constant 101 : index
113+
// CHECK: %[[VAL_6:.*]] = fir.shape_shift %[[VAL_5]], %[[VAL_4]] : (index, index) -> !fir.shapeshift<1>
114+
// CHECK: %[[VAL_7:.*]] = hlfir.designate %{{.*}}{"i"} shape %[[VAL_6]] : (!fir.ref<!fir.type<sometype{i:!fir.array<10xi32>}>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<10xi32>>
115+
// CHECK: %[[VAL_8:.*]] = fir.do_loop %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_4]] {{.*}}
116+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]] : index
117+
// CHECK: hlfir.designate %[[VAL_7]] (%[[VAL_11]]) : (!fir.box<!fir.array<10xi32>>, index) -> !fir.ref<i32>

0 commit comments

Comments
 (0)