Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy committed Dec 27, 2023
1 parent ce4f395 commit 960a847
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 131 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 5 additions & 11 deletions crates/katana/core/src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use katana_primitives::contract::{
CompiledContractClass, CompiledContractClassV0, ContractAddress, StorageKey,
};
use katana_primitives::contract::{CompiledContractClass, ContractAddress, StorageKey};
use katana_primitives::utils::class::parse_compiled_class;
use katana_primitives::FieldElement;
use lazy_static::lazy_static;
use starknet::macros::felt;
Expand All @@ -10,11 +9,6 @@ pub const DEFAULT_GAS_PRICE: u128 = 100 * u128::pow(10, 9); // Given in units of
pub const DEFAULT_INVOKE_MAX_STEPS: u32 = 1_000_000;
pub const DEFAULT_VALIDATE_MAX_STEPS: u32 = 1_000_000;

fn parse_legacy_contract_class(content: impl AsRef<str>) -> CompiledContractClass {
let class: CompiledContractClassV0 = serde_json::from_str(content.as_ref()).unwrap();
CompiledContractClass::V0(class)
}

lazy_static! {

// Predefined contract addresses
Expand All @@ -35,9 +29,9 @@ lazy_static! {

// Predefined contract classes

pub static ref ERC20_CONTRACT: CompiledContractClass = parse_legacy_contract_class(include_str!("../contracts/compiled/erc20.json"));
pub static ref UDC_CONTRACT: CompiledContractClass = parse_legacy_contract_class(include_str!("../contracts/compiled/universal_deployer.json"));
pub static ref OZ_V0_ACCOUNT_CONTRACT: CompiledContractClass = parse_legacy_contract_class(include_str!("../contracts/compiled/account.json"));
pub static ref ERC20_CONTRACT: CompiledContractClass = parse_compiled_class(include_str!("../contracts/compiled/erc20.json")).unwrap();
pub static ref UDC_CONTRACT: CompiledContractClass = parse_compiled_class(include_str!("../contracts/compiled/universal_deployer.json")).unwrap();
pub static ref OZ_V0_ACCOUNT_CONTRACT: CompiledContractClass = parse_compiled_class(include_str!("../contracts/compiled/account.json")).unwrap();

pub static ref DEFAULT_PREFUNDED_ACCOUNT_BALANCE: FieldElement = felt!("0x3635c9adc5dea00000"); // 10^21

Expand Down
26 changes: 26 additions & 0 deletions crates/katana/primitives/src/utils/class.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use anyhow::Result;
use cairo_lang_starknet::casm_contract_class::CasmContractClass;
use cairo_lang_starknet::contract_class::ContractClass;

use crate::contract::{CompiledContractClass, CompiledContractClassV0, CompiledContractClassV1};

/// Parse a [`str`] into a [`CompiledContractClass`].
pub fn parse_compiled_class(class: &str) -> Result<CompiledContractClass> {
if let Ok(class) = parse_compiled_class_v1(class) {
Ok(CompiledContractClass::V1(class))
} else {
Ok(CompiledContractClass::V0(parse_compiled_class_v0(class)?))
}
}

/// Parse a [`str`] into a [`CompiledContractClassV1`].
pub fn parse_compiled_class_v1(class: &str) -> Result<CompiledContractClassV1> {
let class: ContractClass = serde_json::from_str(class)?;
let class = CasmContractClass::from_contract_class(class, true)?;
Ok(CompiledContractClassV1::try_from(class)?)
}

/// Parse a [`str`] into a [`CompiledContractClassV0`].
pub fn parse_compiled_class_v0(class: &str) -> Result<CompiledContractClassV0, std::io::Error> {
Ok(serde_json::from_str(class)?)
}
1 change: 1 addition & 0 deletions crates/katana/primitives/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod class;
pub mod transaction;
24 changes: 5 additions & 19 deletions crates/katana/storage/db/benches/codec.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,27 @@
use blockifier::execution::contract_class::ContractClassV1;
use cairo_lang_starknet::casm_contract_class::CasmContractClass;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use katana_db::codecs::{Compress, Decompress};
use katana_db::models::class::StoredContractClass;
use katana_primitives::contract::CompiledContractClass;
use katana_primitives::utils::class::parse_compiled_class;

fn compress_contract(contract: CompiledContractClass) -> Vec<u8> {
let class = StoredContractClass::from(contract);
class.compress()
StoredContractClass::from(contract).compress()
}

fn decompress_contract(compressed: &[u8]) -> CompiledContractClass {
let class = StoredContractClass::decompress(compressed).unwrap();
CompiledContractClass::from(class)
CompiledContractClass::from(StoredContractClass::decompress(compressed).unwrap())
}

fn compress_contract_with_main_codec(c: &mut Criterion) {
let class = {
let class =
serde_json::from_slice(include_bytes!("./artifacts/dojo_world_240.json")).unwrap();
let class = CasmContractClass::from_contract_class(class, true).unwrap();
CompiledContractClass::V1(ContractClassV1::try_from(class).unwrap())
};
let class = parse_compiled_class(include_str!("./artifacts/dojo_world_240.json")).unwrap();

c.bench_function("compress world contract", |b| {
b.iter_with_large_drop(|| compress_contract(black_box(class.clone())))
});
}

fn decompress_contract_with_main_codec(c: &mut Criterion) {
let class = {
let class =
serde_json::from_slice(include_bytes!("./artifacts/dojo_world_240.json")).unwrap();
let class = CasmContractClass::from_contract_class(class, true).unwrap();
CompiledContractClass::V1(ContractClassV1::try_from(class).unwrap())
};

let class = parse_compiled_class(include_str!("./artifacts/dojo_world_240.json")).unwrap();
let compressed = compress_contract(class);

c.bench_function("decompress world contract", |b| {
Expand Down
2 changes: 2 additions & 0 deletions crates/katana/storage/provider/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ fork = [ "in-memory" ]
in-memory = [ ]

[dev-dependencies]
katana-core = { path = "../../core" }
katana-runner = { path = "../../runner" }
lazy_static.workspace = true
rand = "0.8.5"
rstest = "0.18.2"
rstest_reuse = "0.6.0"
serde_json.workspace = true
starknet.workspace = true
tempfile = "3.8.1"
url.workspace = true
24 changes: 14 additions & 10 deletions crates/katana/storage/provider/src/providers/db/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,24 +212,28 @@ impl ContractClassProvider for HistoricalStateProvider {
hash: ClassHash,
) -> Result<Option<CompiledClassHash>> {
// check that the requested class hash was declared before the pinned block number
if !self.tx.get::<ClassDeclarationBlock>(hash)?.is_some_and(|num| num <= self.block_number)
{
return Ok(None);
};

Ok(self.tx.get::<CompiledClassHashes>(hash)?)
if self.tx.get::<ClassDeclarationBlock>(hash)?.is_some_and(|num| num <= self.block_number) {
Ok(self.tx.get::<CompiledClassHashes>(hash)?)
} else {
Ok(None)
}
}

fn class(&self, hash: ClassHash) -> Result<Option<CompiledContractClass>> {
self.compiled_class_hash_of_class_hash(hash).and_then(|_| {
if self.compiled_class_hash_of_class_hash(hash)?.is_some() {
let contract = self.tx.get::<CompiledContractClasses>(hash)?;
Ok(contract.map(CompiledContractClass::from))
})
} else {
Ok(None)
}
}

fn sierra_class(&self, hash: ClassHash) -> Result<Option<SierraClass>> {
self.compiled_class_hash_of_class_hash(hash)
.and_then(|_| self.tx.get::<SierraClasses>(hash).map_err(|e| e.into()))
if self.compiled_class_hash_of_class_hash(hash)?.is_some() {
self.tx.get::<SierraClasses>(hash).map_err(|e| e.into())
} else {
Ok(None)
}
}
}

Expand Down
10 changes: 9 additions & 1 deletion crates/katana/storage/provider/src/providers/fork/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,15 @@ impl ContractClassProvider for SharedStateProvider {
return Ok(class.cloned());
}

let class = self.0.do_get_class_at(hash)?;
let Some(class) = handle_contract_or_class_not_found_err(self.0.do_get_class_at(hash))
.map_err(|e| {
error!(target: "forked_backend", "error while fetching sierra class {hash:#x}: {e}");
e
})?
else {
return Ok(None);
};

match class {
starknet::core::types::ContractClass::Legacy(_) => Ok(None),
starknet::core::types::ContractClass::Sierra(sierra_class) => {
Expand Down
14 changes: 8 additions & 6 deletions crates/katana/storage/provider/src/providers/fork/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ impl StateProvider for ForkedSnapshot {

impl ContractClassProvider for ForkedSnapshot {
fn sierra_class(&self, hash: ClassHash) -> Result<Option<SierraClass>> {
if let class @ Some(_) = self.classes.sierra_classes.read().get(&hash).cloned() {
return Ok(class);
if self.inner.compiled_class_hashes.get(&hash).is_some() {
Ok(self.classes.sierra_classes.read().get(&hash).cloned())
} else {
ContractClassProvider::sierra_class(&self.inner.db, hash)
}
ContractClassProvider::sierra_class(&self.inner.db, hash)
}

fn compiled_class_hash_of_class_hash(
Expand All @@ -190,9 +191,10 @@ impl ContractClassProvider for ForkedSnapshot {
}

fn class(&self, hash: ClassHash) -> Result<Option<CompiledContractClass>> {
if let class @ Some(_) = self.classes.compiled_classes.read().get(&hash).cloned() {
return Ok(class);
if self.inner.compiled_class_hashes.get(&hash).is_some() {
Ok(self.classes.compiled_classes.read().get(&hash).cloned())
} else {
ContractClassProvider::class(&self.inner.db, hash)
}
ContractClassProvider::class(&self.inner.db, hash)
}
}
16 changes: 12 additions & 4 deletions crates/katana/storage/provider/src/providers/in_memory/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use crate::traits::state::StateProvider;
use crate::Result;

pub struct StateSnapshot<Db> {
// because the classes are shared between snapshots, when trying to fetch check the compiled
// hash first and then the sierra class to ensure the class should be present in the snapshot.
pub(crate) classes: Arc<SharedContractClasses>,
pub(crate) inner: CacheSnapshotWithoutClasses<Db>,
}
Expand Down Expand Up @@ -145,13 +147,19 @@ impl StateProvider for InMemorySnapshot {

impl ContractClassProvider for InMemorySnapshot {
fn sierra_class(&self, hash: ClassHash) -> Result<Option<SierraClass>> {
let class = self.classes.sierra_classes.read().get(&hash).cloned();
Ok(class)
if self.compiled_class_hash_of_class_hash(hash)?.is_some() {
Ok(self.classes.sierra_classes.read().get(&hash).cloned())
} else {
Ok(None)
}
}

fn class(&self, hash: ClassHash) -> Result<Option<CompiledContractClass>> {
let class = self.classes.compiled_classes.read().get(&hash).cloned();
Ok(class)
if self.compiled_class_hash_of_class_hash(hash)?.is_some() {
Ok(self.classes.compiled_classes.read().get(&hash).cloned())
} else {
Ok(None)
}
}

fn compiled_class_hash_of_class_hash(
Expand Down
22 changes: 11 additions & 11 deletions crates/katana/storage/provider/tests/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use katana_primitives::block::{
Block, BlockHashOrNumber, BlockNumber, BlockWithTxHashes, FinalityStatus,
};
use katana_primitives::state::StateUpdates;
use katana_primitives::state::StateUpdatesWithDeclaredClasses;
use katana_provider::providers::db::DbProvider;
use katana_provider::providers::fork::ForkedProvider;
use katana_provider::providers::in_memory::InMemoryProvider;
Expand Down Expand Up @@ -141,22 +141,22 @@ where

#[template]
#[rstest::rstest]
#[case::state_update_at_block_1(1, mock_state_updates().0)]
#[case::state_update_at_block_2(2, mock_state_updates().1)]
#[case::state_update_at_block_3(3, StateUpdates::default())]
#[case::state_update_at_block_5(5, mock_state_updates().2)]
#[case::state_update_at_block_1(1, mock_state_updates()[0].clone())]
#[case::state_update_at_block_2(2, mock_state_updates()[1].clone())]
#[case::state_update_at_block_3(3, StateUpdatesWithDeclaredClasses::default())]
#[case::state_update_at_block_5(5, mock_state_updates()[2].clone())]
fn test_read_state_update<Db>(
#[from(provider_with_states)] provider: BlockchainProvider<Db>,
#[case] block_num: BlockNumber,
#[case] expected_state_update: StateUpdates,
#[case] expected_state_update: StateUpdatesWithDeclaredClasses,
) {
}

#[apply(test_read_state_update)]
fn test_read_state_update_with_in_memory_provider(
#[with(in_memory_provider())] provider: BlockchainProvider<InMemoryProvider>,
#[case] block_num: BlockNumber,
#[case] expected_state_update: StateUpdates,
#[case] expected_state_update: StateUpdatesWithDeclaredClasses,
) -> Result<()> {
test_read_state_update_impl(provider, block_num, expected_state_update)
}
Expand All @@ -167,7 +167,7 @@ fn test_read_state_update_with_fork_provider(
ForkedProvider,
>,
#[case] block_num: BlockNumber,
#[case] expected_state_update: StateUpdates,
#[case] expected_state_update: StateUpdatesWithDeclaredClasses,
) -> Result<()> {
test_read_state_update_impl(provider, block_num, expected_state_update)
}
Expand All @@ -176,20 +176,20 @@ fn test_read_state_update_with_fork_provider(
fn test_read_state_update_with_db_provider(
#[with(db_provider())] provider: BlockchainProvider<DbProvider>,
#[case] block_num: BlockNumber,
#[case] expected_state_update: StateUpdates,
#[case] expected_state_update: StateUpdatesWithDeclaredClasses,
) -> Result<()> {
test_read_state_update_impl(provider, block_num, expected_state_update)
}

fn test_read_state_update_impl<Db>(
provider: BlockchainProvider<Db>,
block_num: BlockNumber,
expected_state_update: StateUpdates,
expected_state_update: StateUpdatesWithDeclaredClasses,
) -> Result<()>
where
Db: StateUpdateProvider,
{
let actual_state_update = provider.state_update(BlockHashOrNumber::from(block_num))?;
assert_eq!(actual_state_update, Some(expected_state_update));
assert_eq!(actual_state_update, Some(expected_state_update.state_updates));
Ok(())
}
Loading

0 comments on commit 960a847

Please sign in to comment.