diff --git a/crates/chia-protocol/src/block_record.rs b/crates/chia-protocol/src/block_record.rs index c90e0d77b..3391ec520 100644 --- a/crates/chia-protocol/src/block_record.rs +++ b/crates/chia-protocol/src/block_record.rs @@ -1,6 +1,9 @@ -use chia_streamable_macro::streamable; - +#[cfg(feature = "py-bindings")] +use crate::pot_iterations::{calculate_ip_iters, calculate_sp_iters}; use crate::{Bytes32, ClassgroupElement, Coin, SubEpochSummary}; +use chia_streamable_macro::streamable; +#[cfg(feature = "py-bindings")] +use pyo3::exceptions::PyValueError; #[cfg(feature = "py-bindings")] use pyo3::prelude::*; @@ -67,13 +70,48 @@ impl BlockRecord { pub fn is_challenge_block(&self, min_blocks_per_challenge_block: u8) -> bool { self.deficit == min_blocks_per_challenge_block - 1 } -} - -#[cfg(feature = "py-bindings")] -use pyo3::types::PyDict; -#[cfg(feature = "py-bindings")] -use pyo3::exceptions::PyValueError; + // fn calculate_sp_interval_iters(&self, num_sps_sub_slot: u64) -> PyResult { + // if self.sub_slot_iters % num_sps_sub_slot != 0 { + // return Err(PyValueError::new_err( + // "sub_slot_iters % constants.NUM_SPS_SUB_SLOT != 0", + // )); + // } + // Ok(self.sub_slot_iters / num_sps_sub_slot) + // } + + // fn calculate_sp_iters(&self, num_sps_sub_slot: u32) -> PyResult { + // if self.signage_point_index as u32 >= num_sps_sub_slot { + // return Err(PyValueError::new_err("SP index too high")); + // } + // Ok(self.calculate_sp_interval_iters(num_sps_sub_slot as u64)? + // * self.signage_point_index as u64) + // } + + // fn calculate_ip_iters( + // &self, + // num_sps_sub_slot: u32, + // num_sp_intervals_extra: u8, + // ) -> PyResult { + // let sp_iters = self.calculate_sp_iters(num_sps_sub_slot)?; + // let sp_interval_iters = self.calculate_sp_interval_iters(num_sps_sub_slot as u64)?; + // if sp_iters % sp_interval_iters != 0 || sp_iters >= self.sub_slot_iters { + // return Err(PyValueError::new_err(format!( + // "Invalid sp iters {sp_iters} for this ssi {}", + // self.sub_slot_iters + // ))); + // } else if self.required_iters >= sp_interval_iters || self.required_iters == 0 { + // return Err(PyValueError::new_err(format!( + // "Required iters {} is not below the sp interval iters {} {} or not >=0", + // self.required_iters, sp_interval_iters, self.sub_slot_iters + // ))); + // } + // Ok( + // (sp_iters + num_sp_intervals_extra as u64 * sp_interval_iters + self.required_iters) + // % self.sub_slot_iters, + // ) + // } +} #[cfg(feature = "py-bindings")] use chia_traits::ChiaToPython; @@ -102,16 +140,11 @@ impl BlockRecord { )) } - // TODO: at some point it would be nice to port - // chia.consensus.pot_iterations to rust, and make this less hacky - fn sp_sub_slot_total_iters_impl( - &self, - py: Python<'_>, - constants: &Bound<'_, PyAny>, - ) -> PyResult { + // TODO: these could be implemented as a total port of pot iterations + fn sp_sub_slot_total_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult { let ret = self .total_iters - .checked_sub(self.ip_iters_impl(py, constants)? as u128) + .checked_sub(self.ip_iters_impl(constants)? as u128) .ok_or(PyValueError::new_err("uint128 overflow"))?; if self.overflow { ret.checked_sub(self.sub_slot_iters as u128) @@ -121,48 +154,38 @@ impl BlockRecord { } } - fn ip_sub_slot_total_iters_impl( - &self, - py: Python<'_>, - constants: &Bound<'_, PyAny>, - ) -> PyResult { + fn ip_sub_slot_total_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult { self.total_iters - .checked_sub(self.ip_iters_impl(py, constants)? as u128) + .checked_sub(self.ip_iters_impl(constants)? as u128) .ok_or(PyValueError::new_err("uint128 overflow")) } - fn sp_iters_impl(&self, py: Python<'_>, constants: &Bound<'_, PyAny>) -> PyResult { - let ctx = PyDict::new(py); - ctx.set_item("sub_slot_iters", self.sub_slot_iters)?; - ctx.set_item("signage_point_index", self.signage_point_index)?; - ctx.set_item("constants", constants)?; - py.run( - c"from chia.consensus.pot_iterations import calculate_ip_iters, calculate_sp_iters\n\ - ret = calculate_sp_iters(constants, sub_slot_iters, signage_point_index)\n", - None, - Some(&ctx), - )?; - ctx.get_item("ret").unwrap().unwrap().extract::() - } - - fn ip_iters_impl(&self, py: Python<'_>, constants: &Bound<'_, PyAny>) -> PyResult { - let ctx = PyDict::new(py); - ctx.set_item("sub_slot_iters", self.sub_slot_iters)?; - ctx.set_item("signage_point_index", self.signage_point_index)?; - ctx.set_item("required_iters", self.required_iters)?; - ctx.set_item("constants", constants)?; - py.run( - c"from chia.consensus.pot_iterations import calculate_ip_iters, calculate_sp_iters\n\ - ret = calculate_ip_iters(constants, sub_slot_iters, signage_point_index, required_iters)\n", - None, - Some(&ctx), - )?; - ctx.get_item("ret").unwrap().unwrap().extract::() - } - - fn sp_total_iters_impl(&self, py: Python<'_>, constants: &Bound<'_, PyAny>) -> PyResult { - self.sp_sub_slot_total_iters_impl(py, constants)? - .checked_add(self.sp_iters_impl(py, constants)? as u128) + fn sp_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult { + let num_sps_sub_slot = constants.getattr("NUM_SPS_SUB_SLOT")?.extract::()?; + calculate_sp_iters( + num_sps_sub_slot, + self.sub_slot_iters, + self.signage_point_index as u32, + ) + } + + fn ip_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult { + let num_sps_sub_slot = constants.getattr("NUM_SPS_SUB_SLOT")?.extract::()?; + let num_sp_intervals_extra = constants + .getattr("NUM_SP_INTERVALS_EXTRA")? + .extract::()?; + calculate_ip_iters( + num_sps_sub_slot, + num_sp_intervals_extra, + self.sub_slot_iters, + self.signage_point_index as u32, + self.required_iters, + ) + } + + fn sp_total_iters_impl(&self, constants: &Bound<'_, PyAny>) -> PyResult { + self.sp_sub_slot_total_iters_impl(constants)? + .checked_add(self.sp_iters_impl(constants)? as u128) .ok_or(PyValueError::new_err("uint128 overflow")) } @@ -171,7 +194,7 @@ impl BlockRecord { py: Python<'a>, constants: &Bound<'_, PyAny>, ) -> PyResult> { - ChiaToPython::to_python(&self.sp_sub_slot_total_iters_impl(py, constants)?, py) + ChiaToPython::to_python(&self.sp_sub_slot_total_iters_impl(constants)?, py) } fn ip_sub_slot_total_iters<'a>( @@ -179,7 +202,7 @@ impl BlockRecord { py: Python<'a>, constants: &Bound<'_, PyAny>, ) -> PyResult> { - ChiaToPython::to_python(&self.ip_sub_slot_total_iters_impl(py, constants)?, py) + ChiaToPython::to_python(&self.ip_sub_slot_total_iters_impl(constants)?, py) } fn sp_iters<'a>( @@ -187,7 +210,7 @@ impl BlockRecord { py: Python<'a>, constants: &Bound<'_, PyAny>, ) -> PyResult> { - ChiaToPython::to_python(&self.sp_iters_impl(py, constants)?, py) + ChiaToPython::to_python(&self.sp_iters_impl(constants)?, py) } fn ip_iters<'a>( @@ -195,7 +218,7 @@ impl BlockRecord { py: Python<'a>, constants: &Bound<'_, PyAny>, ) -> PyResult> { - ChiaToPython::to_python(&self.ip_iters_impl(py, constants)?, py) + ChiaToPython::to_python(&self.ip_iters_impl(constants)?, py) } fn sp_total_iters<'a>( @@ -203,6 +226,6 @@ impl BlockRecord { py: Python<'a>, constants: &Bound<'_, PyAny>, ) -> PyResult> { - ChiaToPython::to_python(&self.sp_total_iters_impl(py, constants)?, py) + ChiaToPython::to_python(&self.sp_total_iters_impl(constants)?, py) } } diff --git a/crates/chia-protocol/src/lib.rs b/crates/chia-protocol/src/lib.rs index 544979383..53d6a6c7a 100644 --- a/crates/chia-protocol/src/lib.rs +++ b/crates/chia-protocol/src/lib.rs @@ -13,6 +13,10 @@ mod fullblock; mod header_block; mod peer_info; mod pool_target; +#[cfg(feature = "py-bindings")] +mod pos_quality; +#[cfg(feature = "py-bindings")] +mod pot_iterations; mod program; mod proof_of_space; mod reward_chain_block; @@ -44,6 +48,10 @@ pub use crate::fullblock::*; pub use crate::header_block::*; pub use crate::peer_info::*; pub use crate::pool_target::*; +#[cfg(feature = "py-bindings")] +pub use crate::pos_quality::*; +#[cfg(feature = "py-bindings")] +pub use crate::pot_iterations::*; pub use crate::program::*; pub use crate::proof_of_space::*; pub use crate::reward_chain_block::*; diff --git a/crates/chia-protocol/src/pos_quality.rs b/crates/chia-protocol/src/pos_quality.rs new file mode 100644 index 000000000..ae13f1eac --- /dev/null +++ b/crates/chia-protocol/src/pos_quality.rs @@ -0,0 +1,19 @@ +// The actual space in bytes of a plot, is _expected_plot_size(k) * UI_ACTUAL_SPACE_CONSTANT_FACTO +// This is not used in consensus, only for display purposes + +pub const UI_ACTUAL_SPACE_CONSTANT_FACTOR: f32 = 0.78; + +// TODO: Update this when new plot format releases +#[cfg(feature = "py-bindings")] +#[pyo3::pyfunction] +pub fn expected_plot_size(k: u32) -> pyo3::PyResult { + // """ + // Given the plot size parameter k (which is between 32 and 59), computes the + // expected size of the plot in bytes (times a constant factor). This is based on efficient encoding + // of the plot, and aims to be scale agnostic, so larger plots don't + // necessarily get more rewards per byte. The +1 is added to give half a bit more space per entry, which + // is necessary to store the entries in the plot. + // """ + + Ok((2 * k as u64 + 1) * (1_u64 << (k - 1))) +} diff --git a/crates/chia-protocol/src/pot_iterations.rs b/crates/chia-protocol/src/pot_iterations.rs new file mode 100644 index 000000000..321a83624 --- /dev/null +++ b/crates/chia-protocol/src/pot_iterations.rs @@ -0,0 +1,247 @@ +// use crate::Bytes32; +// use chia_sha2::Sha256; +// use std::convert::TryInto; +// use crate::pos_quality::expected_plot_size; + +#[cfg(feature = "py-bindings")] +#[pyo3::pyfunction] +pub fn is_overflow_block( + num_sps_sub_slot: u32, + num_sp_intervals_extra: u8, + signage_point_index: u32, +) -> pyo3::PyResult { + if signage_point_index >= num_sps_sub_slot { + return Err(pyo3::exceptions::PyValueError::new_err("SP index too high")); + } + Ok(signage_point_index >= num_sps_sub_slot - num_sp_intervals_extra as u32) +} + +#[cfg(feature = "py-bindings")] +#[pyo3::pyfunction] +pub fn calculate_sp_interval_iters( + num_sps_sub_slot: u32, + sub_slot_iters: u64, +) -> pyo3::PyResult { + if sub_slot_iters % num_sps_sub_slot as u64 != 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "ssi % num_sps_sub_slot != 0", + )); + } + Ok(sub_slot_iters / num_sps_sub_slot as u64) +} + +#[cfg(feature = "py-bindings")] +#[pyo3::pyfunction] +pub fn calculate_sp_iters( + num_sps_sub_slot: u32, + sub_slot_iters: u64, + signage_point_index: u32, +) -> pyo3::PyResult { + if signage_point_index >= num_sps_sub_slot { + return Err(pyo3::exceptions::PyValueError::new_err("SP index too high")); + } + Ok(calculate_sp_interval_iters(num_sps_sub_slot, sub_slot_iters)? * signage_point_index as u64) +} + +#[cfg(feature = "py-bindings")] +#[pyo3::pyfunction] +pub fn calculate_ip_iters( + num_sps_sub_slot: u32, + num_sp_intervals_extra: u8, + sub_slot_iters: u64, + signage_point_index: u32, + required_iters: u64, +) -> pyo3::PyResult { + let sp_interval_iters = calculate_sp_interval_iters(num_sps_sub_slot, sub_slot_iters)?; + let sp_iters = calculate_sp_iters(num_sps_sub_slot, sub_slot_iters, signage_point_index)?; + if sp_iters % sp_interval_iters != 0 || sp_iters > sub_slot_iters { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Invalid sp iters {sp_iters} for this ssi {sub_slot_iters}", + ))); + } else if required_iters >= sp_interval_iters || required_iters == 0 { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Required iters {required_iters} is not below the sp interval iters {sp_interval_iters} {sub_slot_iters} or not >=0", + ))); + } + Ok( + (sp_iters + num_sp_intervals_extra as u64 * sp_interval_iters + required_iters) + % sub_slot_iters, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + static NUM_SPS_SUB_SLOT: u32 = 32; + static NUM_SP_INTERVALS_EXTRA: u8 = 3; + + #[test] + fn test_is_overflow_block() { + assert!( + !is_overflow_block(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, 27) + .expect("valid SP index") + ); + assert!( + !is_overflow_block(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, 28) + .expect("valid SP index") + ); + assert!( + is_overflow_block(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, 29) + .expect("valid SP index") + ); + assert!( + is_overflow_block(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, 30) + .expect("valid SP index") + ); + assert!( + is_overflow_block(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, 31) + .expect("valid SP index") + ); + assert!(is_overflow_block(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, 32).is_err()); + } + + #[test] + fn test_calculate_sp_iters() { + let ssi: u64 = 100_001 * 64 * 4; + assert!(calculate_sp_iters(NUM_SPS_SUB_SLOT, ssi, 32).is_err()); + calculate_sp_iters(NUM_SPS_SUB_SLOT, ssi, 31).expect("valid_result"); + } + + #[test] + fn test_calculate_ip_iters() { + // # num_sps_sub_slot: u32, + // # num_sp_intervals_extra: u8, + // # sub_slot_iters: u64, + // # signage_point_index: u8, + // # required_iters: u64, + let ssi: u64 = 100_001 * 64 * 4; + let sp_interval_iters = ssi / NUM_SPS_SUB_SLOT as u64; + + // Invalid signage point index + assert!( + calculate_ip_iters(NUM_SPS_SUB_SLOT, NUM_SP_INTERVALS_EXTRA, ssi, 123, 100_000) + .is_err() + ); + + let sp_iters = sp_interval_iters * 13; + + // required_iters too high + assert!(calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + sp_interval_iters.try_into().unwrap(), + sp_interval_iters + ) + .is_err()); + + // required_iters too high + assert!(calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + sp_interval_iters.try_into().unwrap(), + sp_interval_iters * 12 + ) + .is_err()); + + // required_iters too low (0) + assert!(calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + sp_interval_iters.try_into().unwrap(), + 0 + ) + .is_err()); + + let required_iters = sp_interval_iters - 1; + let ip_iters = calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + 13, + required_iters, + ) + .expect("should be valid"); + assert_eq!( + ip_iters, + sp_iters + (NUM_SP_INTERVALS_EXTRA as u64 * sp_interval_iters) + required_iters + ); + + let required_iters = 1_u64; + let ip_iters = calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + 13, + required_iters, + ) + .expect("valid"); + assert_eq!( + ip_iters, + sp_iters + (NUM_SP_INTERVALS_EXTRA as u64 * sp_interval_iters) + required_iters + ); + + let required_iters: u64 = ssi * 4 / 300; + let ip_iters = calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + 13, + required_iters, + ) + .expect("valid"); + assert_eq!( + ip_iters, + sp_iters + (NUM_SP_INTERVALS_EXTRA as u64 * sp_interval_iters) + required_iters + ); + assert!(sp_iters < ip_iters); + + // Overflow + let sp_iters = sp_interval_iters * (NUM_SPS_SUB_SLOT - 1) as u64; + let ip_iters = calculate_ip_iters( + NUM_SPS_SUB_SLOT, + NUM_SP_INTERVALS_EXTRA, + ssi, + NUM_SPS_SUB_SLOT - 1_u32, + required_iters, + ) + .expect("valid"); + assert_eq!( + ip_iters, + (sp_iters + (NUM_SP_INTERVALS_EXTRA as u64 * sp_interval_iters) + required_iters) % ssi + ); + assert!(sp_iters > ip_iters); + } +} + +// TODO: enable and fix below + +// #[cfg(feature = "py-bindings")] +// #[pyo3::pyfunction] +// pub fn calculate_iterations_quality( +// difficulty_constant_factor: u128, +// quality_string: Bytes32, +// size: u32, +// difficulty: u64, +// cc_sp_output_hash: Bytes32, +// ) -> pyo3::PyResult { +// // Hash the concatenation of `quality_string` and `cc_sp_output_hash` +// let mut hasher = Sha256::new(); +// hasher.update(quality_string); +// hasher.update(cc_sp_output_hash); +// let sp_quality_string = hasher.finalize(); + +// // Convert the hash bytes to a big-endian u128 integer +// let sp_quality_value = u128::from_be_bytes(sp_quality_string[..16]); + +// // Expected plot size calculation function +// let plot_size = expected_plot_size(size); + +// // Calculate the number of iterations +// let iters = (difficulty as u128 * difficulty_constant_factor * sp_quality_value) +// / ((1_u128 << 256) * plot_size as u128); + +// Ok(iters.max(1) as u64) +// } diff --git a/tests/test_block_record_fidelity.py b/tests/test_block_record_fidelity.py index e37ec5ef2..10698035e 100644 --- a/tests/test_block_record_fidelity.py +++ b/tests/test_block_record_fidelity.py @@ -1,5 +1,5 @@ from typing import Optional, Any, Callable - +from pytest import raises import sys import time from chia_rs import BlockRecord, ClassgroupElement @@ -62,15 +62,15 @@ def get_hash(rng: Random) -> bytes32: return bytes32.random(rng) -def get_block_record(rng: Random) -> BlockRecord: +def get_block_record(rng: Random, ssi=None, ri=None, spi=None) -> BlockRecord: height = get_u32(rng) weight = get_u128(rng) iters = get_u128(rng) - sp_index = get_u4(rng) + sp_index = spi if spi is not None else get_u4(rng) vdf_out = get_classgroup_element(rng) infused_challenge = get_optional(rng, get_classgroup_element) - sub_slot_iters = get_ssi(rng) - required_iters = get_u64(rng) + sub_slot_iters = ssi if ssi is not None else get_ssi(rng) + required_iters = ri if ri is not None else get_u64(rng) deficit = get_u8(rng) overflow = get_bool(rng) prev_tx_height = get_u32(rng) @@ -129,3 +129,28 @@ def wrap_call(expr: str, br: Any) -> str: return f"V:{ret}" except Exception as e: return f"E:{e}" + + +# TODO: more thoroughly check these new functions which use self +def test_calculate_sp_iters(): + ssi: uint64 = uint64(100001 * 64 * 4) + rng = Random() + rng.seed(1337) + br = get_block_record(rng, ssi=ssi, spi=31) + res = br.sp_iters_impl(DEFAULT_CONSTANTS) + assert res is not None + + +def test_calculate_ip_iters(): + ssi: uint64 = uint64(100001 * 64 * 4) + sp_interval_iters = ssi // 32 + ri = sp_interval_iters - 1 + rng = Random() + rng.seed(1337) + br = get_block_record(rng, ssi=ssi, spi=31, ri=ri) + with raises(ValueError): + res = br.ip_iters_impl(DEFAULT_CONSTANTS) + + br = get_block_record(rng, ssi=ssi, spi=13, ri=1) + res = br.ip_iters_impl(DEFAULT_CONSTANTS) + assert res is not None diff --git a/tests/test_blscache.py b/tests/test_blscache.py index 44fb7cb7c..edd8f2450 100644 --- a/tests/test_blscache.py +++ b/tests/test_blscache.py @@ -14,9 +14,6 @@ ) from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint8, uint16, uint32, uint64, uint128 -from chia.util.hash import std_hash -from chia.util.lru_cache import LRUCache -from chia.types.blockchain_format.program import Program as ChiaProgram import pytest @@ -174,7 +171,6 @@ def test_cached_bls(): # Use a small cache which can not accommodate all pairings bls_cache = BLSCache(n_keys // 2) - local_cache = LRUCache(n_keys // 2) # Verify signatures and cache pairings one at a time for pk, msg, sig in zip(pks_half, msgs_half, sigs_half): assert bls_cache.aggregate_verify([pk], [msg], sig) @@ -221,13 +217,10 @@ def test_cached_bls_repeat_pk(): cached_bls = BLSCache() n_keys = 400 seed = b"a" * 32 - sks = [AugSchemeMPL.key_gen(seed) for i in range(n_keys)] + [ - AugSchemeMPL.key_gen(std_hash(seed)) - ] + sks = [AugSchemeMPL.key_gen(seed) for _ in range(n_keys)] pks = [sk.get_g1() for sk in sks] - pks_bytes = [bytes(sk.get_g1()) for sk in sks] - msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys + 1)] + msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys)] sigs = [AugSchemeMPL.sign(sk, msg) for sk, msg in zip(sks, msgs)] agg_sig = AugSchemeMPL.aggregate(sigs) @@ -304,11 +297,7 @@ def test_validate_clvm_and_sig(): ) sig = AugSchemeMPL.sign( sk, - ( - ChiaProgram.to("hello").as_atom() - + coin.name() - + DEFAULT_CONSTANTS.AGG_SIG_ME_ADDITIONAL_DATA - ), # noqa + (b"hello" + coin.name() + DEFAULT_CONSTANTS.AGG_SIG_ME_ADDITIONAL_DATA), # noqa ) new_spend = SpendBundle(coin_spends, sig) diff --git a/tests/test_pot_iterations.py b/tests/test_pot_iterations.py new file mode 100644 index 000000000..b3f078cb4 --- /dev/null +++ b/tests/test_pot_iterations.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from pytest import raises + +from run_gen import DEFAULT_CONSTANTS + +# from chia.consensus.pos_quality import _expected_plot_size +from chia_rs import ( + calculate_ip_iters, + # calculate_iterations_quality, + calculate_sp_iters, + is_overflow_block, +) + +# from chia.util.hash import std_hash +from chia_rs.sized_ints import uint8, uint16, uint32, uint64 # , uint128 + +test_constants = DEFAULT_CONSTANTS.replace( + NUM_SPS_SUB_SLOT=uint32(32), SUB_SLOT_TIME_TARGET=uint16(300) +) + + +class TestPotIterations: + def test_is_overflow_block(self): + assert not is_overflow_block( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + uint8(27), + ) + assert not is_overflow_block( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + uint8(28), + ) + assert is_overflow_block( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + uint8(29), + ) + assert is_overflow_block( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + uint8(30), + ) + assert is_overflow_block( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + uint8(31), + ) + with raises(ValueError): + assert is_overflow_block( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + uint8(32), + ) + + def test_calculate_sp_iters(self): + ssi: uint64 = uint64(100001 * 64 * 4) + with raises(ValueError): + calculate_sp_iters(test_constants.NUM_SPS_SUB_SLOT, ssi, uint8(32)) + calculate_sp_iters(test_constants.NUM_SPS_SUB_SLOT, ssi, uint8(31)) + + def test_calculate_ip_iters(self): + # num_sps_sub_slot: u32, + # num_sp_intervals_extra: u8, + # sub_slot_iters: u64, + # signage_point_index: u8, + # required_iters: u64, + ssi: uint64 = uint64(100001 * 64 * 4) + sp_interval_iters = ssi // test_constants.NUM_SPS_SUB_SLOT + + with raises(ValueError): + # Invalid signage point index + calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + uint8(123), + uint64(100000), + ) + + sp_iters = sp_interval_iters * 13 + + with raises(ValueError): + # required_iters too high + calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + sp_interval_iters, + sp_interval_iters, + ) + + with raises(ValueError): + # required_iters too high + calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + sp_interval_iters, + sp_interval_iters * 12, + ) + + with raises(ValueError): + # required_iters too low (0) + calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + sp_interval_iters, + uint64(0), + ) + + required_iters = sp_interval_iters - 1 + ip_iters = calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + uint8(13), + required_iters, + ) + assert ( + ip_iters + == sp_iters + + test_constants.NUM_SP_INTERVALS_EXTRA * sp_interval_iters + + required_iters + ) + + required_iters = uint64(1) + ip_iters = calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + uint8(13), + required_iters, + ) + assert ( + ip_iters + == sp_iters + + test_constants.NUM_SP_INTERVALS_EXTRA * sp_interval_iters + + required_iters + ) + + required_iters = uint64(int(ssi * 4 / 300)) + ip_iters = calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + uint8(13), + required_iters, + ) + assert ( + ip_iters + == sp_iters + + test_constants.NUM_SP_INTERVALS_EXTRA * sp_interval_iters + + required_iters + ) + assert sp_iters < ip_iters + + # Overflow + sp_iters = sp_interval_iters * (test_constants.NUM_SPS_SUB_SLOT - 1) + ip_iters = calculate_ip_iters( + test_constants.NUM_SPS_SUB_SLOT, + test_constants.NUM_SP_INTERVALS_EXTRA, + ssi, + uint8(test_constants.NUM_SPS_SUB_SLOT - 1), + required_iters, + ) + assert ( + ip_iters + == ( + sp_iters + + test_constants.NUM_SP_INTERVALS_EXTRA * sp_interval_iters + + required_iters + ) + % ssi + ) + assert sp_iters > ip_iters diff --git a/tests/test_program_fidelity.py b/tests/test_program_fidelity.py deleted file mode 100644 index b8753aefe..000000000 --- a/tests/test_program_fidelity.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import Optional - -import string -import chia_rs -from chia.types.blockchain_format.program import Program as ChiaProgram -from chia.types.blockchain_format.serialized_program import SerializedProgram -from random import Random - - -def rand_bytes(rnd: Random) -> bytes: - size = rnd.randint(0, 4) - ret = bytearray() - for _ in range(size): - ret.append(rnd.getrandbits(8)) - return bytes(ret) - - -def rand_string(rnd: Random) -> str: - size = rnd.randint(1, 10) - return "".join(rnd.choices(string.ascii_uppercase + string.digits, k=size)) - - -def rand_int(rnd: Random) -> int: - return rnd.randint(0, 100000000000000) - - -def rand_list(rnd: Random) -> list: - size = rnd.randint(0, 3) - ret = [] - for _ in range(size): - ret.append(rand_object(rnd)) - return ret - - -def rand_program(rnd: Random) -> ChiaProgram: - return ChiaProgram.from_bytes(b"\xff\x01\xff\x04\x01") - - -def rand_rust_program(rnd: Random) -> chia_rs.Program: - return chia_rs.Program.from_bytes(b"\xff\x01\xff\x04\x01") - - -def rand_optional(rnd: Random) -> Optional[object]: - if rnd.randint(0, 1) == 0: - return None - return rand_object(rnd) - - -def rand_object(rnd: Random) -> object: - types = [ - rand_optional, - rand_int, - rand_string, - rand_bytes, - rand_program, - rand_list, - rand_rust_program, - ] - return rnd.sample(types, 1)[0](rnd) - - -def recursive_replace(o: object) -> object: - if isinstance(o, list): - ret = [] - for i in o: - ret.append(recursive_replace(i)) - return ret - elif isinstance(o, chia_rs.Program): - return SerializedProgram.from_bytes(o.to_bytes()) - else: - return o - - -def test_run_program() -> None: - - rust_identity = chia_rs.Program.from_bytes(b"\x01") - py_identity = SerializedProgram.from_bytes(b"\x01") - - rnd = Random() - for _ in range(10000): - args = rand_object(rnd) - - # the python SerializedProgram treats itself specially, the rust - # Program treats itself specially, but they don't recognize each other, - # so they will produce different results in this regard - rust_ret = rust_identity._run(10000, 0, args) - - # Replace rust Program with the python SerializedProgram. - args = recursive_replace(args) - - py_ret = py_identity._run(10000, 0, args) - - assert rust_ret == py_ret - - -def test_tree_hash() -> None: - - rnd = Random() - for _ in range(10000): - py_prg = ChiaProgram.to(rand_object(rnd)) - rust_prg = chia_rs.Program.from_bytes(bytes(py_prg)) - - assert py_prg.get_tree_hash() == rust_prg.get_tree_hash() - - -def test_uncurry() -> None: - - rnd = Random() - for _ in range(10000): - py_prg = ChiaProgram.to(rand_object(rnd)) - py_prg = py_prg.curry(rand_object(rnd)) - rust_prg = chia_rs.Program.from_program(py_prg) - assert py_prg.uncurry() == rust_prg.uncurry() - - py_prg = py_prg.curry(rand_object(rnd), rand_object(rnd)) - rust_prg = chia_rs.Program.from_program(py_prg) - assert py_prg.uncurry() == rust_prg.uncurry() - - -def test_round_trip() -> None: - - rnd = Random() - for _ in range(10000): - obj = rand_object(rnd) - py_prg = ChiaProgram.to(obj) - rust_prg = chia_rs.Program.from_program(py_prg) - rust_prg2 = chia_rs.Program.to(obj) - - assert py_prg == rust_prg.to_program() - assert py_prg == rust_prg2.to_program() - - assert bytes(py_prg) == bytes(rust_prg) - assert bytes(py_prg) == bytes(rust_prg2) diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 48ec6a9c6..99761952f 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -336,6 +336,35 @@ def get_flags_for_height_and_constants( constants: ConsensusConstants ) -> int: ... +def calculate_ip_iters( + num_sps_sub_slot: uint32, + num_sp_intervals_extra: uint8, + sub_slot_iters: uint64, + signage_point_index: uint32, + required_iters: uint64, +) -> uint64: ... + +def calculate_sp_iters( + num_sps_sub_slot: uint32, + sub_slot_iters: uint64, + signage_point_index: uint32, +) -> uint64: ... + +def calculate_sp_interval_iters( + num_sps_sub_slot: uint32, + sub_slot_iters: uint64, +) -> uint64: ... + +def is_overflow_block( + num_sps_sub_slot: uint32, + num_sp_intervals_extra: uint8, + signage_point_index: uint32, +) -> bool: ... + +def expected_plot_size( + k: int +) -> int: ... + NO_UNKNOWN_CONDS: int = ... STRICT_ARGS_COUNT: int = ... diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index af0306b92..abcb4e0ac 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -65,6 +65,35 @@ def get_flags_for_height_and_constants( constants: ConsensusConstants ) -> int: ... +def calculate_ip_iters( + num_sps_sub_slot: uint32, + num_sp_intervals_extra: uint8, + sub_slot_iters: uint64, + signage_point_index: uint32, + required_iters: uint64, +) -> uint64: ... + +def calculate_sp_iters( + num_sps_sub_slot: uint32, + sub_slot_iters: uint64, + signage_point_index: uint32, +) -> uint64: ... + +def calculate_sp_interval_iters( + num_sps_sub_slot: uint32, + sub_slot_iters: uint64, +) -> uint64: ... + +def is_overflow_block( + num_sps_sub_slot: uint32, + num_sp_intervals_extra: uint8, + signage_point_index: uint32, +) -> bool: ... + +def expected_plot_size( + k: int +) -> int: ... + NO_UNKNOWN_CONDS: int = ... STRICT_ARGS_COUNT: int = ... diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 70b2c5916..0f6680e00 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -16,6 +16,10 @@ use chia_consensus::spendbundle_conditions::get_conditions_from_spendbundle; use chia_consensus::spendbundle_validation::{ get_flags_for_height_and_constants, validate_clvm_and_signature, }; +use chia_protocol::{ + calculate_ip_iters, calculate_sp_interval_iters, calculate_sp_iters, expected_plot_size, + is_overflow_block, +}; use chia_protocol::{ BlockRecord, Bytes32, ChallengeBlockInfo, ChallengeChainSubSlot, ClassgroupElement, Coin, CoinSpend, CoinState, CoinStateFilters, CoinStateUpdate, EndOfSubSlotBundle, FeeEstimate, @@ -458,6 +462,13 @@ pub fn chia_rs(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { )?; m.add_class::()?; + // pot functions + m.add_function(wrap_pyfunction!(calculate_sp_interval_iters, m)?)?; + m.add_function(wrap_pyfunction!(calculate_sp_iters, m)?)?; + m.add_function(wrap_pyfunction!(calculate_ip_iters, m)?)?; + m.add_function(wrap_pyfunction!(is_overflow_block, m)?)?; + m.add_function(wrap_pyfunction!(expected_plot_size, m)?)?; + // constants m.add_class::()?;