30
30
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31
31
#include " llvm/Support/Debug.h"
32
32
33
- // #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
33
+ #include " gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
34
34
35
35
namespace mlir {
36
36
namespace gc {
@@ -300,13 +300,13 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
300
300
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
301
301
// void deallocator(void *ptr) { std::free(ptr); }
302
302
303
- // std::shared_ptr<const_cache_proxy> create_const_cache_proxy(size_t size) {
304
- // // simply allocate buffer and return
305
- // std::shared_ptr<void> base =
306
- // std::shared_ptr<void>{std::aligned_alloc(64, size), [](void *p) {
307
- // std::free(p); }};
308
- // return std::make_shared<const_cache_proxy>(base, base.get(), size, true);
309
- // }
303
+ std::shared_ptr<const_cache_proxy> create_const_cache_proxy (size_t size) {
304
+ // simply allocate buffer and return
305
+ std::shared_ptr<void > base =
306
+ std::shared_ptr<void >{std::aligned_alloc (64 , size), [](void *p) {
307
+ std::free (p); }};
308
+ return std::make_shared<const_cache_proxy>(base, base.get (), size, true );
309
+ }
310
310
311
311
size_t divide_and_ceil (size_t x, size_t y) { return (x + y - 1 ) / y; }
312
312
@@ -330,12 +330,12 @@ struct const_graph_tensor_cache_manager {
330
330
total_size += divide_and_ceil (buffers_size[i], 64 ) * 64 ;
331
331
}
332
332
llvm::dbgs () << " Alloc total size: " << total_size << ' \n ' ;
333
- // auto base = create_const_cache_proxy(total_size);
333
+ auto base = create_const_cache_proxy (total_size);
334
334
std::vector<uint64_t > global_ids (buffers_size.size ());
335
335
size_t offset = 0 ;
336
336
for (size_t i = 0 ; i < buffers_size.size (); i++) {
337
337
llvm::dbgs () << " Alloc offset: " << offset << ' \n ' ;
338
- // reg_cached_tensor(cached_tensor_global_id, base, offset);
338
+ reg_cached_tensor (cached_tensor_global_id, base, offset);
339
339
global_ids[i] = cached_tensor_global_id;
340
340
++cached_tensor_global_id;
341
341
offset += divide_and_ceil (buffers_size[i], 64 ) * 64 ;
@@ -344,52 +344,52 @@ struct const_graph_tensor_cache_manager {
344
344
}
345
345
};
346
346
347
- // static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
348
- // StringRef name, int64_t value) {
349
- // OpBuilder::InsertionGuard insertGuard(builder);
350
- // builder.setInsertionPointToStart(module.getBody());
351
-
352
- // auto type = IntegerType::get(builder.getContext(), 8);
353
- // LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
354
- // loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
355
- // builder.getIndexAttr(value),
356
- // /*alignment=*/0);
357
- // }
358
-
359
- // static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
360
- // StringRef name, ArrayRef<int64_t> array) {
361
- // OpBuilder::InsertionGuard insertGuard(builder);
362
- // builder.setInsertionPointToStart(module.getBody());
347
+ static void addGlobal (ModuleOp module , Location loc, OpBuilder &builder,
348
+ StringRef name, int64_t value) {
349
+ OpBuilder::InsertionGuard insertGuard (builder);
350
+ builder.setInsertionPointToStart (module .getBody ());
363
351
364
- // auto type = LLVM::LLVMArrayType::get(
365
- // IntegerType::get(builder.getContext(), 8), array.size());
366
- // LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
367
- // loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
368
- // builder.getIndexArrayAttr(array),
369
- // /*alignment=*/0);
370
- // }
352
+ auto type = IntegerType::get (builder.getContext (), 8 );
353
+ LLVM::GlobalOp global = builder.create <LLVM::GlobalOp>(
354
+ loc, type, /* isConstant=*/ true , LLVM::Linkage::Internal, name,
355
+ builder.getIndexAttr (value),
356
+ /* alignment=*/ 0 );
357
+ }
371
358
372
359
static void addGlobalArray (ModuleOp module , Location loc, OpBuilder &builder,
373
360
StringRef name, ArrayRef<int64_t > array) {
374
361
OpBuilder::InsertionGuard insertGuard (builder);
375
362
builder.setInsertionPointToStart (module .getBody ());
376
363
377
- MemRefType type = MemRefType::Builder (array.size (), builder.getIndexType ());
378
- IntegerAttr memrefAlignment = IntegerAttr ();
379
- auto global = builder.create <memref::GlobalOp>(
380
- loc, name,
381
- /* sym_visibility=*/ builder.getStringAttr (" public" ),
382
- /* type=*/ type,
383
- /* initial_value=*/ builder.getIndexTensorAttr (array),
384
- /* constant=*/ true ,
385
- /* alignment=*/ memrefAlignment);
364
+ auto type = LLVM::LLVMArrayType::get (
365
+ IntegerType::get (builder.getContext (), 8 ), array.size ());
366
+ LLVM::GlobalOp global = builder.create <LLVM::GlobalOp>(
367
+ loc, type, /* isConstant=*/ true , LLVM::Linkage::Internal, name,
368
+ builder.getIndexArrayAttr (array),
369
+ /* alignment=*/ 0 );
386
370
}
387
371
388
- static void addGlobal (ModuleOp module , Location loc, OpBuilder &builder,
389
- StringRef name, int64_t value) {
390
- SmallVector<int64_t > array{value};
391
- addGlobalArray (module , loc, builder, name, array);
392
- }
372
+ // static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
373
+ // StringRef name, ArrayRef<int64_t> array) {
374
+ // OpBuilder::InsertionGuard insertGuard(builder);
375
+ // builder.setInsertionPointToStart(module.getBody());
376
+
377
+ // MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType());
378
+ // IntegerAttr memrefAlignment = IntegerAttr();
379
+ // auto global = builder.create<memref::GlobalOp>(
380
+ // loc, name,
381
+ // /*sym_visibility=*/builder.getStringAttr("public"),
382
+ // /*type=*/type,
383
+ // /*initial_value=*/builder.getIndexTensorAttr(array),
384
+ // /*constant=*/true,
385
+ // /*alignment=*/memrefAlignment);
386
+ // }
387
+
388
+ // static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
389
+ // StringRef name, int64_t value) {
390
+ // SmallVector<int64_t> array{value};
391
+ // addGlobalArray(module, loc, builder, name, array);
392
+ // }
393
393
394
394
// Operate on tensors. Create fold() and compute() on module. The
395
395
// folded weights and first-run flag is maintained by upper-level runtime.
0 commit comments