Skip to content

Commit cb15d3f

Browse files
authored
[CIR][LowerToLLVM][NFC] Refactor GlobalOpLowering for better readability and maintainability (#1525)
As noted in [this comment](#1442 (comment)), the nested if-arms in `GlobalOpLowering` are somewhat confusing and error-prone. This PR simplifies the logic into more straightforward components. Since LLVM's GlobalOp accepts two types of initializers (either an initializer value or an initializer region), we've extracted the decision logic into a separate function called `lowerInitializer`. This function takes two inout arguments: `mlir::Attribute &init` (for the attribute value) and `bool useInitializerRegion` (as the decision indicator). All code paths then converge at a common epilogue that handles the operation rewriting. The previous implementation for lowering `DataMemberAttr` initializers relied on recursion between MLIR rewrite calls, which made the control flow somewhat opaque. The new version makes this explicit by using a clear self-recursive pattern within `lowerInitializer`.
1 parent 55a2236 commit cb15d3f

File tree

2 files changed

+156
-116
lines changed

2 files changed

+156
-116
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+136-109
Original file line numberDiff line numberDiff line change
@@ -2379,49 +2379,16 @@ mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite(
23792379
return mlir::success();
23802380
}
23812381

2382-
/// Replace CIR global with a region initialized LLVM global and update
2383-
/// insertion point to the end of the initializer block.
2384-
void CIRToLLVMGlobalOpLowering::createRegionInitializedLLVMGlobalOp(
2385-
cir::GlobalOp op, mlir::Attribute attr,
2386-
mlir::ConversionPatternRewriter &rewriter,
2387-
SmallVector<mlir::NamedAttribute> attributes) const {
2388-
const auto llvmType =
2389-
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType());
2390-
auto newGlobalOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
2391-
op, llvmType, op.getConstant(), convertLinkage(op.getLinkage()),
2392-
op.getSymName(), nullptr,
2393-
/*alignment*/ op.getAlignment().value_or(0),
2394-
/*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op),
2395-
/*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(),
2396-
/*comdat*/ mlir::SymbolRefAttr(), attributes);
2397-
newGlobalOp.getRegion().push_back(new mlir::Block());
2398-
rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());
2399-
2400-
rewriter.create<mlir::LLVM::ReturnOp>(
2401-
op->getLoc(),
2402-
lowerCirAttrAsValue(op, attr, rewriter, typeConverter, dataLayout));
2403-
}
2404-
2405-
mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
2406-
cir::GlobalOp op, OpAdaptor adaptor,
2407-
mlir::ConversionPatternRewriter &rewriter) const {
2408-
2409-
// Fetch required values to create LLVM op.
2410-
const auto CIRSymType = op.getSymType();
2382+
llvm::SmallVector<mlir::NamedAttribute>
2383+
CIRToLLVMGlobalOpLowering::lowerGlobalAttributes(
2384+
cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
2385+
SmallVector<mlir::NamedAttribute> attributes;
24112386

2412-
const auto llvmType =
2413-
convertTypeForMemory(*getTypeConverter(), dataLayout, CIRSymType);
2414-
const auto isConst = op.getConstant();
2415-
const auto isDsoLocal = op.getDsolocal();
2416-
const auto linkage = convertLinkage(op.getLinkage());
2417-
const auto symbol = op.getSymName();
24182387
std::optional<mlir::StringRef> section = op.getSection();
2419-
std::optional<mlir::Attribute> init = op.getInitialValue();
24202388
mlir::LLVM::VisibilityAttr visibility = mlir::LLVM::VisibilityAttr::get(
24212389
getContext(), lowerCIRVisibilityToLLVMVisibility(
24222390
op.getGlobalVisibilityAttr().getValue()));
24232391

2424-
SmallVector<mlir::NamedAttribute> attributes;
24252392
if (section.has_value())
24262393
attributes.push_back(rewriter.getNamedAttr(
24272394
"section", rewriter.getStringAttr(section.value())));
@@ -2433,88 +2400,147 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
24332400
attributes.push_back(rewriter.getNamedAttr("externally_initialized",
24342401
rewriter.getUnitAttr()));
24352402
}
2403+
return attributes;
2404+
}
24362405

