Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 34 additions & 76 deletions crates/vapp/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ use crate::{
storage::{Storage, StorageError, StorageKey, StorageValue},
};

/// The number of top layers to cache in [`MerkleStorage::compute_node`]. Caching only the top
/// layers avoids rebuilding expensive high-level subtree hashes across multiple [`proof`] calls
/// while keeping memory bounded. With 1M leaves in a 160-bit tree, caching 20 layers stores
/// ~1.7M entries (~136 MB) instead of ~142M entries (~22 GB) for the full cache.
const COMPUTE_NODE_CACHE_LAYERS: usize = 20;

/// Merkle tree with key type K and value type V.
///
/// This implementation supports `2^K::bits()` possible indices and uses sparse storage to
Expand Down Expand Up @@ -92,7 +98,7 @@ pub trait MerkleTreeHasher {

impl<K: StorageKey, V: StorageValue, H: MerkleTreeHasher> MerkleStorage<K, V, H> {
/// Compute the merkle root from scratch.
pub fn root(&mut self) -> B256 {
pub fn root(&self) -> B256 {
let num_bits = K::bits();

// If no leaves, return the precomputed empty tree root.
Expand Down Expand Up @@ -205,89 +211,41 @@ impl<K: StorageKey, V: StorageValue, H: MerkleTreeHasher> MerkleStorage<K, V, H>
}
}

/// Compute a node hash bottom-up, caching intermediate results.
fn compute_node(&mut self, target_layer: usize, target_index: U256) -> B256 {
// Build a stack of (layer, index) pairs that need to be computed.
let mut stack = Vec::new();
let mut to_compute = vec![(target_layer, target_index)];

// Find all nodes that need computation (not cached and not empty).
while let Some((layer, index)) = to_compute.pop() {
if layer == 0 {
// Leaf node - no dependencies.
continue;
}

if self.cache.contains_key(&(layer, index)) {
// Already cached.
continue;
}

if self.is_subtree_empty(layer, index) {
// Empty subtree - cache zero hash.
self.cache.insert((layer, index), self.zero_hashes[layer]);
continue;
}

// Add to computation stack.
stack.push((layer, index));

// Add children to be computed first.
let left_child = index << 1;
let right_child = left_child | U256::from(1);

to_compute.push((layer - 1, left_child));
to_compute.push((layer - 1, right_child));
/// Recursively compute the hash of the node at (`layer`, `index`), caching only the top
/// [`COMPUTE_NODE_CACHE_LAYERS`] layers to bound memory usage. The recursion depth equals
/// `K::bits()` (e.g. 160 for address/request-id keys), which is well within the default
/// thread stack size.
fn compute_node(&mut self, layer: usize, index: U256) -> B256 {
// Base case: leaf layer.
if layer == 0 {
return self
.leaves
.get(&index)
.map_or(self.zero_hashes[0], |v| H::hash(v));
}

// Compute hashes bottom-up.
while let Some((layer, index)) = stack.pop() {
if self.cache.contains_key(&(layer, index)) {
continue; // Already computed.
}

let left_child = index << 1;
let right_child = left_child | U256::from(1);
// Return cached value if available.
if let Some(&cached) = self.cache.get(&(layer, index)) {
return cached;
}

let left_hash = if layer == 1 {
// Children are leaves.
if let Some(value) = self.leaves.get(&left_child) {
H::hash(value)
} else {
self.zero_hashes[0]
}
} else {
// Children are internal nodes - should be cached now.
self.cache
.get(&(layer - 1, left_child))
.copied()
.unwrap_or(self.zero_hashes[layer - 1])
};
// Empty subtree short-circuit.
if self.is_subtree_empty(layer, index) {
return self.zero_hashes[layer];
}

let right_hash = if layer == 1 {
// Children are leaves.
if let Some(value) = self.leaves.get(&right_child) {
H::hash(value)
} else {
self.zero_hashes[0]
}
} else {
// Children are internal nodes - should be cached now.
self.cache
.get(&(layer - 1, right_child))
.copied()
.unwrap_or(self.zero_hashes[layer - 1])
};
// Recurse into left and right children.
let left_hash = self.compute_node(layer - 1, index << 1);
let right_hash = self.compute_node(layer - 1, (index << 1) | U256::from(1));
let hash = H::hash_pair(&left_hash, &right_hash);

let hash = H::hash_pair(&left_hash, &right_hash);
// Only cache the top layers to bound memory usage.
let cache_threshold = K::bits().saturating_sub(COMPUTE_NODE_CACHE_LAYERS);
if layer >= cache_threshold {
self.cache.insert((layer, index), hash);
}

// Return the computed hash.
self.cache
.get(&(target_layer, target_index))
.copied()
.unwrap_or(self.zero_hashes[target_layer])
hash
}

/// Get the set of keys that have been touched (read or written).
Expand Down
Loading