From 458f9289588f024e09253f340e005c85b21e99e8 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 28 Jan 2025 11:36:49 -0500 Subject: [PATCH] more --- crates/chia-datalayer/src/merkle.rs | 24 ++++++++------ tests/test_datalayer.py | 10 +++--- wheel/python/chia_rs/datalayer.pyi | 51 ++++++++++++++++++----------- wheel/src/api.rs | 2 ++ 4 files changed, 52 insertions(+), 35 deletions(-) diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index b6c412c66..5417b8cad 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -19,13 +19,18 @@ use std::iter::zip; use std::ops::Range; use thiserror::Error; -#[cfg_attr( - feature = "py-bindings", - derive(FromPyObject, IntoPyObject), - pyo3(transparent) -)] +#[cfg_attr(feature = "py-bindings", pyclass(eq, frozen, hash))] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Streamable)] -pub struct TreeIndex(u32); +pub struct TreeIndex(#[pyo3(get, name = "raw")] u32); + +#[cfg(feature = "py-bindings")] +#[pymethods] +impl TreeIndex { + #[new] + pub fn py_new(raw: u32) -> Self { + Self(raw) + } +} impl std::fmt::Display for TreeIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -317,7 +322,7 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash { Hash(Bytes32::new(hasher.finalize())) } -#[cfg_attr(feature = "py-bindings", pyclass(eq, eq_int))] +#[cfg_attr(feature = "py-bindings", pyclass(eq, eq_int, frozen, hash))] #[repr(u8)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Streamable)] pub enum Side { @@ -1423,8 +1428,7 @@ impl MerkleBlob { value: ValueId, hash: Hash, reference_kid: Option, - // TODO: should be a Side, but python has a different Side right now - side: Option, + side: Option, ) -> PyResult<()> { let insert_location = match (reference_kid, side) { (None, None) => InsertLocation::Auto {}, @@ -1436,7 +1440,7 @@ impl MerkleBlob { .ok_or(PyValueError::new_err(format!( "unknown key id passed as insert location reference: {key}" )))?, - side: Side::from_bytes(&[side])?, + side: side, }, _ => { // TODO: use a specific error diff --git a/tests/test_datalayer.py b/tests/test_datalayer.py index 3670278d5..6f9e8d265 100644 --- a/tests/test_datalayer.py +++ b/tests/test_datalayer.py @@ -1,6 +1,6 @@ import pytest -from chia_rs.datalayer import InvalidBlobLengthError, LeafNode, MerkleBlob, KeyId, ValueId +from chia_rs.datalayer import InvalidBlobLengthError, LeafNode, MerkleBlob, KeyId, ValueId, Side from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import int64, uint8 @@ -30,7 +30,7 @@ def test_just_insert_a_bunch() -> None: total_time = 0.0 for i in range(100_000): start = time.monotonic() - merkle_blob.insert(KeyId(i), ValueId(i), HASH) + merkle_blob.insert(KeyId(int64(i)), ValueId(int64(i)), HASH) end = time.monotonic() total_time += end - start @@ -42,10 +42,10 @@ def test_checking_coverage() -> None: merkle_blob = MerkleBlob(blob=bytearray()) for i in range(count): if i % 2 == 0: - merkle_blob.insert(KeyId(i), ValueId(i), bytes32.zeros) + merkle_blob.insert(KeyId(int64(i)), ValueId(int64(i)), bytes32.zeros) else: merkle_blob.insert( - KeyId(i), ValueId(i), bytes32.zeros, KeyId(i - 1), uint8(0) + KeyId(int64(i)), ValueId(int64(i)), bytes32.zeros, KeyId(int64(i - 1)), Side.Left ) keys = { @@ -53,7 +53,7 @@ def test_checking_coverage() -> None: for index, node in merkle_blob.get_nodes_with_indexes() if isinstance(node, LeafNode) } - assert keys == set(KeyId(n) for n in range(count)) + assert keys == set(KeyId(int64(n)) for n in range(count)) def test_invalid_blob_length_raised() -> None: diff --git a/wheel/python/chia_rs/datalayer.pyi b/wheel/python/chia_rs/datalayer.pyi index 79464bb7c..c9fd40706 100644 --- a/wheel/python/chia_rs/datalayer.pyi +++ b/wheel/python/chia_rs/datalayer.pyi @@ -1,3 +1,4 @@ +from enum import Enum from typing import Mapping, Optional, Sequence, Union, Any, ClassVar, final from .sized_bytes import bytes32, bytes100 from .sized_ints import uint8, uint16, uint32, uint64, uint128, int8, int16, int32, int64 @@ -9,7 +10,6 @@ DATA_SIZE: int BLOCK_SIZE: int METADATA_SIZE: int - class FailedLoadingMetadataError(Exception): ... class FailedLoadingNodeError(Exception): ... class InvalidBlobLengthError(Exception): ... @@ -30,6 +30,11 @@ class IndexIsNotAChildError(Exception): ... class CycleFoundError(Exception): ... class BlockIndexOutOfBoundsError(Exception): ... +@final +class Side(Enum): + Left: int = ... + Right: int = ... + @final class KeyId: raw: int64 @@ -42,34 +47,40 @@ class ValueId: def __init__(self, raw: int64) -> None: ... +@final +class TreeIndex: + raw: uint32 + + def __init__(self, raw: uint32) -> None: ... + @final class InternalNode: - def __init__(self, parent: Optional[uint32], hash: bytes32, left: uint32, right: uint32) -> None: ... + def __init__(self, parent: Optional[TreeIndex], hash: bytes32, left: TreeIndex, right: TreeIndex) -> None: ... @property - def parent(self) -> Optional[uint32]: ... + def parent(self) -> Optional[TreeIndex]: ... @property def hash(self) -> bytes: ... @property - def left(self) -> uint32: ... + def left(self) -> TreeIndex: ... @property - def right(self) -> uint32: ... + def right(self) -> TreeIndex: ... @final class LeafNode: - def __init__(self, parent: Optional[uint32], hash: bytes32, key: int64, value: int64) -> None: ... + def __init__(self, parent: Optional[TreeIndex], hash: bytes32, key: KeyId, value: ValueId) -> None: ... @property - def parent(self) -> Optional[uint32]: ... + def parent(self) -> Optional[TreeIndex]: ... @property def hash(self) -> bytes: ... @property - def key(self) -> int64: ... + def key(self) -> KeyId: ... @property - def value(self) -> int64: ... + def value(self) -> ValueId: ... @final @@ -77,9 +88,9 @@ class MerkleBlob: @property def blob(self) -> bytearray: ... @property - def free_indexes(self) -> set[uint32]: ... + def free_indexes(self) -> set[TreeIndex]: ... @property - def key_to_index(self) -> Mapping[int64, uint32]: ... + def key_to_index(self) -> Mapping[KeyId, TreeIndex]: ... @property def check_integrity_on_drop(self) -> bool: ... @@ -88,19 +99,19 @@ class MerkleBlob: blob: bytes, ) -> None: ... - def insert(self, key: int64, value: int64, hash: bytes32, reference_kid: Optional[int64] = None, side: Optional[uint8] = None) -> None: ... - def upsert(self, key: int64, value: int64, new_hash: bytes32) -> None: ... - def delete(self, key: int64) -> None: ... - def get_raw_node(self, index: uint32) -> Union[InternalNode, LeafNode]: ... + def insert(self, key: KeyId, value: ValueId, hash: bytes32, reference_kid: Optional[KeyId] = None, side: Optional[Side] = None) -> None: ... + def upsert(self, key: KeyId, value: ValueId, new_hash: bytes32) -> None: ... + def delete(self, key: KeyId) -> None: ... + def get_raw_node(self, index: TreeIndex) -> Union[InternalNode, LeafNode]: ... def calculate_lazy_hashes(self) -> None: ... - def get_lineage_with_indexes(self, index: uint32) -> list[tuple[uint32, Union[InternalNode, LeafNode]]]:... - def get_nodes_with_indexes(self) -> list[tuple[uint32, Union[InternalNode, LeafNode]]]: ... + def get_lineage_with_indexes(self, index: TreeIndex) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]:... + def get_nodes_with_indexes(self) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ... def empty(self) -> bool: ... def get_root_hash(self) -> bytes32: ... def batch_insert(self, keys_values: list[tuple[int64, int64]], hashes: list[bytes32]): ... - def get_hash_at_index(self, index: uint32): ... - def get_keys_values(self) -> dict[int64, int64]: ... - def get_key_index(self, key: int64) -> uint32: ... + def get_hash_at_index(self, index: TreeIndex): ... + def get_keys_values(self) -> dict[KeyId, ValueId]: ... + def get_key_index(self, key: KeyId) -> TreeIndex: ... def __len__(self) -> int: ... diff --git a/wheel/src/api.rs b/wheel/src/api.rs index f9755f650..00cb7c587 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -648,6 +648,8 @@ pub fn add_datalayer_submodule(py: Python<'_>, parent: &Bound<'_, PyModule>) -> datalayer.add_class::()?; datalayer.add_class::()?; datalayer.add_class::()?; + datalayer.add_class::()?; + datalayer.add_class::()?; datalayer.add("BLOCK_SIZE", BLOCK_SIZE)?; datalayer.add("DATA_SIZE", DATA_SIZE)?;