Skip to content

Commit 7042bd1

Browse files
committed
chore: added test only methods
1 parent b51bc6f commit 7042bd1

File tree

8 files changed

+112
-76
lines changed

8 files changed

+112
-76
lines changed

crates/blockifier/src/bouncer_test.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ use crate::blockifier::transaction_executor::{
1212
use crate::bouncer::{verify_tx_weights_in_bounds, Bouncer, BouncerWeights, BuiltinCount};
1313
use crate::context::BlockContext;
1414
use crate::execution::call_info::ExecutionSummary;
15-
use crate::state::cached_state::{StateChangesKeys, TransactionalState};
16-
use crate::state::visited_pcs::VisitedPcsSet;
15+
use crate::state::cached_state::StateChangesKeys;
1716
use crate::storage_key;
1817
use crate::test_utils::initial_test_state::test_state;
1918
use crate::transaction::errors::TransactionExecutionError;
@@ -185,11 +184,11 @@ fn test_bouncer_try_update(
185184
) {
186185
use cairo_vm::vm::runners::cairo_runner::ExecutionResources;
187186

187+
use crate::state::cached_state::TransactionalState;
188188
use crate::transaction::objects::TransactionResources;
189189

190190
let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]);
191-
let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> =
192-
TransactionalState::create_transactional(state);
191+
let mut transactional_state = TransactionalState::create_transactional_for_testing(state);
193192

