Skip to content

Commit

Permalink
Merge pull request #16497 from MinaProtocol/georgeee/proof-cache-tag-…
Browse files Browse the repository at this point in the history
…20-decouple-transaction_snark_scan_state-types

Decouple Transaction_snark_scan_state types
  • Loading branch information
georgeee authored Feb 6, 2025
2 parents a2905c6 + 4d8834a commit 1b3457d
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 74 deletions.
9 changes: 7 additions & 2 deletions src/lib/bootstrap_controller/bootstrap_controller.ml
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ let run ~context:(module Context : CONTEXT) ~trust_system ~verifier ~network
| Error err ->
Deferred.return (staged_ledger_data_download_time, None, Error err)
| Ok
( scan_state
( scan_state_uncached
, expected_merkle_root
, pending_coinbases
, protocol_states ) -> (
Expand All @@ -385,7 +385,8 @@ let run ~context:(module Context : CONTEXT) ~trust_system ~verifier ~network
let open Deferred.Or_error.Let_syntax in
let received_staged_ledger_hash =
Staged_ledger_hash.of_aux_ledger_and_coinbase_hash
(Staged_ledger.Scan_state.hash scan_state)
(Staged_ledger.Scan_state.Stable.Latest.hash
scan_state_uncached )
expected_merkle_root pending_coinbases
in
[%log debug]
Expand Down Expand Up @@ -416,6 +417,10 @@ let run ~context:(module Context : CONTEXT) ~trust_system ~verifier ~network
List.map protocol_states
~f:(With_hash.of_data ~hash_data:Protocol_state.hashes)
in
let scan_state =
Staged_ledger.Scan_state.write_all_proofs_to_disk
scan_state_uncached
in
let%bind protocol_states =
Staged_ledger.Scan_state.check_required_protocol_states
scan_state ~protocol_states
Expand Down
4 changes: 2 additions & 2 deletions src/lib/ledger_catchup/normal_catchup.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1033,12 +1033,12 @@ let%test_module "Ledger_catchup tests" =
Rose_tree.equal (Rose_tree.of_list_exn target_best_tip_path)
catchup_breadcrumbs ~f:(fun breadcrumb_tree1 breadcrumb_tree2 ->
let b1 =
Mina_block.Validated.unwrap
Mina_block.Validated.read_all_proofs_from_disk
(Transition_frontier.Breadcrumb.validated_transition
breadcrumb_tree1 )
in
let b2 =
Mina_block.Validated.unwrap
Mina_block.Validated.read_all_proofs_from_disk
(Transition_frontier.Breadcrumb.validated_transition
breadcrumb_tree2 )
in
Expand Down
5 changes: 4 additions & 1 deletion src/lib/ledger_catchup/super_catchup.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,10 @@ let%test_module "Ledger_catchup tests" =
(* We force evaluation of state body hash for both blocks for further equality check *)
let _hash1 = Mina_block.Validated.state_body_hash b1 in
let _hash2 = Mina_block.Validated.state_body_hash b2 in
Mina_block.Validated.(Stable.Latest.equal (unwrap b1) (unwrap b2)) )
Mina_block.Validated.(
Stable.Latest.equal
(read_all_proofs_from_disk b1)
(read_all_proofs_from_disk b2)) )
in
if not catchup_breadcrumbs_are_best_tip_path then
failwith
Expand Down
2 changes: 1 addition & 1 deletion src/lib/mina_block/validated_block.ml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ let is_genesis t =
header t |> Header.protocol_state |> Mina_state.Protocol_state.consensus_state
|> Consensus.Data.Consensus_state.is_genesis_state

let unwrap = Fn.id
let read_all_proofs_from_disk = Fn.id
2 changes: 1 addition & 1 deletion src/lib/mina_block/validated_block.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ val body : t -> Staged_ledger_diff.Body.t

val is_genesis : t -> bool

val unwrap : t -> Stable.Latest.t
val read_all_proofs_from_disk : t -> Stable.Latest.t
4 changes: 2 additions & 2 deletions src/lib/mina_networking/mina_networking.mli
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ module Rpcs : sig
type query = State_hash.t

type response =
( Staged_ledger.Scan_state.t
( Staged_ledger.Scan_state.Stable.Latest.t
* Ledger_hash.t
* Pending_coinbase.t
* Mina_state.Protocol_state.value list )
Expand Down Expand Up @@ -219,7 +219,7 @@ val get_staged_ledger_aux_and_pending_coinbases_at_hash :
t
-> Peer.Id.t
-> State_hash.t
-> ( Staged_ledger.Scan_state.t
-> ( Staged_ledger.Scan_state.Stable.Latest.t
* Ledger_hash.t
* Pending_coinbase.t
* Mina_state.Protocol_state.value list )
Expand Down
12 changes: 9 additions & 3 deletions src/lib/mina_networking/rpcs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ module Get_staged_ledger_aux_and_pending_coinbases_at_hash = struct
type query = State_hash.t

type response =
( Staged_ledger.Scan_state.t
( Staged_ledger.Scan_state.Stable.Latest.t
* Ledger_hash.t
* Pending_coinbase.t
* Mina_state.Protocol_state.value list )
Expand Down Expand Up @@ -304,8 +304,14 @@ module Get_staged_ledger_aux_and_pending_coinbases_at_hash = struct
Actions.
(Requested_unknown_item, Some (receipt_trust_action_message hash)))
>>| const None
| _ ->
return result
| Some (scan_state, expected_merkle_root, pending_coinbases, protocol_states)
->
return
(Some
( Staged_ledger.Scan_state.read_all_proofs_from_disk scan_state
, expected_merkle_root
, pending_coinbases
, protocol_states ) )

let rate_limit_budget = (4, `Per Time.Span.minute)

Expand Down
8 changes: 4 additions & 4 deletions src/lib/proof_cache_tag/proof_cache_tag.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ type t =
| Lmdb of { cache_id : Cache.id; cache_db : Cache.t }
| Identity of Mina_base.Proof.t

let unwrap = function
let read_proof_from_disk = function
| Lmdb t ->
Cache.get t.cache_db t.cache_id
| Identity proof ->
proof

let generate db proof =
let write_proof_to_disk db proof =
match db with
| Lmdb_cache cache_db ->
Lmdb { cache_id = Cache.put cache_db proof; cache_db }
Expand All @@ -30,11 +30,11 @@ module For_tests = struct

let blockchain_dummy =
Lazy.map
~f:(fun dummy -> generate (create_db ()) dummy)
~f:(fun dummy -> write_proof_to_disk (create_db ()) dummy)
Mina_base.Proof.blockchain_dummy

let transaction_dummy =
Lazy.map
~f:(fun dummy -> generate (create_db ()) dummy)
~f:(fun dummy -> write_proof_to_disk (create_db ()) dummy)
Mina_base.Proof.transaction_dummy
end
4 changes: 2 additions & 2 deletions src/lib/proof_cache_tag/proof_cache_tag.mli
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ val create_db :
-> logger:Logger.t
-> (cache_db, [> `Initialization_error of Error.t ]) Deferred.Result.t

val unwrap : t -> Mina_base.Proof.t
val read_proof_from_disk : t -> Mina_base.Proof.t

val generate : cache_db -> Mina_base.Proof.t -> t
val write_proof_to_disk : cache_db -> Mina_base.Proof.t -> t

module For_tests : sig
val blockchain_dummy : t lazy_t
Expand Down
2 changes: 1 addition & 1 deletion src/lib/staged_ledger/staged_ledger.ml
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ module T = struct
; pending_coinbase_collection
} : Staged_ledger_hash.t =
Staged_ledger_hash.of_aux_ledger_and_coinbase_hash
(Scan_state.hash scan_state)
Scan_state.(Stable.Latest.hash @@ read_all_proofs_from_disk scan_state)
(Ledger.merkle_root ledger)
pending_coinbase_collection

Expand Down
10 changes: 8 additions & 2 deletions src/lib/staged_ledger/staged_ledger.mli
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ type t
module Scan_state : sig
[%%versioned:
module Stable : sig
[@@@no_toplevel_latest_type]

module V2 : sig
type t

val hash : t -> Staged_ledger_hash.Aux_hash.t
end
end]

type t

module Job_view : sig
type t [@@deriving sexp, to_yojson]
end
Expand Down Expand Up @@ -46,8 +50,6 @@ module Scan_state : sig
[@@deriving sexp, to_yojson]
end

val hash : t -> Staged_ledger_hash.Aux_hash.t

val empty :
constraint_constants:Genesis_constants.Constraint_constants.t -> unit -> t

Expand Down Expand Up @@ -131,6 +133,10 @@ module Scan_state : sig
Or_error.t )
-> t
-> unit Deferred.Or_error.t

