Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KIKIMR-22131] Handle potential race in computation pattern cache #16010

Merged
merged 1 commit into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions ydb/library/yql/dq/runtime/dq_tasks_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,23 @@ class TDqTaskRunner : public IDqTaskRunner {
bool canBeCached;
if (UseSeparatePatternAlloc(task) && Context.PatternCache) {
auto& cache = Context.PatternCache;
auto ticket = cache->FindOrSubscribe(program.GetRaw());
if (!ticket.HasFuture()) {
entry = CreateComputationPattern(task, program.GetRaw(), true, canBeCached);
if (canBeCached && entry->Pattern->GetSuitableForCache()) {
cache->EmplacePattern(task.GetProgram().GetRaw(), entry);
ticket.Close();
} else {
cache->IncNotSuitablePattern();
auto future = cache->FindOrSubscribe(program.GetRaw());
if (!future.HasValue()) {
try {
entry = CreateComputationPattern(task, program.GetRaw(), true, canBeCached);
if (canBeCached && entry->Pattern->GetSuitableForCache()) {
cache->EmplacePattern(task.GetProgram().GetRaw(), entry);
} else {
cache->IncNotSuitablePattern();
cache->NotifyPatternMissing(program.GetRaw());
}
} catch (...) {
// TODO: not sure if there may be exceptions in the first place.
cache->NotifyPatternMissing(program.GetRaw());
throw;
}
} else {
entry = ticket.GetValueSync();
entry = future.GetValueSync();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,29 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
return CurrentPatternsCompiledCodeSizeInBytes;
}

std::shared_ptr<TPatternCacheEntry>* Find(const TString& serializedProgram) {
TPatternCacheEntryPtr Find(const TString& serializedProgram) {
auto it = SerializedProgramToPatternCacheHolder.find(serializedProgram);
if (it == SerializedProgramToPatternCacheHolder.end()) {
return nullptr;
return {};
}

PromoteEntry(&it->second);

return &it->second.Entry;
return it->second.Entry;
}

void Insert(const TString& serializedProgram, std::shared_ptr<TPatternCacheEntry>& entry) {
void Insert(const TString& serializedProgram, TPatternCacheEntryPtr& entry) {
auto [it, inserted] = SerializedProgramToPatternCacheHolder.emplace(std::piecewise_construct,
std::forward_as_tuple(serializedProgram),
std::forward_as_tuple(serializedProgram, entry));

if (!inserted) {
RemoveEntryFromLists(&it->second);
entry = it->second.Entry;
} else {
entry->UpdateSizeForCache();
}

entry->UpdateSizeForCache();

/// New item is inserted, insert it in the back of both LRU lists and recalculate sizes
CurrentPatternsSizeBytes += entry->SizeForCache;
LRUPatternList.PushBack(&it->second);
Expand All @@ -69,15 +70,20 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
ClearIfNeeded();
}

void NotifyPatternCompiled(const TString & serializedProgram) {
void NotifyPatternCompiled(const TString& serializedProgram) {
auto it = SerializedProgramToPatternCacheHolder.find(serializedProgram);
if (it == SerializedProgramToPatternCacheHolder.end()) {
return;
}

const auto& entry = it->second.Entry;

Y_ASSERT(entry->Pattern->IsCompiled());
if (!entry->Pattern->IsCompiled()) {
// This is possible if the old entry got removed from cache while being compiled - and the new entry got in.
// TODO: add metrics for this inefficient cache usage.
// TODO: make this scenario more consistent - don't waste compilation result.
return;
}

if (it->second.LinkedInCompiledPatternLRUList()) {
return;
Expand Down Expand Up @@ -113,7 +119,7 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
* Most recently accessed items are in back of the lists, least recently accessed items are in front of the lists.
*/
struct TPatternCacheHolder : public TIntrusiveListItem<TPatternCacheHolder, TPatternLRUListTag>, TIntrusiveListItem<TPatternCacheHolder, TCompiledPatternLRUListTag> {
TPatternCacheHolder(TString serializedProgram, std::shared_ptr<TPatternCacheEntry> entry)
TPatternCacheHolder(TString serializedProgram, TPatternCacheEntryPtr entry)
: SerializedProgram(std::move(serializedProgram))
, Entry(std::move(entry))
{}
Expand All @@ -126,8 +132,8 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
return !TIntrusiveListItem<TPatternCacheHolder, TCompiledPatternLRUListTag>::Empty();
}

TString SerializedProgram;
std::shared_ptr<TPatternCacheEntry> Entry;
const TString SerializedProgram;
TPatternCacheEntryPtr Entry;
};

void PromoteEntry(TPatternCacheHolder* holder) {
Expand Down Expand Up @@ -228,52 +234,51 @@ TComputationPatternLRUCache::~TComputationPatternLRUCache() {
CleanCache();
}

std::shared_ptr<TPatternCacheEntry> TComputationPatternLRUCache::Find(const TString& serializedProgram) {
TPatternCacheEntryPtr TComputationPatternLRUCache::Find(const TString& serializedProgram) {
std::lock_guard<std::mutex> lock(Mutex);
if (auto it = Cache->Find(serializedProgram)) {
++*Hits;

if ((*it)->Pattern->IsCompiled())
if (it->Pattern->IsCompiled())
++*HitsCompiled;

return *it;
return it;
}

++*Misses;
return {};
}

TComputationPatternLRUCache::TTicket TComputationPatternLRUCache::FindOrSubscribe(const TString& serializedProgram) {
TPatternCacheEntryFuture TComputationPatternLRUCache::FindOrSubscribe(const TString& serializedProgram) {
std::lock_guard lock(Mutex);
if (auto it = Cache->Find(serializedProgram)) {
++*Hits;
AccessPattern(serializedProgram, *it);
return TTicket(serializedProgram, false, NThreading::MakeFuture<std::shared_ptr<TPatternCacheEntry>>(*it), nullptr);
AccessPattern(serializedProgram, it);
return NThreading::MakeFuture<TPatternCacheEntryPtr>(it);
}

auto [notifyIt, isNew] = Notify.emplace(serializedProgram, Nothing());
auto [notifyIt, isNew] = Notify.emplace(std::piecewise_construct, std::forward_as_tuple(serializedProgram), std::forward_as_tuple());
if (isNew) {
++*Misses;
return TTicket(serializedProgram, true, {}, this);
// First future is empty - so the subscriber can initiate the entry creation.
return {};
}

++*Waits;
auto promise = NThreading::NewPromise<std::shared_ptr<TPatternCacheEntry>>();
auto promise = NThreading::NewPromise<TPatternCacheEntryPtr>();
auto& subscribers = notifyIt->second;
if (!subscribers) {
subscribers.ConstructInPlace();
}
subscribers.push_back(promise);

subscribers->push_back(promise);
return TTicket(serializedProgram, false, promise, nullptr);
// Second and next futures are not empty - so subscribers can wait while first one creates the entry.
return promise;
}

void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgram, std::shared_ptr<TPatternCacheEntry> patternWithEnv) {
void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgram, TPatternCacheEntryPtr& patternWithEnv) {
Y_DEBUG_ABORT_UNLESS(patternWithEnv && patternWithEnv->Pattern);
TMaybe<TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>>> subscribers;
TVector<NThreading::TPromise<TPatternCacheEntryPtr>> subscribers;

{
std::lock_guard<std::mutex> lock(Mutex);
std::lock_guard lock(Mutex);
Cache->Insert(serializedProgram, patternWithEnv);

auto notifyIt = Notify.find(serializedProgram);
Expand All @@ -288,10 +293,8 @@ void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgra
*SizeCompiledBytes = Cache->PatternsCompiledCodeSizeInBytes();
}

if (subscribers) {
for (auto& subscriber : *subscribers) {
subscriber.SetValue(patternWithEnv);
}
for (auto& subscriber : subscribers) {
subscriber.SetValue(patternWithEnv);
}
}

Expand All @@ -300,6 +303,24 @@ void TComputationPatternLRUCache::NotifyPatternCompiled(const TString& serialize
Cache->NotifyPatternCompiled(serializedProgram);
}

void TComputationPatternLRUCache::NotifyPatternMissing(const TString& serializedProgram) {
TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>> subscribers;
{
std::lock_guard lock(Mutex);

auto notifyIt = Notify.find(serializedProgram);
if (notifyIt != Notify.end()) {
subscribers.swap(notifyIt->second);
Notify.erase(notifyIt);
}
}

for (auto& subscriber : subscribers) {
// It's part of API - to set nullptr as broken promise.
subscriber.SetValue(nullptr);
}
}

size_t TComputationPatternLRUCache::GetSize() const {
std::lock_guard lock(Mutex);
return Cache->PatternsSize();
Expand All @@ -314,7 +335,7 @@ void TComputationPatternLRUCache::CleanCache() {
Cache->Clear();
}

void TComputationPatternLRUCache::AccessPattern(const TString & serializedProgram, std::shared_ptr<TPatternCacheEntry> & entry) {
void TComputationPatternLRUCache::AccessPattern(const TString& serializedProgram, TPatternCacheEntryPtr entry) {
if (!Configuration.PatternAccessTimesBeforeTryToCompile || entry->Pattern->IsCompiled()) {
return;
}
Expand All @@ -326,22 +347,4 @@ void TComputationPatternLRUCache::AccessPattern(const TString & serializedProgra
}
}

void TComputationPatternLRUCache::NotifyMissing(const TString& serialized) {
TMaybe<TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>>> subscribers;
{
std::lock_guard<std::mutex> lock(Mutex);
auto notifyIt = Notify.find(serialized);
if (notifyIt != Notify.end()) {
subscribers.swap(notifyIt->second);
Notify.erase(notifyIt);
}
}

if (subscribers) {
for (auto& subscriber : *subscribers) {
subscriber.SetValue(nullptr);
}
}
}

} // namespace NKikimr::NMiniKQL
Original file line number Diff line number Diff line change
Expand Up @@ -53,43 +53,11 @@ struct TPatternCacheEntry {
}
};

using TPatternCacheEntryPtr = std::shared_ptr<TPatternCacheEntry>;
using TPatternCacheEntryFuture = NThreading::TFuture<TPatternCacheEntryPtr>;

class TComputationPatternLRUCache {
public:
class TTicket : private TNonCopyable {
public:
TTicket(const TString& serialized, bool isOwned, const NThreading::TFuture<std::shared_ptr<TPatternCacheEntry>>& future, TComputationPatternLRUCache* cache)
: Serialized(serialized)
, IsOwned(isOwned)
, Future(future)
, Cache(cache)
{}

~TTicket() {
if (Cache) {
Cache->NotifyMissing(Serialized);
}
}

bool HasFuture() const {
return !IsOwned;
}

std::shared_ptr<TPatternCacheEntry> GetValueSync() const {
Y_ABORT_UNLESS(HasFuture());
return Future.GetValueSync();
}

void Close() {
Cache = nullptr;
}

private:
const TString Serialized;
const bool IsOwned;
const NThreading::TFuture<std::shared_ptr<TPatternCacheEntry>> Future;
TComputationPatternLRUCache* Cache;
};

struct Config {
Config(size_t maxSizeBytes, size_t maxCompiledSizeBytes)
: MaxSizeBytes(maxSizeBytes)
Expand Down Expand Up @@ -120,17 +88,17 @@ class TComputationPatternLRUCache {

~TComputationPatternLRUCache();

static std::shared_ptr<TPatternCacheEntry> CreateCacheEntry(bool useAlloc = true) {
static TPatternCacheEntryPtr CreateCacheEntry(bool useAlloc = true) {
return std::make_shared<TPatternCacheEntry>(useAlloc);
}

std::shared_ptr<TPatternCacheEntry> Find(const TString& serializedProgram);
TPatternCacheEntryPtr Find(const TString& serializedProgram);
TPatternCacheEntryFuture FindOrSubscribe(const TString& serializedProgram);

TTicket FindOrSubscribe(const TString& serializedProgram);

void EmplacePattern(const TString& serializedProgram, std::shared_ptr<TPatternCacheEntry> patternWithEnv);
void EmplacePattern(const TString& serializedProgram, TPatternCacheEntryPtr& patternWithEnv);

void NotifyPatternCompiled(const TString& serializedProgram);
void NotifyPatternMissing(const TString& serializedProgram);

size_t GetSize() const;

Expand Down Expand Up @@ -159,27 +127,22 @@ class TComputationPatternLRUCache {
return PatternsToCompile.size();
}

void GetPatternsToCompile(THashMap<TString, std::shared_ptr<TPatternCacheEntry>> & result) {
void GetPatternsToCompile(THashMap<TString, TPatternCacheEntryPtr> & result) {
std::lock_guard lock(Mutex);
result.swap(PatternsToCompile);
}

private:
void AccessPattern(const TString & serializedProgram, std::shared_ptr<TPatternCacheEntry> & entry);

void NotifyMissing(const TString& serialized);
class TLRUPatternCacheImpl;

static constexpr size_t CacheMaxElementsSize = 10000;

friend class TTicket;
void AccessPattern(const TString& serializedProgram, TPatternCacheEntryPtr entry);

mutable std::mutex Mutex;
THashMap<TString, TMaybe<TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>>>> Notify;

class TLRUPatternCacheImpl;
std::unique_ptr<TLRUPatternCacheImpl> Cache;

THashMap<TString, std::shared_ptr<TPatternCacheEntry>> PatternsToCompile;
THashMap<TString, TVector<NThreading::TPromise<TPatternCacheEntryPtr>>> Notify; // protected by Mutex
std::unique_ptr<TLRUPatternCacheImpl> Cache; // protected by Mutex
THashMap<TString, TPatternCacheEntryPtr> PatternsToCompile; // protected by Mutex

const Config Configuration;

Expand Down
Loading