From 0e6814d45965d73807c25eac43bbe012a581278b Mon Sep 17 00:00:00 2001 From: Ivan Date: Tue, 25 Feb 2025 15:56:37 +0300 Subject: [PATCH] [KIKIMR-22131] Handle potential race in computation pattern cache (#14962) --- .../yql/dq/runtime/dq_tasks_runner.cpp | 24 ++-- .../mkql_computation_pattern_cache.cpp | 105 +++++++++--------- .../mkql_computation_pattern_cache.h | 65 +++-------- 3 files changed, 83 insertions(+), 111 deletions(-) diff --git a/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp b/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp index 6f909793c121..a8b08083fcfe 100644 --- a/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp +++ b/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp @@ -434,17 +434,23 @@ class TDqTaskRunner : public IDqTaskRunner { bool canBeCached; if (UseSeparatePatternAlloc(task) && Context.PatternCache) { auto& cache = Context.PatternCache; - auto ticket = cache->FindOrSubscribe(program.GetRaw()); - if (!ticket.HasFuture()) { - entry = CreateComputationPattern(task, program.GetRaw(), true, canBeCached); - if (canBeCached && entry->Pattern->GetSuitableForCache()) { - cache->EmplacePattern(task.GetProgram().GetRaw(), entry); - ticket.Close(); - } else { - cache->IncNotSuitablePattern(); + auto future = cache->FindOrSubscribe(program.GetRaw()); + if (!future.HasValue()) { + try { + entry = CreateComputationPattern(task, program.GetRaw(), true, canBeCached); + if (canBeCached && entry->Pattern->GetSuitableForCache()) { + cache->EmplacePattern(task.GetProgram().GetRaw(), entry); + } else { + cache->IncNotSuitablePattern(); + cache->NotifyPatternMissing(program.GetRaw()); + } + } catch (...) { + // TODO: not sure if there may be exceptions in the first place. + cache->NotifyPatternMissing(program.GetRaw()); + throw; } } else { - entry = ticket.GetValueSync(); + entry = future.GetValueSync(); } } diff --git a/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.cpp b/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.cpp index 1bd38fe24b73..1a2a490c5c2b 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.cpp +++ b/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.cpp @@ -33,28 +33,29 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl return CurrentPatternsCompiledCodeSizeInBytes; } - std::shared_ptr* Find(const TString& serializedProgram) { + TPatternCacheEntryPtr Find(const TString& serializedProgram) { auto it = SerializedProgramToPatternCacheHolder.find(serializedProgram); if (it == SerializedProgramToPatternCacheHolder.end()) { - return nullptr; + return {}; } PromoteEntry(&it->second); - return &it->second.Entry; + return it->second.Entry; } - void Insert(const TString& serializedProgram, std::shared_ptr& entry) { + void Insert(const TString& serializedProgram, TPatternCacheEntryPtr& entry) { auto [it, inserted] = SerializedProgramToPatternCacheHolder.emplace(std::piecewise_construct, std::forward_as_tuple(serializedProgram), std::forward_as_tuple(serializedProgram, entry)); if (!inserted) { RemoveEntryFromLists(&it->second); + entry = it->second.Entry; + } else { + entry->UpdateSizeForCache(); } - entry->UpdateSizeForCache(); - /// New item is inserted, insert it in the back of both LRU lists and recalculate sizes CurrentPatternsSizeBytes += entry->SizeForCache; LRUPatternList.PushBack(&it->second); @@ -69,7 +70,7 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl ClearIfNeeded(); } - void NotifyPatternCompiled(const TString & serializedProgram) { + void NotifyPatternCompiled(const TString& serializedProgram) { auto it = SerializedProgramToPatternCacheHolder.find(serializedProgram); if (it == SerializedProgramToPatternCacheHolder.end()) { return; @@ -77,7 +78,12 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl const auto& entry = it->second.Entry; - Y_ASSERT(entry->Pattern->IsCompiled()); + if (!entry->Pattern->IsCompiled()) { + // This is possible if the old entry got removed from cache while being compiled - and the new entry got in. + // TODO: add metrics for this inefficient cache usage. + // TODO: make this scenario more consistent - don't waste compilation result. + return; + } if (it->second.LinkedInCompiledPatternLRUList()) { return; @@ -113,7 +119,7 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl * Most recently accessed items are in back of the lists, least recently accessed items are in front of the lists. */ struct TPatternCacheHolder : public TIntrusiveListItem, TIntrusiveListItem { - TPatternCacheHolder(TString serializedProgram, std::shared_ptr entry) + TPatternCacheHolder(TString serializedProgram, TPatternCacheEntryPtr entry) : SerializedProgram(std::move(serializedProgram)) , Entry(std::move(entry)) {} @@ -126,8 +132,8 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl return !TIntrusiveListItem::Empty(); } - TString SerializedProgram; - std::shared_ptr Entry; + const TString SerializedProgram; + TPatternCacheEntryPtr Entry; }; void PromoteEntry(TPatternCacheHolder* holder) { @@ -228,52 +234,51 @@ TComputationPatternLRUCache::~TComputationPatternLRUCache() { CleanCache(); } -std::shared_ptr TComputationPatternLRUCache::Find(const TString& serializedProgram) { +TPatternCacheEntryPtr TComputationPatternLRUCache::Find(const TString& serializedProgram) { std::lock_guard lock(Mutex); if (auto it = Cache->Find(serializedProgram)) { ++*Hits; - if ((*it)->Pattern->IsCompiled()) + if (it->Pattern->IsCompiled()) ++*HitsCompiled; - return *it; + return it; } ++*Misses; return {}; } -TComputationPatternLRUCache::TTicket TComputationPatternLRUCache::FindOrSubscribe(const TString& serializedProgram) { +TPatternCacheEntryFuture TComputationPatternLRUCache::FindOrSubscribe(const TString& serializedProgram) { std::lock_guard lock(Mutex); if (auto it = Cache->Find(serializedProgram)) { ++*Hits; - AccessPattern(serializedProgram, *it); - return TTicket(serializedProgram, false, NThreading::MakeFuture>(*it), nullptr); + AccessPattern(serializedProgram, it); + return NThreading::MakeFuture(it); } - auto [notifyIt, isNew] = Notify.emplace(serializedProgram, Nothing()); + auto [notifyIt, isNew] = Notify.emplace(std::piecewise_construct, std::forward_as_tuple(serializedProgram), std::forward_as_tuple()); if (isNew) { ++*Misses; - return TTicket(serializedProgram, true, {}, this); + // First future is empty - so the subscriber can initiate the entry creation. + return {}; } ++*Waits; - auto promise = NThreading::NewPromise>(); + auto promise = NThreading::NewPromise(); auto& subscribers = notifyIt->second; - if (!subscribers) { - subscribers.ConstructInPlace(); - } + subscribers.push_back(promise); - subscribers->push_back(promise); - return TTicket(serializedProgram, false, promise, nullptr); + // Second and next futures are not empty - so subscribers can wait while first one creates the entry. + return promise; } -void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgram, std::shared_ptr patternWithEnv) { +void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgram, TPatternCacheEntryPtr& patternWithEnv) { Y_DEBUG_ABORT_UNLESS(patternWithEnv && patternWithEnv->Pattern); - TMaybe>>> subscribers; + TVector> subscribers; { - std::lock_guard lock(Mutex); + std::lock_guard lock(Mutex); Cache->Insert(serializedProgram, patternWithEnv); auto notifyIt = Notify.find(serializedProgram); @@ -288,10 +293,8 @@ void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgra *SizeCompiledBytes = Cache->PatternsCompiledCodeSizeInBytes(); } - if (subscribers) { - for (auto& subscriber : *subscribers) { - subscriber.SetValue(patternWithEnv); - } + for (auto& subscriber : subscribers) { + subscriber.SetValue(patternWithEnv); } } @@ -300,6 +303,24 @@ void TComputationPatternLRUCache::NotifyPatternCompiled(const TString& serialize Cache->NotifyPatternCompiled(serializedProgram); } +void TComputationPatternLRUCache::NotifyPatternMissing(const TString& serializedProgram) { + TVector>> subscribers; + { + std::lock_guard lock(Mutex); + + auto notifyIt = Notify.find(serializedProgram); + if (notifyIt != Notify.end()) { + subscribers.swap(notifyIt->second); + Notify.erase(notifyIt); + } + } + + for (auto& subscriber : subscribers) { + // It's part of API - to set nullptr as broken promise. + subscriber.SetValue(nullptr); + } +} + size_t TComputationPatternLRUCache::GetSize() const { std::lock_guard lock(Mutex); return Cache->PatternsSize(); @@ -314,7 +335,7 @@ void TComputationPatternLRUCache::CleanCache() { Cache->Clear(); } -void TComputationPatternLRUCache::AccessPattern(const TString & serializedProgram, std::shared_ptr & entry) { +void TComputationPatternLRUCache::AccessPattern(const TString& serializedProgram, TPatternCacheEntryPtr entry) { if (!Configuration.PatternAccessTimesBeforeTryToCompile || entry->Pattern->IsCompiled()) { return; } @@ -326,22 +347,4 @@ void TComputationPatternLRUCache::AccessPattern(const TString & serializedProgra } } -void TComputationPatternLRUCache::NotifyMissing(const TString& serialized) { - TMaybe>>> subscribers; - { - std::lock_guard lock(Mutex); - auto notifyIt = Notify.find(serialized); - if (notifyIt != Notify.end()) { - subscribers.swap(notifyIt->second); - Notify.erase(notifyIt); - } - } - - if (subscribers) { - for (auto& subscriber : *subscribers) { - subscriber.SetValue(nullptr); - } - } -} - } // namespace NKikimr::NMiniKQL diff --git a/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.h b/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.h index 3284690192fa..1c8645fdb5e7 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.h +++ b/ydb/library/yql/minikql/computation/mkql_computation_pattern_cache.h @@ -53,43 +53,11 @@ struct TPatternCacheEntry { } }; +using TPatternCacheEntryPtr = std::shared_ptr; +using TPatternCacheEntryFuture = NThreading::TFuture; + class TComputationPatternLRUCache { public: - class TTicket : private TNonCopyable { - public: - TTicket(const TString& serialized, bool isOwned, const NThreading::TFuture>& future, TComputationPatternLRUCache* cache) - : Serialized(serialized) - , IsOwned(isOwned) - , Future(future) - , Cache(cache) - {} - - ~TTicket() { - if (Cache) { - Cache->NotifyMissing(Serialized); - } - } - - bool HasFuture() const { - return !IsOwned; - } - - std::shared_ptr GetValueSync() const { - Y_ABORT_UNLESS(HasFuture()); - return Future.GetValueSync(); - } - - void Close() { - Cache = nullptr; - } - - private: - const TString Serialized; - const bool IsOwned; - const NThreading::TFuture> Future; - TComputationPatternLRUCache* Cache; - }; - struct Config { Config(size_t maxSizeBytes, size_t maxCompiledSizeBytes) : MaxSizeBytes(maxSizeBytes) @@ -120,17 +88,17 @@ class TComputationPatternLRUCache { ~TComputationPatternLRUCache(); - static std::shared_ptr CreateCacheEntry(bool useAlloc = true) { + static TPatternCacheEntryPtr CreateCacheEntry(bool useAlloc = true) { return std::make_shared(useAlloc); } - std::shared_ptr Find(const TString& serializedProgram); + TPatternCacheEntryPtr Find(const TString& serializedProgram); + TPatternCacheEntryFuture FindOrSubscribe(const TString& serializedProgram); - TTicket FindOrSubscribe(const TString& serializedProgram); - - void EmplacePattern(const TString& serializedProgram, std::shared_ptr patternWithEnv); + void EmplacePattern(const TString& serializedProgram, TPatternCacheEntryPtr& patternWithEnv); void NotifyPatternCompiled(const TString& serializedProgram); + void NotifyPatternMissing(const TString& serializedProgram); size_t GetSize() const; @@ -159,27 +127,22 @@ class TComputationPatternLRUCache { return PatternsToCompile.size(); } - void GetPatternsToCompile(THashMap> & result) { + void GetPatternsToCompile(THashMap & result) { std::lock_guard lock(Mutex); result.swap(PatternsToCompile); } private: - void AccessPattern(const TString & serializedProgram, std::shared_ptr & entry); - - void NotifyMissing(const TString& serialized); + class TLRUPatternCacheImpl; static constexpr size_t CacheMaxElementsSize = 10000; - friend class TTicket; + void AccessPattern(const TString& serializedProgram, TPatternCacheEntryPtr entry); mutable std::mutex Mutex; - THashMap>>>> Notify; - - class TLRUPatternCacheImpl; - std::unique_ptr Cache; - - THashMap> PatternsToCompile; + THashMap>> Notify; // protected by Mutex + std::unique_ptr Cache; // protected by Mutex + THashMap PatternsToCompile; // protected by Mutex const Config Configuration;