diff --git a/crates/chia-datalayer/src/lib.rs b/crates/chia-datalayer/src/lib.rs index f75124346..7ed47e100 100644 --- a/crates/chia-datalayer/src/lib.rs +++ b/crates/chia-datalayer/src/lib.rs @@ -1,3 +1,3 @@ mod merkle; -pub use merkle::{InsertLocation, MerkleBlob, Node, Side}; +pub use merkle::{InsertLocation, InternalNode, LeafNode, MerkleBlob, Side}; diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index 8fcc2d567..01abbc5fa 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -1,8 +1,7 @@ #[cfg(feature = "py-bindings")] use pyo3::{ - buffer::PyBuffer, - exceptions::{PyAttributeError, PyValueError}, - pyclass, pymethods, FromPyObject, IntoPy, PyObject, PyResult, Python, + buffer::PyBuffer, exceptions::PyValueError, pyclass, pymethods, FromPyObject, IntoPy, PyObject, + PyResult, Python, }; use clvmr::sha2::Sha256; @@ -73,9 +72,6 @@ pub enum Error { #[error("requested insertion at root but tree not empty")] UnableToInsertAsRootOfNonEmptyTree, - #[error("old leaf unexpectedly not a leaf")] - OldLeafUnexpectedlyNotALeaf, - #[error("unable to find a leaf")] UnableToFindALeaf, @@ -90,6 +86,9 @@ pub enum Error { #[error("block index out of range: {0:?}")] BlockIndexOutOfRange(TreeIndex), + + #[error("node not a leaf: {0:?}")] + NodeNotALeaf(InternalNode), } // assumptions @@ -235,152 +234,164 @@ impl NodeMetadata { } } -#[cfg_attr(feature = "py-bindings", pyclass(name = "Node", get_all))] -#[derive(Debug, PartialEq)] -pub struct Node { +fn parent_from_bytes(blob: &DataBytes) -> Parent { + let parent_integer = TreeIndex::from_bytes(&blob[PARENT_RANGE]); + match parent_integer { + NULL_PARENT => None, + _ => Some(parent_integer), + } +} + +fn hash_from_bytes(blob: &DataBytes) -> Hash { + blob[HASH_RANGE].try_into().unwrap() +} + +#[cfg_attr(feature = "py-bindings", pyclass(name = "InternalNode", get_all))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct InternalNode { parent: Parent, hash: Hash, - specific: NodeSpecific, + left: TreeIndex, + right: TreeIndex, } -// #[cfg_attr(feature = "py-bindings", pyclass(name = "NodeSpecific"))] -#[cfg_attr(feature = "py-bindings", pyclass(name = "NodeSpecific", get_all))] -#[derive(Clone, Debug, PartialEq)] -pub enum NodeSpecific { - Internal { left: TreeIndex, right: TreeIndex }, - Leaf { key: KvId, value: KvId }, -} +impl InternalNode { + #[allow(clippy::unnecessary_wraps)] + pub fn from_bytes(blob: &DataBytes) -> Result { + Ok(Self { + parent: parent_from_bytes(blob), + hash: hash_from_bytes(blob), + left: TreeIndex::from_bytes(&blob[LEFT_RANGE]), + right: TreeIndex::from_bytes(&blob[RIGHT_RANGE]), + }) + } + pub fn to_bytes(&self) -> DataBytes { + let mut blob: DataBytes = [0; DATA_SIZE]; + let parent_integer = self.parent.unwrap_or(NULL_PARENT); + blob[HASH_RANGE].copy_from_slice(&self.hash); + blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_bytes()); + blob[LEFT_RANGE].copy_from_slice(&self.left.to_bytes()); + blob[RIGHT_RANGE].copy_from_slice(&self.right.to_bytes()); -impl NodeSpecific { - // TODO: methods that only handle one variant seem kinda smelly to me, am i right? - pub fn sibling_index(&self, index: TreeIndex) -> TreeIndex { - let NodeSpecific::Internal { right, left } = self else { - panic!("unable to get sibling index from a leaf") - }; + blob + } - if index == *right { - *left - } else if index == *left { - *right + pub fn sibling_index(&self, index: TreeIndex) -> TreeIndex { + if index == self.right { + self.left + } else if index == self.left { + self.right } else { panic!("index not a child: {index}") } } } -impl Node { +#[cfg_attr(feature = "py-bindings", pyclass(name = "LeafNode", get_all))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LeafNode { + parent: Parent, + hash: Hash, + key: KvId, + value: KvId, +} + +impl LeafNode { #[allow(clippy::unnecessary_wraps)] - pub fn from_bytes(metadata: &NodeMetadata, blob: DataBytes) -> Result { + pub fn from_bytes(blob: &DataBytes) -> Result { Ok(Self { - parent: Self::parent_from_bytes(&blob), - hash: Self::hash_from_bytes(&blob), - specific: match metadata.node_type { - NodeType::Internal => NodeSpecific::Internal { - left: TreeIndex::from_bytes(&blob[LEFT_RANGE]), - right: TreeIndex::from_bytes(&blob[RIGHT_RANGE]), - }, - NodeType::Leaf => NodeSpecific::Leaf { - key: KvId::from_be_bytes(blob[KEY_RANGE].try_into().unwrap()), - value: KvId::from_be_bytes(blob[VALUE_RANGE].try_into().unwrap()), - }, - }, + parent: parent_from_bytes(blob), + hash: hash_from_bytes(blob), + key: KvId::from_be_bytes(blob[KEY_RANGE].try_into().unwrap()), + value: KvId::from_be_bytes(blob[VALUE_RANGE].try_into().unwrap()), }) } - fn parent_from_bytes(blob: &DataBytes) -> Parent { - let parent_integer = TreeIndex::from_bytes(&blob[PARENT_RANGE]); - match parent_integer { - NULL_PARENT => None, - _ => Some(parent_integer), - } - } - - fn hash_from_bytes(blob: &DataBytes) -> Hash { - blob[HASH_RANGE].try_into().unwrap() - } - pub fn to_bytes(&self) -> DataBytes { let mut blob: DataBytes = [0; DATA_SIZE]; - match self { - Node { - parent, - specific: NodeSpecific::Internal { left, right }, - hash, - } => { - let parent_integer = match parent { - None => NULL_PARENT, - Some(parent) => *parent, - }; - blob[HASH_RANGE].copy_from_slice(hash); - blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_bytes()); - blob[LEFT_RANGE].copy_from_slice(&left.to_bytes()); - blob[RIGHT_RANGE].copy_from_slice(&right.to_bytes()); - } - Node { - parent, - specific: NodeSpecific::Leaf { key, value }, - hash, - } => { - let parent_integer = match parent { - None => NULL_PARENT, - Some(parent) => *parent, - }; - blob[HASH_RANGE].copy_from_slice(hash); - blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_bytes()); - blob[KEY_RANGE].copy_from_slice(&key.to_be_bytes()); - blob[VALUE_RANGE].copy_from_slice(&value.to_be_bytes()); - } - } + let parent_integer = self.parent.unwrap_or(NULL_PARENT); + blob[HASH_RANGE].copy_from_slice(&self.hash); + blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_bytes()); + blob[KEY_RANGE].copy_from_slice(&self.key.to_be_bytes()); + blob[VALUE_RANGE].copy_from_slice(&self.value.to_be_bytes()); blob } } -#[cfg(feature = "py-bindings")] -#[pymethods] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Node { + Internal(InternalNode), + Leaf(LeafNode), +} + impl Node { - #[getter(left)] - pub fn py_property_left(&self) -> PyResult { - let NodeSpecific::Internal { left, .. } = self.specific else { - return Err(PyAttributeError::new_err( - "Attribute 'left' not present for leaf nodes".to_string(), - )); - }; + fn parent(&self) -> Parent { + match self { + Node::Internal(node) => node.parent, + Node::Leaf(node) => node.parent, + } + } - Ok(left) + fn set_parent(&mut self, parent: Parent) { + match self { + Node::Internal(node) => node.parent = parent, + Node::Leaf(node) => node.parent = parent, + } } - #[getter(right)] - pub fn py_property_right(&self) -> PyResult { - let NodeSpecific::Internal { right, .. } = self.specific else { - return Err(PyAttributeError::new_err( - "Attribute 'right' not present for leaf nodes".to_string(), - )); - }; + fn hash(&self) -> Hash { + match self { + Node::Internal(node) => node.hash, + Node::Leaf(node) => node.hash, + } + } - Ok(right) + fn set_hash(&mut self, hash: &Hash) { + match self { + Node::Internal(ref mut node) => node.hash = *hash, + Node::Leaf(ref mut node) => node.hash = *hash, + } } - #[getter(key)] - pub fn py_property_key(&self) -> PyResult { - let NodeSpecific::Leaf { key, .. } = self.specific else { - return Err(PyAttributeError::new_err( - "Attribute 'key' not present for internal nodes".to_string(), - )); - }; + pub fn from_bytes(metadata: &NodeMetadata, blob: &DataBytes) -> Result { + Ok(match metadata.node_type { + NodeType::Internal => Node::Internal(InternalNode::from_bytes(blob)?), + NodeType::Leaf => Node::Leaf(LeafNode::from_bytes(blob)?), + }) + } - Ok(key) + pub fn to_bytes(&self) -> DataBytes { + match self { + Node::Internal(node) => node.to_bytes(), + Node::Leaf(node) => node.to_bytes(), + } } - #[getter(value)] - pub fn py_property_value(&self) -> PyResult { - let NodeSpecific::Leaf { value, .. } = self.specific else { - return Err(PyAttributeError::new_err( - "Attribute 'value' not present for internal nodes".to_string(), - )); + fn expect_leaf(self, message: &str) -> LeafNode { + let Node::Leaf(leaf) = self else { + let message = message.replace("<>", &format!("{self:?}")); + panic!("{}", message) }; - Ok(value) + leaf + } + + fn try_into_leaf(self) -> Result { + match self { + Node::Leaf(leaf) => Ok(leaf), + Node::Internal(internal) => Err(Error::NodeNotALeaf(internal)), + } + } +} + +#[cfg(feature = "py-bindings")] +impl IntoPy for Node { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + Node::Internal(node) => node.into_py(py), + Node::Leaf(node) => node.into_py(py), + } } } @@ -408,14 +419,14 @@ impl Block { let data_blob: DataBytes = blob[DATA_RANGE].try_into().unwrap(); let metadata = NodeMetadata::from_bytes(metadata_blob) .map_err(|message| Error::FailedLoadingMetadata(message.to_string()))?; - let node = Node::from_bytes(&metadata, data_blob) + let node = Node::from_bytes(&metadata, &data_blob) .map_err(|message| Error::FailedLoadingNode(message.to_string()))?; Ok(Block { metadata, node }) } pub fn update_hash(&mut self, left: &Hash, right: &Hash) { - self.node.hash = internal_hash(left, right); + self.node.set_hash(&internal_hash(left, right)); self.metadata.dirty = false; } } @@ -431,8 +442,8 @@ fn get_free_indexes_and_keys_values_indexes( for (index, block) in MerkleBlobLeftChildFirstIterator::new(blob) { seen_indexes[index.0 as usize] = true; - if let NodeSpecific::Leaf { key, .. } = block.node.specific { - key_to_index.insert(key, index); + if let Node::Leaf(leaf) = block.node { + key_to_index.insert(leaf.key, index); } } @@ -504,20 +515,18 @@ impl MerkleBlob { self.insert_first(key, value, hash)?; } InsertLocation::Leaf { index, side } => { - let old_leaf = self.get_node(index)?; - let NodeSpecific::Leaf { .. } = old_leaf.specific else { - panic!("requested insertion at leaf but found internal node") - }; + let old_leaf = self.get_node(index)?.try_into_leaf()?; let internal_node_hash = match side { Side::Left => internal_hash(hash, &old_leaf.hash), Side::Right => internal_hash(&old_leaf.hash, hash), }; - let node = Node { + let node = LeafNode { parent: None, hash: *hash, - specific: NodeSpecific::Leaf { key, value }, + key, + value, }; if self.key_to_index.len() == 1 { @@ -537,11 +546,12 @@ impl MerkleBlob { node_type: NodeType::Leaf, dirty: false, }, - node: Node { + node: Node::Leaf(LeafNode { parent: None, - specific: NodeSpecific::Leaf { key, value }, + key, + value, hash: *hash, - }, + }), }; self.clear(); @@ -552,8 +562,8 @@ impl MerkleBlob { fn insert_second( &mut self, - mut node: Node, - old_leaf: &Node, + mut node: LeafNode, + old_leaf: &LeafNode, internal_node_hash: &Hash, side: &Side, ) -> Result<(), Error> { @@ -567,26 +577,16 @@ impl MerkleBlob { node_type: NodeType::Internal, dirty: false, }, - node: Node { + node: Node::Internal(InternalNode { parent: None, - specific: NodeSpecific::Internal { - left: left_index, - right: right_index, - }, + left: left_index, + right: right_index, hash: *internal_node_hash, - }, + }), }; self.insert_entry_to_blob(root_index, &new_internal_block)?; - let NodeSpecific::Leaf { - key: old_leaf_key, - value: old_leaf_value, - } = old_leaf.specific - else { - return Err(Error::OldLeafUnexpectedlyNotALeaf); - }; - node.parent = Some(TreeIndex(0)); let nodes = [ @@ -595,12 +595,10 @@ impl MerkleBlob { Side::Left => right_index, Side::Right => left_index, }, - Node { + LeafNode { parent: Some(TreeIndex(0)), - specific: NodeSpecific::Leaf { - key: old_leaf_key, - value: old_leaf_value, - }, + key: old_leaf.key, + value: old_leaf.value, hash: old_leaf.hash, }, ), @@ -619,7 +617,7 @@ impl MerkleBlob { node_type: NodeType::Leaf, dirty: false, }, - node, + node: Node::Leaf(node), }; self.insert_entry_to_blob(index, &block)?; @@ -630,8 +628,8 @@ impl MerkleBlob { fn insert_third_or_later( &mut self, - mut node: Node, - old_leaf: &Node, + mut node: LeafNode, + old_leaf: &LeafNode, old_leaf_index: TreeIndex, internal_node_hash: &Hash, side: &Side, @@ -646,7 +644,7 @@ impl MerkleBlob { node_type: NodeType::Leaf, dirty: false, }, - node, + node: Node::Leaf(node), }; self.insert_entry_to_blob(new_leaf_index, &new_leaf_block)?; @@ -659,14 +657,12 @@ impl MerkleBlob { node_type: NodeType::Internal, dirty: false, }, - node: Node { + node: Node::Internal(InternalNode { parent: old_leaf.parent, - specific: NodeSpecific::Internal { - left: left_index, - right: right_index, - }, + left: left_index, + right: right_index, hash: *internal_node_hash, - }, + }), }; self.insert_entry_to_blob(new_internal_node_index, &new_internal_block)?; @@ -677,16 +673,11 @@ impl MerkleBlob { self.update_parent(old_leaf_index, Some(new_internal_node_index))?; let mut old_parent_block = self.get_block(old_parent_index)?; - if let NodeSpecific::Internal { - ref mut left, - ref mut right, - .. - } = old_parent_block.node.specific - { - if old_leaf_index == *left { - *left = new_internal_node_index; - } else if old_leaf_index == *right { - *right = new_internal_node_index; + if let Node::Internal(ref mut internal_node, ..) = old_parent_block.node { + if old_leaf_index == internal_node.left { + internal_node.left = new_internal_node_index; + } else if old_leaf_index == internal_node.right { + internal_node.right = new_internal_node_index; } else { panic!("child not a child of its parent"); } @@ -724,11 +715,12 @@ impl MerkleBlob { node_type: NodeType::Leaf, dirty: false, }, - node: Node { + node: Node::Leaf(LeafNode { parent: None, hash, - specific: NodeSpecific::Leaf { key, value }, - }, + key, + value, + }), }; self.insert_entry_to_blob(new_leaf_index, &new_block)?; indexes.push(new_leaf_index); @@ -753,22 +745,23 @@ impl MerkleBlob { let new_internal_node_index = self.get_new_index(); - let block_1 = self.update_parent(index_1, Some(new_internal_node_index))?; - let block_2 = self.update_parent(index_2, Some(new_internal_node_index))?; + let mut hashes = vec![]; + for index in [index_1, index_2] { + let block = self.update_parent(index, Some(new_internal_node_index))?; + hashes.push(block.node.hash()); + } let new_block = Block { metadata: NodeMetadata { node_type: NodeType::Internal, dirty: false, }, - node: Node { + node: Node::Internal(InternalNode { parent: None, - hash: internal_hash(&block_1.node.hash, &block_2.node.hash), - specific: NodeSpecific::Internal { - left: index_1, - right: index_2, - }, - }, + hash: internal_hash(&hashes[0], &hashes[1]), + left: index_1, + right: index_2, + }), }; self.insert_entry_to_blob(new_internal_node_index, &new_block)?; @@ -781,18 +774,15 @@ impl MerkleBlob { if indexes.len() == 1 { // OPT: can we avoid this extra min height leaf traversal? let min_height_leaf = self.get_min_height_leaf()?; - let NodeSpecific::Leaf { key, .. } = min_height_leaf.node.specific else { - panic!() - }; - self.insert_from_leaf(self.key_to_index[&key], indexes[0], &Side::Left)?; + self.insert_from_key(min_height_leaf.key, indexes[0], &Side::Left)?; }; Ok(()) } - fn insert_from_leaf( + fn insert_from_key( &mut self, - old_leaf_index: TreeIndex, + old_leaf_key: KvId, new_index: TreeIndex, side: &Side, ) -> Result<(), Error> { @@ -809,12 +799,12 @@ impl MerkleBlob { } let new_internal_node_index = self.get_new_index(); - let old_leaf = self.get_node(old_leaf_index)?; + let (old_leaf_index, old_leaf) = self.get_leaf_by_key(old_leaf_key)?; let new_node = self.get_node(new_index)?; let new_stuff = Stuff { index: new_index, - hash: new_node.hash, + hash: new_node.hash(), }; let old_stuff = Stuff { index: old_leaf_index, @@ -831,14 +821,12 @@ impl MerkleBlob { node_type: NodeType::Internal, dirty: false, }, - node: Node { + node: Node::Internal(InternalNode { parent: old_leaf.parent, hash: internal_node_hash, - specific: NodeSpecific::Internal { - left: left.index, - right: right.index, - }, - }, + left: left.index, + right: right.index, + }), }; self.insert_entry_to_blob(new_internal_node_index, &block)?; self.update_parent(new_index, Some(new_internal_node_index))?; @@ -849,15 +837,10 @@ impl MerkleBlob { }; let mut parent = self.get_block(old_leaf_parent)?; - if let NodeSpecific::Internal { - ref mut left, - ref mut right, - .. - } = parent.node.specific - { + if let Node::Internal(ref mut internal) = parent.node { match old_leaf_index { - x if x == *left => *left = new_internal_node_index, - x if x == *right => *right = new_internal_node_index, + x if x == internal.left => internal.left = new_internal_node_index, + x if x == internal.right => internal.right = new_internal_node_index, _ => panic!("parent not a child a grandparent"), } } else { @@ -869,21 +852,18 @@ impl MerkleBlob { Ok(()) } - fn get_min_height_leaf(&self) -> Result { - MerkleBlobBreadthFirstIterator::new(&self.blob) + fn get_min_height_leaf(&self) -> Result { + let block = MerkleBlobBreadthFirstIterator::new(&self.blob) .next() - .ok_or(Error::UnableToFindALeaf) + .ok_or(Error::UnableToFindALeaf)?; + + Ok(block + .node + .expect_leaf("unexpectedly found internal node first: <>")) } pub fn delete(&mut self, key: KvId) -> Result<(), Error> { - let leaf_index = *self.key_to_index.get(&key).ok_or(Error::UnknownKey(key))?; - let leaf = self.get_node(leaf_index)?; - - // TODO: maybe some common way to indicate/perform sanity double checks? - // maybe this disappears with unit variants and structs for the data - let NodeSpecific::Leaf { .. } = leaf.specific else { - panic!("key to index cache resulted in internal node") - }; + let (leaf_index, leaf) = self.get_leaf_by_key(key)?; self.key_to_index.remove(&key); let Some(parent_index) = leaf.parent else { @@ -892,16 +872,19 @@ impl MerkleBlob { }; self.free_indexes.insert(leaf_index); - let parent = self.get_node(parent_index)?; - let sibling_index = parent.specific.sibling_index(leaf_index); + let maybe_parent = self.get_node(parent_index)?; + let Node::Internal(parent) = maybe_parent else { + panic!("parent node not internal: {maybe_parent:?}") + }; + let sibling_index = parent.sibling_index(leaf_index); let mut sibling_block = self.get_block(sibling_index)?; let Some(grandparent_index) = parent.parent else { - sibling_block.node.parent = None; + sibling_block.node.set_parent(None); self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?; - if let NodeSpecific::Internal { left, right } = sibling_block.node.specific { - for child_index in [left, right] { + if let Node::Internal(node) = sibling_block.node { + for child_index in [node.left, node.right] { self.update_parent(child_index, Some(TreeIndex(0)))?; } }; @@ -914,18 +897,13 @@ impl MerkleBlob { self.free_indexes.insert(parent_index); let mut grandparent_block = self.get_block(grandparent_index)?; - sibling_block.node.parent = Some(grandparent_index); + sibling_block.node.set_parent(Some(grandparent_index)); self.insert_entry_to_blob(sibling_index, &sibling_block)?; - if let NodeSpecific::Internal { - ref mut left, - ref mut right, - .. - } = grandparent_block.node.specific - { + if let Node::Internal(ref mut internal) = grandparent_block.node { match parent_index { - x if x == *left => *left = sibling_index, - x if x == *right => *right = sibling_index, + x if x == internal.left => internal.left = sibling_index, + x if x == internal.right => internal.right = sibling_index, _ => panic!("parent not a child a grandparent"), } } else { @@ -945,19 +923,16 @@ impl MerkleBlob { }; let mut block = self.get_block(*leaf_index)?; - if let NodeSpecific::Leaf { - value: ref mut inplace_value, - .. - } = block.node.specific - { - block.node.hash.clone_from(new_hash); - *inplace_value = value; - } else { - panic!("expected internal node but found leaf"); - } + // TODO: repeated message + let mut leaf = block.node.clone().expect_leaf(&format!( + "expected leaf for index from key cache: {leaf_index} -> <>" + )); + leaf.hash.clone_from(new_hash); + leaf.value = value; + block.node = Node::Leaf(leaf); self.insert_entry_to_blob(*leaf_index, &block)?; - if let Some(parent) = block.node.parent { + if let Some(parent) = block.node.parent() { self.mark_lineage_as_dirty(parent)?; } @@ -970,21 +945,22 @@ impl MerkleBlob { let mut child_to_parent: HashMap = HashMap::new(); for (index, block) in MerkleBlobParentFirstIterator::new(&self.blob) { - if let Some(parent) = block.node.parent { + if let Some(parent) = block.node.parent() { assert_eq!(child_to_parent.remove(&index), Some(parent)); } - match block.node.specific { - NodeSpecific::Internal { left, right } => { + match block.node { + Node::Internal(node) => { internal_count += 1; - child_to_parent.insert(left, index); - child_to_parent.insert(right, index); + child_to_parent.insert(node.left, index); + child_to_parent.insert(node.right, index); } - NodeSpecific::Leaf { key, .. } => { + Node::Leaf(node) => { leaf_count += 1; let cached_index = self .key_to_index - .get(&key) - .ok_or(Error::IntegrityKeyNotInCache(key))?; + .get(&node.key) + .ok_or(Error::IntegrityKeyNotInCache(node.key))?; + let key = node.key; assert_eq!( *cached_index, index, "key to index cache for {key:?} should be {index:?} got: {cached_index:?}" @@ -1017,7 +993,7 @@ impl MerkleBlob { parent: Option, ) -> Result { let mut block = self.get_block(index)?; - block.node.parent = parent; + block.node.set_parent(parent); self.insert_entry_to_blob(index, &block)?; Ok(block) @@ -1035,7 +1011,7 @@ impl MerkleBlob { block.metadata.dirty = true; self.insert_entry_to_blob(this_index, &block)?; - next_index = block.node.parent; + next_index = block.node.parent(); } Ok(()) @@ -1079,15 +1055,19 @@ impl MerkleBlob { loop { for byte in &seed_bytes { for bit in 0..8 { - match node.specific { - NodeSpecific::Leaf { .. } => { + match node { + Node::Leaf { .. } => { return Ok(InsertLocation::Leaf { index: next_index, side, }) } - NodeSpecific::Internal { left, right, .. } => { - next_index = if byte & (1 << bit) != 0 { left } else { right }; + Node::Internal(internal) => { + next_index = if byte & (1 << bit) != 0 { + internal.left + } else { + internal.right + }; node = self.get_node(next_index)?; } } @@ -1128,19 +1108,16 @@ impl MerkleBlob { && old_block.metadata.node_type == NodeType::Leaf { // TODO: sort of repeating the leaf check above and below. smells a little - if let NodeSpecific::Leaf { - key: old_block_key, .. - } = old_block.node.specific - { - self.key_to_index.remove(&old_block_key); + if let Node::Leaf(old_node) = old_block.node { + self.key_to_index.remove(&old_node.key); }; }; self.blob[block_range(index)].copy_from_slice(&new_block_bytes); } } - if let NodeSpecific::Leaf { key, .. } = block.node.specific { - self.key_to_index.insert(key, index); + if let Node::Leaf(ref node) = block.node { + self.key_to_index.insert(node.key, index); }; self.free_indexes.take(&index); @@ -1156,7 +1133,7 @@ impl MerkleBlob { let block_bytes = self.get_block_bytes(index)?; let data_bytes: DataBytes = block_bytes[DATA_RANGE].try_into().unwrap(); - Ok(Node::hash_from_bytes(&data_bytes)) + Ok(hash_from_bytes(&data_bytes)) } fn get_block_bytes(&self, index: TreeIndex) -> Result { @@ -1172,12 +1149,19 @@ impl MerkleBlob { Ok(self.get_block(index)?.node) } + pub fn get_leaf_by_key(&self, key: KvId) -> Result<(TreeIndex, LeafNode), Error> { + let index = *self.key_to_index.get(&key).ok_or(Error::UnknownKey(key))?; + let leaf = self.get_node(index)?.expect_leaf(&format!( + "expected leaf for index from key cache: {index} -> <>" + )); + + Ok((index, leaf)) + } + pub fn get_parent_index(&self, index: TreeIndex) -> Result { let block = self.get_block_bytes(index)?; - Ok(Node::parent_from_bytes( - block[DATA_RANGE].try_into().unwrap(), - )) + Ok(parent_from_bytes(block[DATA_RANGE].try_into().unwrap())) } pub fn get_lineage_with_indexes( @@ -1189,7 +1173,7 @@ impl MerkleBlob { while let Some(this_index) = next_index { let node = self.get_node(this_index)?; - next_index = node.parent; + next_index = node.parent(); lineage.push((index, node)); } @@ -1219,13 +1203,13 @@ impl MerkleBlob { .filter(|(_, block)| block.metadata.dirty) .collect::>() { - let NodeSpecific::Internal { left, right } = block.node.specific else { + let Node::Internal(ref leaf) = block.node else { panic!("leaves should not be dirty") }; // OPT: obviously inefficient to re-get/deserialize these blocks inside // an iteration that's already doing that - let left_hash = self.get_hash(left)?; - let right_hash = self.get_hash(right)?; + let left_hash = self.get_hash(leaf.left)?; + let right_hash = self.get_hash(leaf.right)?; block.update_hash(&left_hash, &right_hash); self.insert_entry_to_blob(index, &block)?; } @@ -1273,14 +1257,14 @@ impl MerkleBlob { // Ok(()) // } + // TODO: really this is test, not unused #[allow(unused)] fn get_key_value_map(&self) -> HashMap { let mut key_value = HashMap::new(); for (key, index) in &self.key_to_index { - let NodeSpecific::Leaf { value, .. } = self.get_node(*index).unwrap().specific else { - panic!() - }; - key_value.insert(*key, value); + // silly waste of having the index, but test code and type narrowing so, ok i guess + let (_, leaf) = self.get_leaf_by_key(*key).unwrap(); + key_value.insert(*key, leaf.value); } key_value @@ -1295,16 +1279,14 @@ impl PartialEq for MerkleBlob { MerkleBlobLeftChildFirstIterator::new(&other.blob), ) { if (self_block.metadata.dirty || other_block.metadata.dirty) - || self_block.node.hash != other_block.node.hash + || self_block.node.hash() != other_block.node.hash() { return false; } - match self_block.node.specific { + match self_block.node { // NOTE: this is effectively checked by the controlled overall traversal - NodeSpecific::Internal { .. } => {} - NodeSpecific::Leaf { .. } => { - return self_block.node.specific == other_block.node.specific - } + Node::Internal(..) => {} + Node::Leaf(..) => return self_block.node == other_block.node, } } @@ -1446,7 +1428,7 @@ impl MerkleBlob { return Err(PyValueError::new_err("root hash is dirty")); } - Ok(Some(block.node.hash)) + Ok(Some(block.node.hash())) } #[pyo3(name = "batch_insert")] @@ -1508,9 +1490,9 @@ impl Iterator for MerkleBlobLeftChildFirstIterator<'_> { let block_bytes: BlockBytes = self.blob[block_range(item.index)].try_into().unwrap(); let block = Block::from_bytes(block_bytes).unwrap(); - match block.node.specific { - NodeSpecific::Leaf { .. } => return Some((item.index, block)), - NodeSpecific::Internal { left, right } => { + match block.node { + Node::Leaf(..) => return Some((item.index, block)), + Node::Internal(ref node) => { if item.visited { return Some((item.index, block)); }; @@ -1521,11 +1503,11 @@ impl Iterator for MerkleBlobLeftChildFirstIterator<'_> { }); self.deque.push_front(MerkleBlobLeftChildFirstIteratorItem { visited: false, - index: right, + index: node.right, }); self.deque.push_front(MerkleBlobLeftChildFirstIteratorItem { visited: false, - index: left, + index: node.left, }); } } @@ -1559,9 +1541,9 @@ impl Iterator for MerkleBlobParentFirstIterator<'_> { let block_bytes: BlockBytes = self.blob[block_range(index)].try_into().unwrap(); let block = Block::from_bytes(block_bytes).unwrap(); - if let NodeSpecific::Internal { left, right } = block.node.specific { - self.deque.push_back(left); - self.deque.push_back(right); + if let Node::Internal(ref node) = block.node { + self.deque.push_back(node.left); + self.deque.push_back(node.right); } Some((index, block)) @@ -1596,11 +1578,11 @@ impl Iterator for MerkleBlobBreadthFirstIterator<'_> { let block_bytes: BlockBytes = self.blob[block_range(index)].try_into().unwrap(); let block = Block::from_bytes(block_bytes).unwrap(); - match block.node.specific { - NodeSpecific::Leaf { .. } => return Some(block), - NodeSpecific::Internal { left, right } => { - self.deque.push_back(left); - self.deque.push_back(right); + match block.node { + Node::Leaf(..) => return Some(block), + Node::Internal(node) => { + self.deque.push_back(node.left); + self.deque.push_back(node.right); } } } @@ -1705,7 +1687,7 @@ mod tests { } assert_eq!(lineage.len(), 2); let (_, last_node) = lineage.last().unwrap(); - assert_eq!(last_node.parent, None); + assert_eq!(last_node.parent(), None); } #[rstest] @@ -1844,26 +1826,25 @@ mod tests { let sibling = merkle_blob .get_node(merkle_blob.key_to_index[&last_key]) .unwrap(); - let parent = merkle_blob.get_node(sibling.parent.unwrap()).unwrap(); - let NodeSpecific::Internal { left, right } = parent.specific else { + let parent = merkle_blob.get_node(sibling.parent().unwrap()).unwrap(); + let Node::Internal(internal) = parent else { panic!() }; - let NodeSpecific::Leaf { key: left_key, .. } = merkle_blob.get_node(left).unwrap().specific - else { - panic!() - }; - let NodeSpecific::Leaf { key: right_key, .. } = - merkle_blob.get_node(right).unwrap().specific - else { - panic!() - }; + let left = merkle_blob + .get_node(internal.left) + .unwrap() + .expect_leaf("<>"); + let right = merkle_blob + .get_node(internal.right) + .unwrap() + .expect_leaf("<>"); let expected_keys: [KvId; 2] = match side { Side::Left => [pre_count as KvId + 1, pre_count as KvId], Side::Right => [pre_count as KvId, pre_count as KvId + 1], }; - assert_eq!([left_key, right_key], expected_keys); + assert_eq!([left.key, right.key], expected_keys); } #[test] @@ -1920,9 +1901,8 @@ mod tests { #[test] fn test_node_type_from_u8_invalid() { let invalid_value = 2; - let expected = format!("unknown NodeType value: {invalid_value}"); let actual = NodeType::from_u8(invalid_value); - actual.expect_err(&expected); + actual.expect_err("invalid node type value should fail"); } #[test] @@ -1930,17 +1910,14 @@ mod tests { NodeMetadata::dirty_from_bytes([0, 2]).expect_err("invalid value should fail"); } - #[test] - #[should_panic(expected = "unable to get sibling index from a leaf")] - fn test_node_specific_sibling_index_panics_for_leaf() { - let leaf = NodeSpecific::Leaf { key: 0, value: 0 }; - leaf.sibling_index(TreeIndex(0)); - } - #[test] #[should_panic(expected = "index not a child: 2")] fn test_node_specific_sibling_index_panics_for_unknown_sibling() { - let node = NodeSpecific::Internal { + // TODO: this probably shouldn't be a panic? + // maybe depends if it is exported or private? + let node = InternalNode { + parent: None, + hash: sha256_num(0), left: TreeIndex(0), right: TreeIndex(1), }; @@ -1985,46 +1962,39 @@ mod tests { let before_blocks = MerkleBlobLeftChildFirstIterator::new(&small_blob.blob).collect::>(); let (key, index) = small_blob.key_to_index.iter().next().unwrap(); - let node = small_blob.get_node(*index).unwrap(); - let NodeSpecific::Leaf { - key: original_key, - value: original_value, - .. - } = node.specific - else { - panic!() - }; - let new_value = original_value + 1; + let original = small_blob.get_node(*index).unwrap().expect_leaf("<>"); + let new_value = original.value + 1; - small_blob.upsert(*key, new_value, &node.hash).unwrap(); + small_blob.upsert(*key, new_value, &original.hash).unwrap(); let after_blocks = MerkleBlobLeftChildFirstIterator::new(&small_blob.blob).collect::>(); assert_eq!(before_blocks.len(), after_blocks.len()); - for ((before_index, before), (after_index, after)) in zip(before_blocks, after_blocks) { - assert_eq!(before.node.parent, after.node.parent); + for ((before_index, before_block), (after_index, after_block)) in + zip(before_blocks, after_blocks) + { + assert_eq!(before_block.node.parent(), after_block.node.parent()); assert_eq!(before_index, after_index); - let NodeSpecific::Leaf { - key: before_key, - value: before_value, - } = before.node.specific - else { - assert_eq!(before.node.specific, after.node.specific); - continue; + let before: LeafNode = match before_block.node { + Node::Leaf(leaf) => leaf, + Node::Internal(internal) => { + let Node::Internal(after) = after_block.node else { + panic!() + }; + assert_eq!(internal.left, after.left); + assert_eq!(internal.right, after.right); + continue; + } }; - let NodeSpecific::Leaf { - key: after_key, - value: after_value, - } = after.node.specific - else { + let Node::Leaf(after) = after_block.node else { panic!() }; - assert_eq!(before_key, after_key); - if before_key == original_key { - assert_eq!(after_value, new_value); + assert_eq!(before.key, after.key); + if before.key == original.key { + assert_eq!(after.value, new_value); } else { - assert_eq!(before_value, after_value); + assert_eq!(before.value, after.value); } } } diff --git a/crates/chia-datalayer/src/merkle/dot.rs b/crates/chia-datalayer/src/merkle/dot.rs index 09ea70c0b..72b75f201 100644 --- a/crates/chia-datalayer/src/merkle/dot.rs +++ b/crates/chia-datalayer/src/merkle/dot.rs @@ -1,4 +1,6 @@ -use crate::merkle::{MerkleBlob, MerkleBlobLeftChildFirstIterator, Node, NodeSpecific, TreeIndex}; +use crate::merkle::{ + InternalNode, LeafNode, MerkleBlob, MerkleBlobLeftChildFirstIterator, Node, TreeIndex, +}; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use url::Url; @@ -53,13 +55,13 @@ impl DotLines { impl Node { pub fn to_dot(&self, index: TreeIndex) -> DotLines { // TODO: can this be done without introducing a blank line? - let node_to_parent = match self.parent { + let node_to_parent = match self.parent() { Some(parent) => format!("node_{index} -> node_{parent};"), None => String::new(), }; - match self.specific { - NodeSpecific::Internal {left, right} => DotLines{ + match self { + Node::Internal ( InternalNode {left, right, ..}) => DotLines{ nodes: vec![ format!("node_{index} [label=\"{index}\"]"), ], @@ -73,7 +75,7 @@ impl Node { ], note: String::new(), }, - NodeSpecific::Leaf {key, value} => DotLines{ + Node::Leaf (LeafNode{key, value, ..}) => DotLines{ nodes: vec![ format!("node_{index} [shape=box, label=\"{index}\\nvalue: {key}\\nvalue: {value}\"];"), ], diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 443830ed0..a090b1524 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -390,22 +390,27 @@ def derive_child_sk_unhardened(sk: PrivateKey, index: int) -> PrivateKey: ... @staticmethod def derive_child_pk_unhardened(pk: G1Element, index: int) -> G1Element: ... + @final -class Node: +class InternalNode: @property def parent(self) -> Optional[uint32]: ... @property def hash(self) -> bytes: ... - # TODO: this all needs reviewed and tidied - @property - def specific(self) -> Union: ... - @property def left(self) -> uint32: ... @property def right(self) -> uint32: ... + +@final +class LeafNode: + @property + def parent(self) -> Optional[uint32]: ... + @property + def hash(self) -> bytes: ... + @property def key(self) -> int64: ... @property @@ -428,10 +433,10 @@ def __init__( def insert(self, key: int64, value: int64, hash: bytes32, reference_kid: Optional[int64] = None, side: Optional[uint8] = None) -> None: ... def delete(self, key: int64) -> None: ... - def get_raw_node(self, index: uint32) -> Node: ... + def get_raw_node(self, index: uint32) -> Union[InternalNode, LeafNode]: ... def calculate_lazy_hashes(self) -> None: ... - def get_lineage_with_indexes(self, index: uint32) -> list[tuple[uint32, Node]]:... - def get_nodes_with_indexes(self) -> list[Node]: ... + def get_lineage_with_indexes(self, index: uint32) -> list[tuple[uint32, Union[InternalNode, LeafNode]]]:... + def get_nodes_with_indexes(self) -> list[Union[InternalNode, LeafNode]]: ... def empty(self) -> bool: ... def get_root_hash(self) -> bytes32: ... def batch_insert(self, keys_values: list[tuple[int64, int64]], hashes: list[bytes32]): ... diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index 317552558..7b05716d2 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -121,22 +121,27 @@ class AugSchemeMPL: @staticmethod def derive_child_pk_unhardened(pk: G1Element, index: int) -> G1Element: ... + @final -class Node: +class InternalNode: @property def parent(self) -> Optional[uint32]: ... @property def hash(self) -> bytes: ... - # TODO: this all needs reviewed and tidied - @property - def specific(self) -> Union: ... - @property def left(self) -> uint32: ... @property def right(self) -> uint32: ... + +@final +class LeafNode: + @property + def parent(self) -> Optional[uint32]: ... + @property + def hash(self) -> bytes: ... + @property def key(self) -> int64: ... @property @@ -159,10 +164,10 @@ class MerkleBlob: def insert(self, key: int64, value: int64, hash: bytes32, reference_kid: Optional[int64] = None, side: Optional[uint8] = None) -> None: ... def delete(self, key: int64) -> None: ... - def get_raw_node(self, index: uint32) -> Node: ... + def get_raw_node(self, index: uint32) -> Union[InternalNode, LeafNode]: ... def calculate_lazy_hashes(self) -> None: ... - def get_lineage_with_indexes(self, index: uint32) -> list[tuple[uint32, Node]]:... - def get_nodes_with_indexes(self) -> list[Node]: ... + def get_lineage_with_indexes(self, index: uint32) -> list[tuple[uint32, Union[InternalNode, LeafNode]]]:... + def get_nodes_with_indexes(self) -> list[Union[InternalNode, LeafNode]]: ... def empty(self) -> bool: ... def get_root_hash(self) -> bytes32: ... def batch_insert(self, keys_values: list[tuple[int64, int64]], hashes: list[bytes32]): ... diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 91eab80ff..ffe7161a3 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -75,7 +75,7 @@ use chia_bls::{ Signature, }; -use chia_datalayer::{MerkleBlob, Node}; +use chia_datalayer::{InternalNode, LeafNode, MerkleBlob}; #[pyfunction] pub fn compute_merkle_set_root<'p>( @@ -479,7 +479,8 @@ pub fn chia_rs(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; // m.add_class::()?; // m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; // merkle tree m.add_class::()?;