Skip to content

Commit ae44418

Browse files
authored
[Offload] Erase entries from JIT cache when program is destroyed (#148847)
When `unloadBinary` is called, any entries in the JITEngine's cache for that binary will be cleared. This fixes a nasty issue with liboffload program handles. If two handles happen to have had the same address (after one was free'd, for example), the cache would be hit and return the wrong program.
1 parent 6adbbcc commit ae44418

File tree

3 files changed

+29
-12
lines changed

3 files changed

+29
-12
lines changed

offload/plugins-nextgen/common/include/JIT.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ struct JITEngine {
5555
process(const __tgt_device_image &Image,
5656
target::plugin::GenericDeviceTy &Device);
5757

58+
/// Remove \p Image from the jit engine's cache
59+
void erase(const __tgt_device_image &Image,
60+
target::plugin::GenericDeviceTy &Device);
61+
5862
private:
5963
/// Compile the bitcode image \p Image and generate the binary image that can
6064
/// be loaded to the target device of the triple \p Triple architecture \p
@@ -89,11 +93,13 @@ struct JITEngine {
8993
/// LLVM Context in which the modules will be constructed.
9094
LLVMContext Context;
9195

92-
/// Output images generated from LLVM backend.
93-
SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
96+
/// A map of embedded IR images to the buffer used to store JITed code
97+
DenseMap<const __tgt_device_image *, std::unique_ptr<MemoryBuffer>>
98+
JITImages;
9499

95100
/// A map of embedded IR images to JITed images.
96-
DenseMap<const __tgt_device_image *, __tgt_device_image *> TgtImageMap;
101+
DenseMap<const __tgt_device_image *, std::unique_ptr<__tgt_device_image>>
102+
TgtImageMap;
97103
};
98104

99105
/// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute

offload/plugins-nextgen/common/src/JIT.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ JITEngine::compile(const __tgt_device_image &Image,
285285

286286
// Check if we JITed this image for the given compute unit kind before.
287287
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
288-
if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
289-
return JITedImage;
288+
if (CUI.TgtImageMap.contains(&Image))
289+
return CUI.TgtImageMap[&Image].get();
290290

291291
auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind);
292292
if (!ObjMBOrErr)
@@ -296,17 +296,15 @@ JITEngine::compile(const __tgt_device_image &Image,
296296
if (!ImageMBOrErr)
297297
return ImageMBOrErr.takeError();
298298

299-
CUI.JITImages.push_back(std::move(*ImageMBOrErr));
300-
__tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
301-
JITedImage = new __tgt_device_image();
299+
CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)});
300+
auto &ImageMB = CUI.JITImages[&Image];
301+
CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()});
302+
auto &JITedImage = CUI.TgtImageMap[&Image];
302303
*JITedImage = Image;
303-
304-
auto &ImageMB = CUI.JITImages.back();
305-
306304
JITedImage->ImageStart = const_cast<char *>(ImageMB->getBufferStart());
307305
JITedImage->ImageEnd = const_cast<char *>(ImageMB->getBufferEnd());
308306

309-
return JITedImage;
307+
return JITedImage.get();
310308
}
311309

312310
Expected<const __tgt_device_image *>
@@ -324,3 +322,13 @@ JITEngine::process(const __tgt_device_image &Image,
324322

325323
return &Image;
326324
}
325+
326+
void JITEngine::erase(const __tgt_device_image &Image,
327+
target::plugin::GenericDeviceTy &Device) {
328+
std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
329+
const std::string &ComputeUnitKind = Device.getComputeUnitKind();
330+
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
331+
332+
CUI.TgtImageMap.erase(&Image);
333+
CUI.JITImages.erase(&Image);
334+
}

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,9 @@ Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) {
854854
return Err;
855855
}
856856

857+
if (Image->getTgtImageBitcode())
858+
Plugin.getJIT().erase(*Image->getTgtImageBitcode(), Image->getDevice());
859+
857860
return unloadBinaryImpl(Image);
858861
}
859862

0 commit comments

Comments
 (0)