Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Jan 28, 2025
1 parent 8e4ade4 commit 458f928
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 35 deletions.
24 changes: 14 additions & 10 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ use std::iter::zip;
use std::ops::Range;
use thiserror::Error;

#[cfg_attr(
feature = "py-bindings",
derive(FromPyObject, IntoPyObject),
pyo3(transparent)
)]
#[cfg_attr(feature = "py-bindings", pyclass(eq, frozen, hash))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Streamable)]
pub struct TreeIndex(u32);
pub struct TreeIndex(#[pyo3(get, name = "raw")] u32);

#[cfg(feature = "py-bindings")]
#[pymethods]
impl TreeIndex {
#[new]
pub fn py_new(raw: u32) -> Self {
Self(raw)
}
}

impl std::fmt::Display for TreeIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -317,7 +322,7 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash {
Hash(Bytes32::new(hasher.finalize()))
}

#[cfg_attr(feature = "py-bindings", pyclass(eq, eq_int))]
#[cfg_attr(feature = "py-bindings", pyclass(eq, eq_int, frozen, hash))]
#[repr(u8)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Streamable)]
pub enum Side {
Expand Down Expand Up @@ -1423,8 +1428,7 @@ impl MerkleBlob {
value: ValueId,
hash: Hash,
reference_kid: Option<KeyId>,
// TODO: should be a Side, but python has a different Side right now
side: Option<u8>,
side: Option<Side>,
) -> PyResult<()> {
let insert_location = match (reference_kid, side) {
(None, None) => InsertLocation::Auto {},
Expand All @@ -1436,7 +1440,7 @@ impl MerkleBlob {
.ok_or(PyValueError::new_err(format!(
"unknown key id passed as insert location reference: {key}"
)))?,
side: Side::from_bytes(&[side])?,
side: side,
},
_ => {
// TODO: use a specific error
Expand Down
10 changes: 5 additions & 5 deletions tests/test_datalayer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from chia_rs.datalayer import InvalidBlobLengthError, LeafNode, MerkleBlob, KeyId, ValueId
from chia_rs.datalayer import InvalidBlobLengthError, LeafNode, MerkleBlob, KeyId, ValueId, Side
from chia_rs.sized_bytes import bytes32
from chia_rs.sized_ints import int64, uint8

Expand Down Expand Up @@ -30,7 +30,7 @@ def test_just_insert_a_bunch() -> None:
total_time = 0.0
for i in range(100_000):
start = time.monotonic()
merkle_blob.insert(KeyId(i), ValueId(i), HASH)
merkle_blob.insert(KeyId(int64(i)), ValueId(int64(i)), HASH)
end = time.monotonic()
total_time += end - start

Expand All @@ -42,18 +42,18 @@ def test_checking_coverage() -> None:
merkle_blob = MerkleBlob(blob=bytearray())
for i in range(count):
if i % 2 == 0:
merkle_blob.insert(KeyId(i), ValueId(i), bytes32.zeros)
merkle_blob.insert(KeyId(int64(i)), ValueId(int64(i)), bytes32.zeros)
else:
merkle_blob.insert(
KeyId(i), ValueId(i), bytes32.zeros, KeyId(i - 1), uint8(0)
KeyId(int64(i)), ValueId(int64(i)), bytes32.zeros, KeyId(int64(i - 1)), Side.Left
)

keys = {
node.key
for index, node in merkle_blob.get_nodes_with_indexes()
if isinstance(node, LeafNode)
}
assert keys == set(KeyId(n) for n in range(count))
assert keys == set(KeyId(int64(n)) for n in range(count))


def test_invalid_blob_length_raised() -> None:
Expand Down
51 changes: 31 additions & 20 deletions wheel/python/chia_rs/datalayer.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Mapping, Optional, Sequence, Union, Any, ClassVar, final
from .sized_bytes import bytes32, bytes100
from .sized_ints import uint8, uint16, uint32, uint64, uint128, int8, int16, int32, int64
Expand All @@ -9,7 +10,6 @@ DATA_SIZE: int
BLOCK_SIZE: int
METADATA_SIZE: int


class FailedLoadingMetadataError(Exception): ...
class FailedLoadingNodeError(Exception): ...
class InvalidBlobLengthError(Exception): ...
Expand All @@ -30,6 +30,11 @@ class IndexIsNotAChildError(Exception): ...
class CycleFoundError(Exception): ...
class BlockIndexOutOfBoundsError(Exception): ...

@final
class Side(Enum):
Left: int = ...
Right: int = ...

@final
class KeyId:
raw: int64
Expand All @@ -42,44 +47,50 @@ class ValueId:

def __init__(self, raw: int64) -> None: ...

@final
class TreeIndex:
raw: uint32

def __init__(self, raw: uint32) -> None: ...

@final
class InternalNode:
def __init__(self, parent: Optional[uint32], hash: bytes32, left: uint32, right: uint32) -> None: ...
def __init__(self, parent: Optional[TreeIndex], hash: bytes32, left: TreeIndex, right: TreeIndex) -> None: ...

@property
def parent(self) -> Optional[uint32]: ...
def parent(self) -> Optional[TreeIndex]: ...
@property
def hash(self) -> bytes: ...

@property
def left(self) -> uint32: ...
def left(self) -> TreeIndex: ...
@property
def right(self) -> uint32: ...
def right(self) -> TreeIndex: ...


@final
class LeafNode:
def __init__(self, parent: Optional[uint32], hash: bytes32, key: int64, value: int64) -> None: ...
def __init__(self, parent: Optional[TreeIndex], hash: bytes32, key: KeyId, value: ValueId) -> None: ...

@property
def parent(self) -> Optional[uint32]: ...
def parent(self) -> Optional[TreeIndex]: ...
@property
def hash(self) -> bytes: ...

@property
def key(self) -> int64: ...
def key(self) -> KeyId: ...
@property
def value(self) -> int64: ...
def value(self) -> ValueId: ...


@final
class MerkleBlob:
@property
def blob(self) -> bytearray: ...
@property
def free_indexes(self) -> set[uint32]: ...
def free_indexes(self) -> set[TreeIndex]: ...
@property
def key_to_index(self) -> Mapping[int64, uint32]: ...
def key_to_index(self) -> Mapping[KeyId, TreeIndex]: ...
@property
def check_integrity_on_drop(self) -> bool: ...

Expand All @@ -88,19 +99,19 @@ class MerkleBlob:
blob: bytes,
) -> None: ...

def insert(self, key: int64, value: int64, hash: bytes32, reference_kid: Optional[int64] = None, side: Optional[uint8] = None) -> None: ...
def upsert(self, key: int64, value: int64, new_hash: bytes32) -> None: ...
def delete(self, key: int64) -> None: ...
def get_raw_node(self, index: uint32) -> Union[InternalNode, LeafNode]: ...
def insert(self, key: KeyId, value: ValueId, hash: bytes32, reference_kid: Optional[KeyId] = None, side: Optional[Side] = None) -> None: ...
def upsert(self, key: KeyId, value: ValueId, new_hash: bytes32) -> None: ...
def delete(self, key: KeyId) -> None: ...
def get_raw_node(self, index: TreeIndex) -> Union[InternalNode, LeafNode]: ...
def calculate_lazy_hashes(self) -> None: ...
def get_lineage_with_indexes(self, index: uint32) -> list[tuple[uint32, Union[InternalNode, LeafNode]]]:...
def get_nodes_with_indexes(self) -> list[tuple[uint32, Union[InternalNode, LeafNode]]]: ...
def get_lineage_with_indexes(self, index: TreeIndex) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]:...
def get_nodes_with_indexes(self) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ...
def empty(self) -> bool: ...
def get_root_hash(self) -> bytes32: ...
def batch_insert(self, keys_values: list[tuple[int64, int64]], hashes: list[bytes32]): ...
def get_hash_at_index(self, index: uint32): ...
def get_keys_values(self) -> dict[int64, int64]: ...
def get_key_index(self, key: int64) -> uint32: ...
def get_hash_at_index(self, index: TreeIndex): ...
def get_keys_values(self) -> dict[KeyId, ValueId]: ...
def get_key_index(self, key: KeyId) -> TreeIndex: ...

def __len__(self) -> int: ...

Expand Down
2 changes: 2 additions & 0 deletions wheel/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ pub fn add_datalayer_submodule(py: Python<'_>, parent: &Bound<'_, PyModule>) ->
datalayer.add_class::<LeafNode>()?;
datalayer.add_class::<KeyId>()?;
datalayer.add_class::<ValueId>()?;
datalayer.add_class::<Side>()?;
datalayer.add_class::<TreeIndex>()?;

datalayer.add("BLOCK_SIZE", BLOCK_SIZE)?;
datalayer.add("DATA_SIZE", DATA_SIZE)?;
Expand Down

0 comments on commit 458f928

Please sign in to comment.