From 15b251dbf2edcb738a55dbf476b00575893f621e Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 20 Feb 2025 12:28:22 -0800 Subject: [PATCH 1/3] temporarily disable zero on device memory during kv_cache alloc --- src/models/kv_cache.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 957a7aba0..709d5ae38 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -187,7 +187,10 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) sb_kv_caches_.empty() ? OrtValue::CreateTensor(Allocator(), shape_, type_) : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); // Zero the memory so we don't leak any data from the previous run - ByteWrapTensor(Device(), *presents_.back()).Zero(); + if (Device().GetType() != DeviceType::WEBGPU) { + // ort c api does have a method to update device memory - temporarily disable. + ByteWrapTensor(Device(), *presents_.back()).Zero(); + } } } catch (const Ort::Exception&) { std::ostringstream oss; From f08f30aec0321cf7ac5f81e06c613c428d814b6d Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 20 Feb 2025 12:50:55 -0800 Subject: [PATCH 2/3] fix comment --- src/models/kv_cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 709d5ae38..36630fd43 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -188,7 +188,7 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); // Zero the memory so we don't leak any data from the previous run if (Device().GetType() != DeviceType::WEBGPU) { - // ort c api does have a method to update device memory - temporarily disable. + // ort c api does not have a method to update device memory - temporarily disable. ByteWrapTensor(Device(), *presents_.back()).Zero(); } } From ec1c6da88037ff6380015efa5f78c875267ad52d Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 20 Feb 2025 13:02:07 -0800 Subject: [PATCH 3/3] make the comment easier to understand Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> --- src/models/kv_cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 36630fd43..8ebe0a2a4 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -187,8 +187,8 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) sb_kv_caches_.empty() ? OrtValue::CreateTensor(Allocator(), shape_, type_) : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); // Zero the memory so we don't leak any data from the previous run + // WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now if (Device().GetType() != DeviceType::WEBGPU) { - // ort c api does not have a method to update device memory - temporarily disable. ByteWrapTensor(Device(), *presents_.back()).Zero(); } }