diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index 3d506abbaa454..96a3622f4afee 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -910,22 +910,25 @@ class ReductionConversion : public mlir::OpRewritePattern { } auto inlineSource = - [elemental, &designate]( - fir::FirOpBuilder builder, mlir::Location loc, - const llvm::SmallVectorImpl &indices) -> mlir::Value { + [elemental, + &designate](fir::FirOpBuilder builder, mlir::Location loc, + const llvm::SmallVectorImpl &oneBasedIndices) + -> mlir::Value { if (elemental) { // Inline the elemental and get the value from it. - auto yield = inlineElementalOp(loc, builder, elemental, indices); + auto yield = + inlineElementalOp(loc, builder, elemental, oneBasedIndices); auto tmp = yield.getElementValue(); yield->erase(); return tmp; } if (designate) { - // Create a designator over designator, then load the reference. - auto resEntity = hlfir::Entity{designate.getResult()}; - auto tmp = builder.create( - loc, getVariableElementType(resEntity), designate, indices); - return builder.create(loc, tmp); + // Create a designator over the array designator, then load the + // reference. + mlir::Value elementAddr = hlfir::getElementAt( + loc, builder, hlfir::Entity{designate.getResult()}, + oneBasedIndices); + return builder.create(loc, elementAddr); } llvm_unreachable("unsupported type"); }; @@ -936,38 +939,41 @@ class ReductionConversion : public mlir::OpRewritePattern { GenBodyFn genBodyFn; if constexpr (std::is_same_v) { init = builder.createIntegerConstant(loc, builder.getI1Type(), 0); - genBodyFn = - [inlineSource](fir::FirOpBuilder builder, mlir::Location loc, - mlir::Value reduction, - const llvm::SmallVectorImpl &indices) + genBodyFn = [inlineSource]( + fir::FirOpBuilder builder, mlir::Location loc, + mlir::Value reduction, + const llvm::SmallVectorImpl &oneBasedIndices) -> mlir::Value { // Conditionally set the reduction variable. mlir::Value cond = builder.create( - loc, builder.getI1Type(), inlineSource(builder, loc, indices)); + loc, builder.getI1Type(), + inlineSource(builder, loc, oneBasedIndices)); return builder.create(loc, reduction, cond); }; } else if constexpr (std::is_same_v) { init = builder.createIntegerConstant(loc, builder.getI1Type(), 1); - genBodyFn = - [inlineSource](fir::FirOpBuilder builder, mlir::Location loc, - mlir::Value reduction, - const llvm::SmallVectorImpl &indices) + genBodyFn = [inlineSource]( + fir::FirOpBuilder builder, mlir::Location loc, + mlir::Value reduction, + const llvm::SmallVectorImpl &oneBasedIndices) -> mlir::Value { // Conditionally set the reduction variable. mlir::Value cond = builder.create( - loc, builder.getI1Type(), inlineSource(builder, loc, indices)); + loc, builder.getI1Type(), + inlineSource(builder, loc, oneBasedIndices)); return builder.create(loc, reduction, cond); }; } else if constexpr (std::is_same_v) { init = builder.createIntegerConstant(loc, op.getType(), 0); - genBodyFn = - [inlineSource](fir::FirOpBuilder builder, mlir::Location loc, - mlir::Value reduction, - const llvm::SmallVectorImpl &indices) + genBodyFn = [inlineSource]( + fir::FirOpBuilder builder, mlir::Location loc, + mlir::Value reduction, + const llvm::SmallVectorImpl &oneBasedIndices) -> mlir::Value { // Conditionally add one to the current value mlir::Value cond = builder.create( - loc, builder.getI1Type(), inlineSource(builder, loc, indices)); + loc, builder.getI1Type(), + inlineSource(builder, loc, oneBasedIndices)); mlir::Value one = builder.createIntegerConstant(loc, reduction.getType(), 1); mlir::Value add1 = @@ -984,12 +990,12 @@ class ReductionConversion : public mlir::OpRewritePattern { std::is_same_v) { bool isMax = std::is_same_v; init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType()); - genBodyFn = [inlineSource, - isMax](fir::FirOpBuilder builder, mlir::Location loc, - mlir::Value reduction, - const llvm::SmallVectorImpl &indices) + genBodyFn = [inlineSource, isMax]( + fir::FirOpBuilder builder, mlir::Location loc, + mlir::Value reduction, + const llvm::SmallVectorImpl &oneBasedIndices) -> mlir::Value { - mlir::Value val = inlineSource(builder, loc, indices); + mlir::Value val = inlineSource(builder, loc, oneBasedIndices); mlir::Value cmp = generateMinMaxComparison(builder, loc, val, reduction, isMax); return builder.create(loc, cmp, val, reduction); diff --git a/flang/test/HLFIR/maxval-elemental.fir b/flang/test/HLFIR/maxval-elemental.fir index 9dc028abe8da3..a21b4858412de 100644 --- a/flang/test/HLFIR/maxval-elemental.fir +++ b/flang/test/HLFIR/maxval-elemental.fir @@ -93,3 +93,25 @@ func.func @_QPtest_float(%arg0: !fir.box> {fir.bindc_name = "a // CHECK-NEXT: hlfir.assign %[[V6]] to %3#0 : f32, !fir.ref // CHECK-NEXT: return // CHECK-NEXT: } + +// Verify that lower bounds of designator are applied in the indexing inside +// the generated loop (hlfir.designate takes indices relative to the base lower +// bounds). +func.func @component_lower_bounds(%arg0: !fir.ref}>>) -> i32 { + %c10 = arith.constant 10 : index + %c101 = arith.constant 101 : index + %4 = fir.shape_shift %c101, %c10 : (index, index) -> !fir.shapeshift<1> + %5 = hlfir.designate %arg0{"i"} shape %4 : (!fir.ref}>>, !fir.shapeshift<1>) -> !fir.box> + %6 = hlfir.maxval %5 : (!fir.box>) -> i32 + return %6 : i32 +} +// CHECK-LABEL: func.func @component_lower_bounds( +// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 101 : index +// CHECK: %[[VAL_6:.*]] = fir.shape_shift %[[VAL_5]], %[[VAL_4]] : (index, index) -> !fir.shapeshift<1> +// CHECK: %[[VAL_7:.*]] = hlfir.designate %{{.*}}{"i"} shape %[[VAL_6]] : (!fir.ref}>>, !fir.shapeshift<1>) -> !fir.box> +// CHECK: %[[VAL_8:.*]] = fir.do_loop %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_4]] {{.*}} +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]] : index +// CHECK: hlfir.designate %[[VAL_7]] (%[[VAL_11]]) : (!fir.box>, index) -> !fir.ref