Skip to content

Commit

Permalink
[flang] fix MAXVAL(x%array_comp_with_custom_lower_bounds) (llvm#129684)
Browse files Browse the repository at this point in the history
The HLFIR inlining of MAXVAL kicks in at O1 and more when the argument
is an array component reference but the implementation did not account
for the rare cases where the array components have non default lower
bounds.

This patch fixes the issue by using `getElementAt` to compute the
element address.
Rename `indices` to `oneBasedIndices` for more clarity.
  • Loading branch information
jeanPerier authored Mar 4, 2025
1 parent 17f0aaa commit 9a659fa
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 29 deletions.
64 changes: 35 additions & 29 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,22 +910,25 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
}

auto inlineSource =
[elemental, &designate](
fir::FirOpBuilder builder, mlir::Location loc,
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
[elemental,
&designate](fir::FirOpBuilder builder, mlir::Location loc,
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
-> mlir::Value {
if (elemental) {
// Inline the elemental and get the value from it.
auto yield = inlineElementalOp(loc, builder, elemental, indices);
auto yield =
inlineElementalOp(loc, builder, elemental, oneBasedIndices);
auto tmp = yield.getElementValue();
yield->erase();
return tmp;
}
if (designate) {
// Create a designator over designator, then load the reference.
auto resEntity = hlfir::Entity{designate.getResult()};
auto tmp = builder.create<hlfir::DesignateOp>(
loc, getVariableElementType(resEntity), designate, indices);
return builder.create<fir::LoadOp>(loc, tmp);
// Create a designator over the array designator, then load the
// reference.
mlir::Value elementAddr = hlfir::getElementAt(
loc, builder, hlfir::Entity{designate.getResult()},
oneBasedIndices);
return builder.create<fir::LoadOp>(loc, elementAddr);
}
llvm_unreachable("unsupported type");
};
Expand All @@ -936,38 +939,41 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
GenBodyFn genBodyFn;
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
genBodyFn =
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
genBodyFn = [inlineSource](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
-> mlir::Value {
// Conditionally set the reduction variable.
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
loc, builder.getI1Type(),
inlineSource(builder, loc, oneBasedIndices));
return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
};
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
genBodyFn =
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
genBodyFn = [inlineSource](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
-> mlir::Value {
// Conditionally set the reduction variable.
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
loc, builder.getI1Type(),
inlineSource(builder, loc, oneBasedIndices));
return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
};
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
init = builder.createIntegerConstant(loc, op.getType(), 0);
genBodyFn =
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
genBodyFn = [inlineSource](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
-> mlir::Value {
// Conditionally add one to the current value
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
loc, builder.getI1Type(),
inlineSource(builder, loc, oneBasedIndices));
mlir::Value one =
builder.createIntegerConstant(loc, reduction.getType(), 1);
mlir::Value add1 =
Expand All @@ -984,12 +990,12 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
std::is_same_v<Op, hlfir::MinvalOp>) {
bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
genBodyFn = [inlineSource,
isMax](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
genBodyFn = [inlineSource, isMax](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
-> mlir::Value {
mlir::Value val = inlineSource(builder, loc, indices);
mlir::Value val = inlineSource(builder, loc, oneBasedIndices);
mlir::Value cmp =
generateMinMaxComparison(builder, loc, val, reduction, isMax);
return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
Expand Down
22 changes: 22 additions & 0 deletions flang/test/HLFIR/maxval-elemental.fir
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,25 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
// CHECK-NEXT: hlfir.assign %[[V6]] to %3#0 : f32, !fir.ref<f32>
// CHECK-NEXT: return
// CHECK-NEXT: }

// Verify that lower bounds of designator are applied in the indexing inside
// the generated loop (hlfir.designate takes indices relative to the base lower
// bounds).
func.func @component_lower_bounds(%arg0: !fir.ref<!fir.type<sometype{i:!fir.array<10xi32>}>>) -> i32 {
%c10 = arith.constant 10 : index
%c101 = arith.constant 101 : index
%4 = fir.shape_shift %c101, %c10 : (index, index) -> !fir.shapeshift<1>
%5 = hlfir.designate %arg0{"i"} shape %4 : (!fir.ref<!fir.type<sometype{i:!fir.array<10xi32>}>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<10xi32>>
%6 = hlfir.maxval %5 : (!fir.box<!fir.array<10xi32>>) -> i32
return %6 : i32
}
// CHECK-LABEL: func.func @component_lower_bounds(
// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 101 : index
// CHECK: %[[VAL_6:.*]] = fir.shape_shift %[[VAL_5]], %[[VAL_4]] : (index, index) -> !fir.shapeshift<1>
// CHECK: %[[VAL_7:.*]] = hlfir.designate %{{.*}}{"i"} shape %[[VAL_6]] : (!fir.ref<!fir.type<sometype{i:!fir.array<10xi32>}>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<10xi32>>
// CHECK: %[[VAL_8:.*]] = fir.do_loop %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_4]] {{.*}}
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]] : index
// CHECK: hlfir.designate %[[VAL_7]] (%[[VAL_11]]) : (!fir.box<!fir.array<10xi32>>, index) -> !fir.ref<i32>

0 comments on commit 9a659fa

Please sign in to comment.