2437-
if (init.has_value()) {
2438-
if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) {
2439-
// If a directly equivalent attribute is available, use it.
2440-
init =
2441-
llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
2442-
.Case<cir::FPAttr>([&](cir::FPAttr attr) {
2443-
return rewriter.getFloatAttr(llvmType, attr.getValue());
2444-
})
2445-
.Case<cir::IntAttr>([&](cir::IntAttr attr) {
2446-
return rewriter.getIntegerAttr(llvmType, attr.getValue());
2447-
})
2448-
.Case<cir::BoolAttr>([&](cir::BoolAttr attr) {
2449-
return rewriter.getBoolAttr(attr.getValue());
2450-
})
2451-
.Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
2452-
// If initRewriter returned a null attribute, init will have a value but
2453-
// the value will be null. If that happens, initRewriter didn't handle the
2454-
// attribute type. It probably needs to be added to
2455-
// GlobalInitAttrRewriter.
2456-
if (!init.value()) {
2457-
op.emitError() << "unsupported initializer '" << init.value() << "'";
2458-
return mlir::failure();
2459-
}
2460-
} else if (mlir::isa<cir::ZeroAttr, cir::ConstPtrAttr, cir::UndefAttr,
2461-
cir::ConstStructAttr, cir::GlobalViewAttr,
2462-
cir::VTableAttr, cir::TypeInfoAttr>(init.value())) {
2463-
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
2464-
// should be updated. For now, we use a custom op to initialize globals
2465-
// to the appropriate value.
2466-
createRegionInitializedLLVMGlobalOp(op, init.value(), rewriter,
2467-
attributes);
2468-
return mlir::success();
2469-
} else if (auto constArr =
2470-
mlir::dyn_cast<cir::ConstArrayAttr>(init.value())) {
2471-
// Initializer is a constant array: convert it to a compatible llvm init.
2472-
if (auto attr = mlir::dyn_cast<mlir::StringAttr>(constArr.getElts())) {
2473-
llvm::SmallString<256> literal(attr.getValue());
2474-
if (constArr.getTrailingZerosNum())
2475-
literal.append(constArr.getTrailingZerosNum(), '\0');
2476-
init = rewriter.getStringAttr(literal);
2477-
} else if (auto attr =
2478-
mlir::dyn_cast<mlir::ArrayAttr>(constArr.getElts())) {
2479-
// Failed to use a compact attribute as an initializer:
2480-
// initialize elements individually.
2481-
if (!(init = lowerConstArrayAttr(constArr, getTypeConverter()))) {
2482-
createRegionInitializedLLVMGlobalOp(op, constArr, rewriter,
2483-
attributes);
2484-
return mlir::success();
2485-
}
2486-
} else {
2487-
op.emitError()
2488-
<< "unsupported lowering for #cir.const_array with value "
2489-
<< constArr.getElts();
2490-
return mlir::failure();
2491-
}
2492-
} else if (auto dataMemberAttr =
2493-
mlir::dyn_cast<cir::DataMemberAttr>(init.value())) {
2494-
assert(lowerMod && "lower module is not available");
2495-
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
2496-
mlir::TypedAttr abiValue = lowerMod->getCXXABI().lowerDataMemberConstant(
2497-
dataMemberAttr, layout, *typeConverter);
2498-
auto abiOp = mlir::cast<GlobalOp>(rewriter.clone(*op.getOperation()));
2499-
abiOp.setInitialValueAttr(abiValue);
2500-
abiOp.setSymType(abiValue.getType());
2501-
abiOp->setAttrs(attributes);
2502-
rewriter.replaceOp(op, abiOp);
2503-
return mlir::success();
2504-
} else {
2505-
op.emitError() << "unsupported initializer '" << init.value() << "'";
2506-
return mlir::failure();
2507-
}
2406+
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
2407+
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
2408+
mlir::Type llvmType, mlir::Attribute &init,
2409+
bool &useInitializerRegion) const {
2410+
if (!init)
2411+
return mlir::success();
2412+
2413+
if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init)) {
2414+
// If a directly equivalent attribute is available, use it.
2415+
return lowerInitializerDirect(rewriter, op, llvmType, init,
2416+
useInitializerRegion);
2417+
} else if (mlir::isa<cir::ZeroAttr, cir::ConstPtrAttr, cir::UndefAttr,
2418+
cir::ConstStructAttr, cir::GlobalViewAttr,
2419+
cir::VTableAttr, cir::TypeInfoAttr>(init)) {
2420+
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
2421+
// should be updated. For now, we use a custom op to initialize globals
2422+
// to the appropriate value.
2423+
useInitializerRegion = true;
2424+
return mlir::success();
2425+
} else if (mlir::isa<cir::ConstArrayAttr>(init)) {
2426+
return lowerInitializerForConstArray(rewriter, op, init,
2427+
useInitializerRegion);
2428+
} else if (auto dataMemberAttr = mlir::dyn_cast<cir::DataMemberAttr>(init)) {
2429+
assert(lowerMod && "lower module is not available");
2430+
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
2431+
mlir::TypedAttr abiValue = lowerMod->getCXXABI().lowerDataMemberConstant(
2432+
dataMemberAttr, layout, *typeConverter);
2433+
init = abiValue;
2434+
llvmType = convertTypeForMemory(*getTypeConverter(), dataLayout,
2435+
abiValue.getType());
2436+
// Recursively lower the CIR attribute produced by CXXABI.
2437+
return lowerInitializer(rewriter, op, llvmType, init, useInitializerRegion);
2438+
} else {
2439+
op.emitError() << "unsupported initializer '" << init << "'";
2440+
return mlir::failure();
2441+
}
2442+
llvm_unreachable("unreachable");
2443+
}
2444+
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray(
2445+
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
2446+
mlir::Attribute &init, bool &useInitializerRegion) const {
2447+
auto constArr = mlir::cast<cir::ConstArrayAttr>(init);
2448+
2449+
// Initializer is a constant array: convert it to a compatible llvm init.
2450+
if (auto attr = mlir::dyn_cast<mlir::StringAttr>(constArr.getElts())) {
2451+
llvm::SmallString<256> literal(attr.getValue());
2452+
if (constArr.getTrailingZerosNum())
2453+
literal.append(constArr.getTrailingZerosNum(), '\0');
2454+
init = rewriter.getStringAttr(literal);
2455+
useInitializerRegion = false;
2456+
return mlir::success();
2457+
} else if (auto attr = mlir::dyn_cast<mlir::ArrayAttr>(constArr.getElts())) {
2458+
// If failed to use a compact attribute as an initializer, we initialize
2459+
// elements individually.
2460+
if (auto val = lowerConstArrayAttr(constArr, getTypeConverter());
2461+
val.has_value()) {
2462+
init = val.value();
2463+
useInitializerRegion = false;
2464+
} else
2465+
useInitializerRegion = true;
2466+
return mlir::success();
2467+
} else {
2468+
op.emitError() << "unsupported lowering for #cir.const_array with value "
2469+
<< constArr.getElts();
2470+
return mlir::failure();
2471+
}
2472+
}
2473+
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerDirect(
2474+
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
2475+
mlir::Type llvmType, mlir::Attribute &init,
2476+
bool &useInitializerRegion) const {
2477+
init = llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init)
2478+
.Case<cir::FPAttr>([&](cir::FPAttr attr) {
2479+
return rewriter.getFloatAttr(llvmType, attr.getValue());
2480+
})
2481+
.Case<cir::IntAttr>([&](cir::IntAttr attr) {
2482+
return rewriter.getIntegerAttr(llvmType, attr.getValue());
2483+
})
2484+
.Case<cir::BoolAttr>([&](cir::BoolAttr attr) {
2485+
return rewriter.getBoolAttr(attr.getValue());
2486+
})
2487+
.Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
2488+
useInitializerRegion = false;
2489+
// If initRewriter returned a null attribute, init will have a value but
2490+
// the value will be null. If that happens, initRewriter didn't handle the
2491+
// attribute type. It probably needs to be added to
2492+
// GlobalInitAttrRewriter.
2493+
if (!init) {
2494+
op.emitError() << "unsupported initializer '" << init << "'";
2495+
return mlir::failure();
25082496
}
2497+
return mlir::success();
2498+
}
25092499

