@@ -910,22 +910,25 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
910
910
}
911
911
912
912
auto inlineSource =
913
- [elemental, &designate](
914
- fir::FirOpBuilder builder, mlir::Location loc,
915
- const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
913
+ [elemental,
914
+ &designate](fir::FirOpBuilder builder, mlir::Location loc,
915
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices)
916
+ -> mlir::Value {
916
917
if (elemental) {
917
918
// Inline the elemental and get the value from it.
918
- auto yield = inlineElementalOp (loc, builder, elemental, indices);
919
+ auto yield =
920
+ inlineElementalOp (loc, builder, elemental, oneBasedIndices);
919
921
auto tmp = yield.getElementValue ();
920
922
yield->erase ();
921
923
return tmp;
922
924
}
923
925
if (designate) {
924
- // Create a designator over designator, then load the reference.
925
- auto resEntity = hlfir::Entity{designate.getResult ()};
926
- auto tmp = builder.create <hlfir::DesignateOp>(
927
- loc, getVariableElementType (resEntity), designate, indices);
928
- return builder.create <fir::LoadOp>(loc, tmp);
926
+ // Create a designator over the array designator, then load the
927
+ // reference.
928
+ mlir::Value elementAddr = hlfir::getElementAt (
929
+ loc, builder, hlfir::Entity{designate.getResult ()},
930
+ oneBasedIndices);
931
+ return builder.create <fir::LoadOp>(loc, elementAddr);
929
932
}
930
933
llvm_unreachable (" unsupported type" );
931
934
};
@@ -936,38 +939,41 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
936
939
GenBodyFn genBodyFn;
937
940
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
938
941
init = builder.createIntegerConstant (loc, builder.getI1Type (), 0 );
939
- genBodyFn =
940
- [inlineSource]( fir::FirOpBuilder builder, mlir::Location loc,
941
- mlir::Value reduction,
942
- const llvm::SmallVectorImpl<mlir::Value> &indices )
942
+ genBodyFn = [inlineSource](
943
+ fir::FirOpBuilder builder, mlir::Location loc,
944
+ mlir::Value reduction,
945
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices )
943
946
-> mlir::Value {
944
947
// Conditionally set the reduction variable.
945
948
mlir::Value cond = builder.create <fir::ConvertOp>(
946
- loc, builder.getI1Type (), inlineSource (builder, loc, indices));
949
+ loc, builder.getI1Type (),
950
+ inlineSource (builder, loc, oneBasedIndices));
947
951
return builder.create <mlir::arith::OrIOp>(loc, reduction, cond);
948
952
};
949
953
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
950
954
init = builder.createIntegerConstant (loc, builder.getI1Type (), 1 );
951
- genBodyFn =
952
- [inlineSource]( fir::FirOpBuilder builder, mlir::Location loc,
953
- mlir::Value reduction,
954
- const llvm::SmallVectorImpl<mlir::Value> &indices )
955
+ genBodyFn = [inlineSource](
956
+ fir::FirOpBuilder builder, mlir::Location loc,
957
+ mlir::Value reduction,
958
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices )
955
959
-> mlir::Value {
956
960
// Conditionally set the reduction variable.
957
961
mlir::Value cond = builder.create <fir::ConvertOp>(
958
- loc, builder.getI1Type (), inlineSource (builder, loc, indices));
962
+ loc, builder.getI1Type (),
963
+ inlineSource (builder, loc, oneBasedIndices));
959
964
return builder.create <mlir::arith::AndIOp>(loc, reduction, cond);
960
965
};
961
966
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
962
967
init = builder.createIntegerConstant (loc, op.getType (), 0 );
963
- genBodyFn =
964
- [inlineSource]( fir::FirOpBuilder builder, mlir::Location loc,
965
- mlir::Value reduction,
966
- const llvm::SmallVectorImpl<mlir::Value> &indices )
968
+ genBodyFn = [inlineSource](
969
+ fir::FirOpBuilder builder, mlir::Location loc,
970
+ mlir::Value reduction,
971
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices )
967
972
-> mlir::Value {
968
973
// Conditionally add one to the current value
969
974
mlir::Value cond = builder.create <fir::ConvertOp>(
970
- loc, builder.getI1Type (), inlineSource (builder, loc, indices));
975
+ loc, builder.getI1Type (),
976
+ inlineSource (builder, loc, oneBasedIndices));
971
977
mlir::Value one =
972
978
builder.createIntegerConstant (loc, reduction.getType (), 1 );
973
979
mlir::Value add1 =
@@ -984,12 +990,12 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
984
990
std::is_same_v<Op, hlfir::MinvalOp>) {
985
991
bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
986
992
init = makeMinMaxInitValGenerator (isMax)(builder, loc, op.getType ());
987
- genBodyFn = [inlineSource,
988
- isMax]( fir::FirOpBuilder builder, mlir::Location loc,
989
- mlir::Value reduction,
990
- const llvm::SmallVectorImpl<mlir::Value> &indices )
993
+ genBodyFn = [inlineSource, isMax](
994
+ fir::FirOpBuilder builder, mlir::Location loc,
995
+ mlir::Value reduction,
996
+ const llvm::SmallVectorImpl<mlir::Value> &oneBasedIndices )
991
997
-> mlir::Value {
992
- mlir::Value val = inlineSource (builder, loc, indices );
998
+ mlir::Value val = inlineSource (builder, loc, oneBasedIndices );
993
999
mlir::Value cmp =
994
1000
generateMinMaxComparison (builder, loc, val, reduction, isMax);
995
1001
return builder.create <mlir::arith::SelectOp>(loc, cmp, val, reduction);
0 commit comments