Skip to content

Commit 4363915

Browse files
committed
Fix visibility and type
1 parent 25f611e commit 4363915

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

lib/gc/Transforms/CST.cpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -343,41 +343,43 @@ struct constGraphTensorCacheManager {
343343
}
344344
};
345345

346-
static void addGlobalI32(ModuleOp module, Location loc, OpBuilder &builder,
346+
static void addGlobalI32(ModuleOp &module, Location loc, OpBuilder &builder,
347347
StringRef name, int32_t value) {
348348
OpBuilder::InsertionGuard insertGuard(builder);
349349
builder.setInsertionPointToStart(module.getBody());
350350

351351
auto type = IntegerType::get(builder.getContext(), 32);
352352
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
353-
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
353+
loc, type, /*isConstant=*/true, LLVM::Linkage::External, name,
354354
builder.getI32IntegerAttr(value),
355355
/*alignment=*/0);
356356
}
357357

358-
static void addGlobalI64Array(ModuleOp module, Location loc, OpBuilder &builder,
359-
StringRef name, ArrayRef<int64_t> array) {
358+
static void addGlobalI64Array(ModuleOp &module, Location loc,
359+
OpBuilder &builder, StringRef name,
360+
ArrayRef<int64_t> array) {
360361
OpBuilder::InsertionGuard insertGuard(builder);
361362
builder.setInsertionPointToStart(module.getBody());
362363

363364
auto type = LLVM::LLVMArrayType::get(
364365
IntegerType::get(builder.getContext(), 64), array.size());
365366
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
366-
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
367-
builder.getI64ArrayAttr(array),
367+
loc, type, /*isConstant=*/true, LLVM::Linkage::External, name,
368+
builder.getI64TensorAttr(array),
368369
/*alignment=*/0);
369370
}
370371

371-
static void addGlobalI32Array(ModuleOp module, Location loc, OpBuilder &builder,
372-
StringRef name, ArrayRef<int32_t> array) {
372+
static void addGlobalI32Array(ModuleOp &module, Location loc,
373+
OpBuilder &builder, StringRef name,
374+
ArrayRef<int32_t> array) {
373375
OpBuilder::InsertionGuard insertGuard(builder);
374376
builder.setInsertionPointToStart(module.getBody());
375377

376378
auto type = LLVM::LLVMArrayType::get(
377379
IntegerType::get(builder.getContext(), 32), array.size());
378380
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
379-
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
380-
builder.getI32ArrayAttr(array),
381+
loc, type, /*isConstant=*/true, LLVM::Linkage::External, name,
382+
builder.getI32TensorAttr(array),
381383
/*alignment=*/0);
382384
}
383385

@@ -493,7 +495,7 @@ void CST::runOnOperation() {
493495

494496
FunctionType foldFuncType =
495497
FunctionType::get(context, inputTypes, outputTypes);
496-
auto foldFunc =
498+
func::FuncOp foldFunc =
497499
builder.create<func::FuncOp>(topFunc.getLoc(), funcName, foldFuncType);
498500
Block *foldBlock = foldFunc.addEntryBlock();
499501
// values of folded constant weights in foldBlock
@@ -541,6 +543,8 @@ void CST::runOnOperation() {
541543
globalIndexes);
542544

543545
foldFunc.setVisibility(SymbolTable::Visibility::Public);
546+
foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
547+
UnitAttr::get(context));
544548
moduleOp.push_back(foldFunc);
545549
symbolTable.insert(foldFunc);
546550

0 commit comments

Comments
 (0)