Skip to content

Commit c3e186d

Browse files
committed
Use llvm global [need to cowork with yijie/mainfunc_wrapper]
1 parent 5eb0ac0 commit c3e186d

File tree

1 file changed

+47
-47
lines changed

1 file changed

+47
-47
lines changed

lib/gc/Transforms/CST.cpp

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3131
#include "llvm/Support/Debug.h"
3232

33-
// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
33+
#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
3434

3535
namespace mlir {
3636
namespace gc {
@@ -300,13 +300,13 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
300300
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
301301
// void deallocator(void *ptr) { std::free(ptr); }
302302

303-
// std::shared_ptr<const_cache_proxy> create_const_cache_proxy(size_t size) {
304-
// // simply allocate buffer and return
305-
// std::shared_ptr<void> base =
306-
// std::shared_ptr<void>{std::aligned_alloc(64, size), [](void *p) {
307-
// std::free(p); }};
308-
// return std::make_shared<const_cache_proxy>(base, base.get(), size, true);
309-
// }
303+
std::shared_ptr<const_cache_proxy> create_const_cache_proxy(size_t size) {
304+
// simply allocate buffer and return
305+
std::shared_ptr<void> base =
306+
std::shared_ptr<void>{std::aligned_alloc(64, size), [](void *p) {
307+
std::free(p); }};
308+
return std::make_shared<const_cache_proxy>(base, base.get(), size, true);
309+
}
310310

311311
size_t divide_and_ceil(size_t x, size_t y) { return (x + y - 1) / y; }
312312

@@ -330,12 +330,12 @@ struct const_graph_tensor_cache_manager {
330330
total_size += divide_and_ceil(buffers_size[i], 64) * 64;
331331
}
332332
llvm::dbgs() << "Alloc total size: " << total_size << '\n';
333-
// auto base = create_const_cache_proxy(total_size);
333+
auto base = create_const_cache_proxy(total_size);
334334
std::vector<uint64_t> global_ids(buffers_size.size());
335335
size_t offset = 0;
336336
for (size_t i = 0; i < buffers_size.size(); i++) {
337337
llvm::dbgs() << "Alloc offset: " << offset << '\n';
338-
// reg_cached_tensor(cached_tensor_global_id, base, offset);
338+
reg_cached_tensor(cached_tensor_global_id, base, offset);
339339
global_ids[i] = cached_tensor_global_id;
340340
++cached_tensor_global_id;
341341
offset += divide_and_ceil(buffers_size[i], 64) * 64;
@@ -344,52 +344,52 @@ struct const_graph_tensor_cache_manager {
344344
}
345345
};
346346

347-
// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
348-
// StringRef name, int64_t value) {
349-
// OpBuilder::InsertionGuard insertGuard(builder);
350-
// builder.setInsertionPointToStart(module.getBody());
351-
352-
// auto type = IntegerType::get(builder.getContext(), 8);
353-
// LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
354-
// loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
355-
// builder.getIndexAttr(value),
356-
// /*alignment=*/0);
357-
// }
358-
359-
// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
360-
// StringRef name, ArrayRef<int64_t> array) {
361-
// OpBuilder::InsertionGuard insertGuard(builder);
362-
// builder.setInsertionPointToStart(module.getBody());
347+
static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
348+
StringRef name, int64_t value) {
349+
OpBuilder::InsertionGuard insertGuard(builder);
350+
builder.setInsertionPointToStart(module.getBody());
363351

364-
// auto type = LLVM::LLVMArrayType::get(
365-
// IntegerType::get(builder.getContext(), 8), array.size());
366-
// LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
367-
// loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
368-
// builder.getIndexArrayAttr(array),
369-
// /*alignment=*/0);
370-
// }
352+
auto type = IntegerType::get(builder.getContext(), 8);
353+
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
354+
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
355+
builder.getIndexAttr(value),
356+
/*alignment=*/0);
357+
}
371358

372359
static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
373360
StringRef name, ArrayRef<int64_t> array) {
374361
OpBuilder::InsertionGuard insertGuard(builder);
375362
builder.setInsertionPointToStart(module.getBody());
376363

377-
MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType());
378-
IntegerAttr memrefAlignment = IntegerAttr();
379-
auto global = builder.create<memref::GlobalOp>(
380-
loc, name,
381-
/*sym_visibility=*/builder.getStringAttr("public"),
382-
/*type=*/type,
383-
/*initial_value=*/builder.getIndexTensorAttr(array),
384-
/*constant=*/true,
385-
/*alignment=*/memrefAlignment);
364+
auto type = LLVM::LLVMArrayType::get(
365+
IntegerType::get(builder.getContext(), 8), array.size());
366+
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
367+
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
368+
builder.getIndexArrayAttr(array),
369+
/*alignment=*/0);
386370
}
387371

388-
static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
389-
StringRef name, int64_t value) {
390-
SmallVector<int64_t> array{value};
391-
addGlobalArray(module, loc, builder, name, array);
392-
}
372+
// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
373+
// StringRef name, ArrayRef<int64_t> array) {
374+
// OpBuilder::InsertionGuard insertGuard(builder);
375+
// builder.setInsertionPointToStart(module.getBody());
376+
377+
// MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType());
378+
// IntegerAttr memrefAlignment = IntegerAttr();
379+
// auto global = builder.create<memref::GlobalOp>(
380+
// loc, name,
381+
// /*sym_visibility=*/builder.getStringAttr("public"),
382+
// /*type=*/type,
383+
// /*initial_value=*/builder.getIndexTensorAttr(array),
384+
// /*constant=*/true,
385+
// /*alignment=*/memrefAlignment);
386+
// }
387+
388+
// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
389+
// StringRef name, int64_t value) {
390+
// SmallVector<int64_t> array{value};
391+
// addGlobalArray(module, loc, builder, name, array);
392+
// }
393393

394394
// Operate on tensors. Create fold() and compute() on module. The
395395
// folded weights and first-run flag is maintained by upper-level runtime.

0 commit comments

Comments
 (0)