2500+
mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
2501+
cir::GlobalOp op, OpAdaptor adaptor,
2502+
mlir::ConversionPatternRewriter &rewriter) const {
2503+
2504+
// Fetch required values to create LLVM op.
2505+
const auto cirSymType = op.getSymType();
2506+
2507+
const auto llvmType =
2508+
convertTypeForMemory(*getTypeConverter(), dataLayout, cirSymType);
2509+
const auto isConst = op.getConstant();
2510+
const auto isDsoLocal = op.getDsolocal();
2511+
const auto linkage = convertLinkage(op.getLinkage());
2512+
const auto symbol = op.getSymName();
2513+
mlir::Attribute init = op.getInitialValueAttr();
2514+
2515+
SmallVector<mlir::NamedAttribute> attributes =
2516+
lowerGlobalAttributes(op, rewriter);
2517+
2518+
bool useInitializerRegion = false;
2519+
if (lowerInitializer(rewriter, op, llvmType, init, useInitializerRegion)
2520+
.failed())
2521+
return mlir::failure();
2522+
2523+
auto initValue = useInitializerRegion ? mlir::Attribute{} : init;
25102524
// Rewrite op.
25112525
auto llvmGlobalOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
2512-
op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()),
2526+
op, llvmType, isConst, linkage, symbol, initValue,
25132527
/*alignment*/ op.getAlignment().value_or(0),
25142528
/*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op),
25152529
/*dsoLocal*/ isDsoLocal, /*threadLocal*/ (bool)op.getTlsModelAttr(),
25162530
/*comdat*/ mlir::SymbolRefAttr(), attributes);
25172531