val write_all_proofs_to_disk : Stable.Latest.t -> t

val read_all_proofs_from_disk : t -> Stable.Latest.t
end

module Pre_diff_info : Pre_diff_info.S
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ end
module Ledger_proof_with_sok_message = struct
[%%versioned
module Stable = struct
[@@@no_toplevel_latest_type]

module V2 = struct
type t = Ledger_proof.Stable.V2.t * Sok_message.Stable.V1.t
[@@deriving sexp]

let to_latest = Fn.id
end
end]

type t = Ledger_proof.t * Sok_message.t
end

module Available_job = struct
Expand Down Expand Up @@ -155,6 +159,8 @@ type job = Available_job.t
the snarked ledger*)
[%%versioned
module Stable = struct
[@@@no_toplevel_latest_type]

module V2 = struct
type t =
{ scan_state :
Expand Down Expand Up @@ -195,7 +201,15 @@ module Stable = struct
end
end]

[%%define_locally Stable.Latest.(hash)]
type t =
{ scan_state :
( Ledger_proof_with_sok_message.t
, Transaction_with_witness.t )
Parallel_scan.State.t
; previous_incomplete_zkapp_updates :
Transaction_with_witness.t list
* [ `Border_block_continued_in_the_next_tree of bool ]
}

(**********Helpers*************)

Expand Down Expand Up @@ -314,14 +328,13 @@ let total_proofs (works : Transaction_snark_work.t list) =

(*************exposed functions*****************)

module P = struct
type t = Ledger_proof_with_sok_message.t
end

module Make_statement_scanner (Verifier : sig
type t

val verify : verifier:t -> P.t list -> unit Or_error.t Deferred.Or_error.t
val verify :
verifier:t
-> Ledger_proof_with_sok_message.t list
-> unit Or_error.t Deferred.Or_error.t
end) =
struct
module Fold = Parallel_scan.State.Make_foldable (Deferred)
Expand Down Expand Up @@ -376,12 +389,8 @@ struct
end

(*TODO: fold over the pending_coinbase tree and validate the statements?*)
let scan_statement ~constraint_constants ~logger
({ scan_state = tree; previous_incomplete_zkapp_updates = _ } : t)
~statement_check ~verifier :
( Transaction_snark.Statement.t
, [ `Error of Error.t | `Empty ] )
Deferred.Result.t =
let scan_statement (type merge) ~constraint_constants ~logger
~merge_to_statement tree ~statement_check ~verify =
let open Deferred.Or_error.Let_syntax in
let timer = Timer.create ~logger () in
let yield_occasionally =
Expand All @@ -392,7 +401,7 @@ struct
Async.Scheduler.yield () |> Deferred.map ~f:Or_error.return
in
let module Acc = struct
type t = (Transaction_snark.Statement.t * P.t list) option
type t = (Transaction_snark.Statement.t * merge list) option
end in
let write_error description =
sprintf !"Staged_ledger.scan_statement: %s\n" description
Expand Down Expand Up @@ -440,24 +449,23 @@ struct
in
let fold_step_a (acc_statement, acc_pc) job =
match job with
| Parallel_scan.Merge.Job.Part (proof, message) ->
let statement = Ledger_proof.statement proof in
| Parallel_scan.Merge.Job.Part merge ->
let statement = merge_to_statement merge in
let%map acc_stmt =
merge_acc ~proofs:[ (proof, message) ] acc_statement statement
merge_acc ~proofs:[ merge ] acc_statement statement
in
(acc_stmt, acc_pc)
| Empty | Full { status = Parallel_scan.Job_status.Done; _ } ->
return (acc_statement, acc_pc)
| Full { left = proof_1, message_1; right = proof_2, message_2; _ } ->
let stmt1 = Ledger_proof.statement proof_1 in
let stmt2 = Ledger_proof.statement proof_2 in
| Full { left; right; _ } ->
let stmt1 = merge_to_statement left in
let stmt2 = merge_to_statement right in
let%bind merged_statement =
Timer.time timer (sprintf "merge:%s" __LOC__) (fun () ->
Deferred.return (Transaction_snark.Statement.merge stmt1 stmt2) )
in
let%map acc_stmt =
merge_acc acc_statement merged_statement
~proofs:[ (proof_1, message_1); (proof_2, message_2) ]
merge_acc acc_statement merged_statement ~proofs:[ left; right ]
in
(acc_stmt, acc_pc)
in
Expand Down Expand Up @@ -541,7 +549,7 @@ struct
| Ok (None, _) ->
Deferred.return (Error `Empty)
| Ok (Some (res, proofs), _) -> (
match%map.Deferred Verifier.verify ~verifier proofs with
match%map.Deferred verify proofs with
| Ok (Ok ()) ->
Ok res
| Ok (Error err) ->
Expand All @@ -551,8 +559,8 @@ struct
| Error e ->
Deferred.return (Error (`Error e))

let check_invariants t ~constraint_constants ~logger ~statement_check
~verifier ~error_prefix
let check_invariants_impl parallel_scan_state ~merge_to_statement
~constraint_constants ~logger ~statement_check ~verify ~error_prefix
~(last_proof_statement : Transaction_snark.Statement.t option)
~(registers_end :
( Frozen_ledger_hash.t
Expand Down Expand Up @@ -590,8 +598,8 @@ struct
in
match%map
O1trace.sync_thread "validate_transaction_snark_scan_state" (fun () ->
scan_statement t ~constraint_constants ~logger ~statement_check
~verifier )
scan_statement parallel_scan_state ~constraint_constants ~logger
~statement_check ~verify ~merge_to_statement )
with
| Error (`Error e) ->
Error e
Expand Down Expand Up @@ -631,6 +639,11 @@ struct
"nondefault fee token"
in
()

let check_invariants (t : t) ~verifier =
check_invariants_impl t.scan_state
~merge_to_statement:(Fn.compose Ledger_proof.statement fst)
~verify:(Verifier.verify ~verifier)
end

let statement_of_job : job -> Transaction_snark.Statement.t option = function
Expand Down Expand Up @@ -1332,7 +1345,7 @@ let update_metrics t = Parallel_scan.update_metrics t.scan_state
let fill_work_and_enqueue_transactions t ~logger transactions work =
let open Or_error.Let_syntax in
let fill_in_transaction_snark_work tree (works : Transaction_snark_work.t list)
: (Ledger_proof.t * Sok_message.t) list Or_error.t =
: Ledger_proof_with_sok_message.t list Or_error.t =
let next_jobs =
List.(
take
Expand Down Expand Up @@ -1454,3 +1467,11 @@ let check_required_protocol_states t ~protocol_states =
in
let%map () = check_length protocol_states_assoc in
protocol_states_assoc

let write_all_proofs_to_disk
{ Stable.Latest.scan_state; previous_incomplete_zkapp_updates } =
{ scan_state; previous_incomplete_zkapp_updates }

let read_all_proofs_from_disk { scan_state; previous_incomplete_zkapp_updates }
=
{ Stable.Latest.scan_state; previous_incomplete_zkapp_updates }
Loading

0 comments on commit 1b3457d

Please sign in to comment.