194193
// Setup the bouncer.
195194
let block_max_capacity = BouncerWeights {

crates/blockifier/src/concurrency/flow_test.rs

+5-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use starknet_api::{contract_address, felt, patricia_key};
99
use crate::abi::sierra_types::{SierraType, SierraU128};
1010
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
1111
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
12-
use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedStateProxy};
12+
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
1313
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
1414
use crate::state::state_api::UpdatableState;
1515
use crate::state::visited_pcs::VisitedPcsSet;
@@ -29,7 +29,6 @@ fn scheduler_flow_test(
2929
// transaction sequentially advances a counter by reading the previous value and bumping it by
3030
// 1.
3131

32-
use crate::concurrency::versioned_state::VersionedStateProxy;
3332
use crate::state::visited_pcs::VisitedPcsSet;
3433
let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE));
3534
let versioned_state =
@@ -76,8 +75,7 @@ fn scheduler_flow_test(
7675
Task::AskForTask
7776
}
7877
Task::ValidationTask(tx_index) => {
79-
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
80-
versioned_state.pin_version(tx_index);
78+
let state_proxy = versioned_state.pin_version_for_testing(tx_index);
8179
let (reads, writes) =
8280
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
8381
let read_set_valid = state_proxy.validate_reads(&reads);
@@ -129,14 +127,14 @@ fn get_reads_writes_for(
129127
) -> (StateMaps, StateMaps) {
130128
match task {
131129
Task::ExecutionTask(tx_index) => {
132-
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = match tx_index {
130+
let state_proxy = match tx_index {
133131
0 => {
134132
return (
135133
state_maps_with_single_storage_entry(0),
136134
state_maps_with_single_storage_entry(1),
137135
);
138136
}
139-
_ => versioned_state.pin_version(tx_index - 1),
137+
_ => versioned_state.pin_version_for_testing(tx_index - 1),
140138
};
141139
let tx_written_value = SierraU128::from_storage(
142140
&state_proxy,
@@ -151,8 +149,7 @@ fn get_reads_writes_for(
151149
)
152150
}
153151
Task::ValidationTask(tx_index) => {
154-
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
155-
versioned_state.pin_version(tx_index);
152+
let state_proxy = versioned_state.pin_version_for_testing(tx_index);
156153
let tx_written_value = SierraU128::from_storage(
157154
&state_proxy,
158155
&contract_address!(CONTRACT_ADDRESS),

crates/blockifier/src/concurrency/versioned_state.rs

+8
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,14 @@ impl<S: StateReader> ThreadSafeVersionedState<S> {
234234
VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData }
235235
}
236236

237+
#[cfg(test)]
238+
pub fn pin_version_for_testing(
239+
&self,
240+
tx_index: TxIndex,
241+
) -> VersionedStateProxy<S, crate::state::visited_pcs::VisitedPcsSet> {
242+
VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData }
243+
}
244+
237245
pub fn into_inner_state(self) -> VersionedState<S> {
238246
Arc::try_unwrap(self.0)
239247
.unwrap_or_else(|_| {

crates/blockifier/src/concurrency/versioned_state_test.rs

+42-44
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ use crate::abi::abi_utils::{get_fee_token_var_address, get_storage_var_address};
1414
use crate::concurrency::test_utils::{
1515
class_hash, contract_address, safe_versioned_state_for_testing,
1616
};
17-
use crate::concurrency::versioned_state::{
18-
ThreadSafeVersionedState, VersionedState, VersionedStateProxy,
19-
};
17+
use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedState};
2018
use crate::concurrency::TxIndex;
2119
use crate::context::BlockContext;
2220
use crate::state::cached_state::{
@@ -73,9 +71,8 @@ fn test_versioned_state_proxy() {
7371
let versioned_state = Arc::new(Mutex::new(VersionedState::new(cached_state)));
7472

7573
let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
76-
let versioned_state_proxys: Vec<
77-
VersionedStateProxy<CachedState<DictStateReader, VisitedPcsSet>, VisitedPcsSet>,
78-
> = (0..20).map(|i| safe_versioned_state.pin_version(i)).collect();
74+
let versioned_state_proxys: Vec<_> =
75+
(0..20).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
7976

8077
// Read initial data
8178
assert_eq!(versioned_state_proxys[5].get_nonce_at(contract_address).unwrap(), nonce);
@@ -210,14 +207,12 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) {
210207
))));
211208

212209
let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state));
213-
let mut versioned_state_proxy_1: VersionedStateProxy<_, VisitedPcsSet> =
214-
safe_versioned_state.pin_version(1);
215-
let mut state_1: TransactionalState<'_, _, VisitedPcsSet> =
216-
TransactionalState::create_transactional(&mut versioned_state_proxy_1);
217-
let mut versioned_state_proxy_2: VersionedStateProxy<_, VisitedPcsSet> =
218-
safe_versioned_state.pin_version(2);
219-
let mut state_2: TransactionalState<'_, _, VisitedPcsSet> =
220-
TransactionalState::create_transactional(&mut versioned_state_proxy_2);
210+
let mut versioned_state_proxy_1 = safe_versioned_state.pin_version_for_testing(1);
211+
let mut state_1 =
212+
TransactionalState::create_transactional_for_testing(&mut versioned_state_proxy_1);
213+
let mut versioned_state_proxy_2 = safe_versioned_state.pin_version_for_testing(2);
214+
let mut state_2 =
215+
TransactionalState::create_transactional_for_testing(&mut versioned_state_proxy_2);
221216

222217
// Prepare transactions
223218
let deploy_account_tx_1 = deploy_account_tx(
@@ -288,10 +283,9 @@ fn test_validate_reads(
288283
) {
289284
let storage_key = storage_key!(0x10_u8);
290285

291-
let mut version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
292-
safe_versioned_state.pin_version(1);
293-
let transactional_state: TransactionalState<'_, _, VisitedPcsSet> =
294-
TransactionalState::create_transactional(&mut version_state_proxy);
286+
let mut version_state_proxy = safe_versioned_state.pin_version_for_testing(1);
287+
let transactional_state =
288+
TransactionalState::create_transactional_for_testing(&mut version_state_proxy);
295289

296290
// Validating tx index 0 always succeeds.
297291
assert!(
@@ -380,8 +374,7 @@ fn test_false_validate_reads(
380374
#[case] tx_0_writes: StateMaps,
381375
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
382376
) {
383-
let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
384-
safe_versioned_state.pin_version(0);
377+
let version_state_proxy = safe_versioned_state.pin_version_for_testing(0);
385378
version_state_proxy.state().apply_writes(0, &tx_0_writes, &HashMap::default());
386379
assert!(!safe_versioned_state.pin_version::<VisitedPcsSet>(1).validate_reads(&tx_1_reads));
387380
}
@@ -398,8 +391,7 @@ fn test_false_validate_reads_declared_contracts(
398391
declared_contracts: HashMap::from([(class_hash!(1_u8), true)]),
399392
..Default::default()
400393
};
401-
let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
402-
safe_versioned_state.pin_version(0);
394+
let version_state_proxy = safe_versioned_state.pin_version_for_testing(0);
403395
let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class();
404396
let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]);
405397
version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class);
@@ -412,10 +404,12 @@ fn test_apply_writes(
412404
class_hash: ClassHash,
413405
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
414406
) {
415-
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
416-
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
417-
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
418-
versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();
407+
let mut versioned_proxy_states: Vec<_> =
408+
(0..2).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
409+
let mut transactional_states: Vec<_> = versioned_proxy_states
410+
.iter_mut()
411+
.map(TransactionalState::create_transactional_for_testing)
412+
.collect();
419413

420414
// Transaction 0 class hash.
421415
let class_hash_0 = class_hash!(76_u8);
@@ -429,7 +423,7 @@ fn test_apply_writes(
429423
transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap();
430424
assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1);
431425

432-
safe_versioned_state.pin_version(0).apply_writes(
426+
safe_versioned_state.pin_version_for_testing(0).apply_writes(
433427
&transactional_states[0].cache.borrow().writes,
434428
&transactional_states[0].class_hash_to_class.borrow().clone(),
435429
&VisitedPcsSet::default(),
@@ -447,10 +441,12 @@ fn test_apply_writes_reexecute_scenario(
447441
class_hash: ClassHash,
448442
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
449443
) {
450-
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
451-
(0..2).map(|i| safe_versioned_state.pin_version(i)).collect();
452-
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
453-
versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();
444+
let mut versioned_proxy_states: Vec<_> =
445+
(0..2).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
446+
let mut transactional_states: Vec<_> = versioned_proxy_states
447+
.iter_mut()
448+
.map(TransactionalState::create_transactional_for_testing)
449+
.collect();
454450

455451
// Transaction 0 class hash.
456452
let class_hash_0 = class_hash!(76_u8);
@@ -460,7 +456,7 @@ fn test_apply_writes_reexecute_scenario(
460456
// updated.
461457
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash);
462458

463-
safe_versioned_state.pin_version(0).apply_writes(
459+
safe_versioned_state.pin_version_for_testing(0).apply_writes(
464460
&transactional_states[0].cache.borrow().writes,
465461
&transactional_states[0].class_hash_to_class.borrow().clone(),
466462
&VisitedPcsSet::default(),
@@ -471,7 +467,7 @@ fn test_apply_writes_reexecute_scenario(
471467

472468
// TODO: Use re-execution native util once it's ready.
473469
// "Re-execute" the transaction.
474-
let mut versioned_state_proxy = safe_versioned_state.pin_version(1);
470+
let mut versioned_state_proxy = safe_versioned_state.pin_version_for_testing(1);
475471
transactional_states[1] = TransactionalState::create_transactional(&mut versioned_state_proxy);
476472
// The class hash should be updated.
477473
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
@@ -483,10 +479,12 @@ fn test_delete_writes(
483479
safe_versioned_state: ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
484480
) {
485481
let num_of_txs = 3;
486-
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
487-
(0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect();
488-
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
489-
versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect();
482+
let mut versioned_proxy_states: Vec<_> =
483+
(0..num_of_txs).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
484+
let mut transactional_states: Vec<_> = versioned_proxy_states
485+
.iter_mut()
486+
.map(TransactionalState::create_transactional_for_testing)
487+
.collect();
490488

491489
// Setting 2 instances of the contract to ensure `delete_writes` removes information from
492490
// multiple keys. Class hash values are not checked in this test.
@@ -504,7 +502,7 @@ fn test_delete_writes(
504502
tx_state
505503
.set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class())
506504
.unwrap();
507-
safe_versioned_state.pin_version(i).apply_writes(
505+
safe_versioned_state.pin_version_for_testing(i).apply_writes(
508506
&tx_state.cache.borrow().writes,
509507
&tx_state.class_hash_to_class.borrow(),
510508
&VisitedPcsSet::default(),
@@ -564,7 +562,7 @@ fn test_delete_writes_completeness(
564562
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]);
565563

566564
let tx_index = 0;
567-
let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index);
565+
let mut versioned_state_proxy = safe_versioned_state.pin_version_for_testing(tx_index);
568566

569567
versioned_state_proxy.apply_writes(
570568
&state_maps_writes,
@@ -608,13 +606,13 @@ fn test_versioned_proxy_state_flow(
608606
let contract_address = contract_address!("0x1");
609607
let class_hash = ClassHash(felt!(27_u8));
610608

611-
let mut versioned_proxy_states: Vec<VersionedStateProxy<_, VisitedPcsSet>> =
612-
(0..4).map(|i| safe_versioned_state.pin_version(i)).collect();
609+
let mut versioned_proxy_states: Vec<_> =
610+
(0..4).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect();
613611

614-
let mut transactional_states: Vec<TransactionalState<'_, _, VisitedPcsSet>> =
615-
Vec::with_capacity(4);
612+
let mut transactional_states = Vec::with_capacity(4);
616613
for proxy_state in &mut versioned_proxy_states {
617-
transactional_states.push(TransactionalState::create_transactional(proxy_state));
614+
transactional_states
615+
.push(TransactionalState::create_transactional_for_testing(proxy_state));
618616
}
619617

620618
// Clients class hash values.

crates/blockifier/src/concurrency/worker_logic.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::state::cached_state::{
1717
ContractClassMapping, StateChanges, StateMaps, TransactionalState,
1818
};
1919
use crate::state::state_api::{StateReader, UpdatableState};
20-
use crate::state::visited_pcs::VisitedPcs;
20+
use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet};
2121
use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult};
2222
use crate::transaction::transaction_execution::Transaction;
2323
use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags};
@@ -45,6 +45,17 @@ pub struct WorkerExecutor<'a, S: StateReader, V: VisitedPcs> {
4545
pub block_context: &'a BlockContext,
4646
pub bouncer: Mutex<&'a mut Bouncer>,
4747
}
48+
impl<'a, S: StateReader> WorkerExecutor<'a, S, VisitedPcsSet> {
49+
#[cfg(test)]
50+
pub fn new_for_testing(
51+
state: ThreadSafeVersionedState<S>,
52+
chunk: &'a [Transaction],
53+
block_context: &'a BlockContext,
54+
bouncer: Mutex<&'a mut Bouncer>,
55+
) -> WorkerExecutor<'a, S, VisitedPcsSet> {
56+
WorkerExecutor::new(state, chunk, block_context, bouncer)
57+
}
58+
}
4859
impl<'a, S: StateReader, V: VisitedPcs> WorkerExecutor<'a, S, V> {
4960
pub fn new(
5061
state: ThreadSafeVersionedState<S>,

0 commit comments

Comments
 (0)