2532+
// Initialize the initializer region of LLVM global and update insertion point
2533+
// to the end of the initializer block.
2534+
if (useInitializerRegion) {
2535+
assert(init && "Expected initializer to use initializer region");
2536+
llvmGlobalOp.getInitializerRegion().push_back(new mlir::Block());
2537+
rewriter.setInsertionPointToEnd(llvmGlobalOp.getInitializerBlock());
2538+
2539+
rewriter.create<mlir::LLVM::ReturnOp>(
2540+
op->getLoc(),
2541+
lowerCirAttrAsValue(op, init, rewriter, typeConverter, dataLayout));
2542+
}
2543+
25182544
auto mod = op->getParentOfType<mlir::ModuleOp>();
25192545
if (op.getComdat())
25202546
addComdat(llvmGlobalOp, comdatOp, rewriter, mod);
@@ -4327,6 +4353,7 @@ void populateCIRToLLVMConversionPatterns(
43274353
CIRToLLVMVTTAddrPointOpLowering
43284354
#define GET_BUILTIN_LOWERING_LIST
43294355
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
4356+
43304357
#undef GET_BUILTIN_LOWERING_LIST
43314358
// clang-format on
43324359
>(converter, patterns.getContext());

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

+20-7
Original file line numberDiff line numberDiff line change
@@ -606,19 +606,32 @@ class CIRToLLVMGlobalOpLowering
606606
cir::LowerModule *lowerModule,
607607
mlir::DataLayout const &dataLayout)
608608
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
609-
dataLayout(dataLayout) {
610-
setHasBoundedRewriteRecursion();
611-
}
609+
dataLayout(dataLayout) {}
612610

613611
mlir::LogicalResult
614612
matchAndRewrite(cir::GlobalOp op, OpAdaptor,
615613
mlir::ConversionPatternRewriter &) const override;
616614

617615
private:
618-
void createRegionInitializedLLVMGlobalOp(
619-
cir::GlobalOp op, mlir::Attribute attr,
620-
mlir::ConversionPatternRewriter &rewriter,
621-
llvm::SmallVector<mlir::NamedAttribute> attributes) const;
616+
mlir::LogicalResult
617+
lowerInitializer(mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
618+
mlir::Type llvmType, mlir::Attribute &init,
619+
bool &useInitializerRegion) const;
620+
621+
mlir::LogicalResult
622+
lowerInitializerForConstArray(mlir::ConversionPatternRewriter &rewriter,
623+
cir::GlobalOp op, mlir::Attribute &init,
624+
bool &useInitializerRegion) const;
625+
626+
mlir::LogicalResult
627+
lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter,
628+
cir::GlobalOp op, mlir::Type llvmType,
629+
mlir::Attribute &init,
630+
bool &useInitializerRegion) const;
631+
632+
llvm::SmallVector<mlir::NamedAttribute>
633+
lowerGlobalAttributes(cir::GlobalOp op,
634+
mlir::ConversionPatternRewriter &rewriter) const;
622635

623636
mutable mlir::LLVM::ComdatOp comdatOp = nullptr;
624637
static void addComdat(mlir::LLVM::GlobalOp &op,

0 commit comments

Comments
 (0)