diff --git a/Cargo.lock b/Cargo.lock index 3e76c8d84..4eb8ac199 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -578,6 +578,7 @@ dependencies = [ "criterion", "ff_ext", "generic_static", + "gkr_iop", "glob", "itertools 0.13.0", "mpcs", @@ -597,6 +598,7 @@ dependencies = [ "serde_json", "strum", "strum_macros", + "subprotocols", "sumcheck", "tempfile", "thread_local", @@ -1172,16 +1174,21 @@ name = "gkr_iop" version = "0.1.0" dependencies = [ "ark-std", + "criterion", "ff_ext", "itertools 0.13.0", "multilinear_extensions", + "ndarray", "p3-field", "p3-goldilocks", "rand", "rayon", + "serde", "subprotocols", "thiserror 1.0.69", + "tiny-keccak", "transcript", + "witness", ] [[package]] @@ -1487,6 +1494,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -1589,6 +1606,21 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nimue" version = "0.2.0" @@ -2089,6 +2121,21 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "poseidon" version = "0.1.0" @@ -2296,6 +2343,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -2683,6 +2736,7 @@ dependencies = [ "p3-goldilocks", "rand", "rayon", + "serde", "thiserror 1.0.69", "transcript", ] diff --git a/Cargo.toml b/Cargo.toml index 4fd37c506..608af8c6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" itertools = "0.13" +ndarray = "*" num-bigint = { version = "0.4.6" } num-derive = "0.4" num-traits = "0.2" @@ -60,7 +61,7 @@ rand_core = "0.6" rand_xorshift = "0.3" rayon = "1.10" secp = "0.4.1" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" strum = "0.26" strum_macros = "0.26" diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index b6fd82ac2..1d6cc7df8 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -20,6 +20,10 @@ pub trait SyscallSpec { const REG_OPS_COUNT: usize; const MEM_OPS_COUNT: usize; const CODE: u32; + + const HAS_LOOKUPS: bool = false; + + const GKR_OUTPUTS: usize = 0; } /// Trace the inputs and effects of a syscall. diff --git a/ceno_emul/src/syscalls/keccak_permute.rs b/ceno_emul/src/syscalls/keccak_permute.rs index 3748fc3ea..7afd3ef83 100644 --- a/ceno_emul/src/syscalls/keccak_permute.rs +++ b/ceno_emul/src/syscalls/keccak_permute.rs @@ -16,6 +16,7 @@ impl SyscallSpec for KeccakSpec { const REG_OPS_COUNT: usize = 2; const MEM_OPS_COUNT: usize = KECCAK_WORDS; const CODE: u32 = ceno_rt::syscalls::KECCAK_PERMUTE; + const HAS_LOOKUPS: bool = true; } /// Wrapper type for the keccak_permute argument that implements conversions diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index 6353a0cf9..d848652d9 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -231,7 +231,7 @@ fn test_ceno_rt_keccak() -> Result<()> { let steps = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. - const ITERATIONS: usize = 3; + const ITERATIONS: usize = 4; let keccak_outs = sample_keccak_f(ITERATIONS); let all_messages = read_all_messages(&state); diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3c23fe0b1..1d525e471 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -19,9 +19,11 @@ serde_json.workspace = true base64 = "0.22" ceno_emul = { path = "../ceno_emul" } ff_ext = { path = "../ff_ext" } +gkr_iop = { path = "../gkr_iop" } mpcs = { path = "../mpcs" } multilinear_extensions = { version = "0", path = "../multilinear_extensions" } p3 = { path = "../p3" } +subprotocols = { path = "../subprotocols" } sumcheck = { version = "0", path = "../sumcheck" } transcript = { path = "../transcript" } witness = { path = "../witness" } diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index b0e6cd2b0..b35af5385 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -95,6 +95,7 @@ fn bench_add(c: &mut Criterion) { "ADD", &prover.pk.pp, &circuit_pk, + None, polys.into_iter().map(|mle| mle.into()).collect_vec(), commit, &[], diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 58efee2ae..634c35d56 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,7 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; +use gkr_iop::{ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuitWitness}; +use itertools::Itertools; use multilinear_extensions::util::max_usable_threads; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, @@ -67,3 +69,170 @@ pub trait Instruction { Ok((raw_witin, lk_multiplicity)) } } + +pub struct GKRinfo { + pub and_lookups: usize, + pub xor_lookups: usize, + pub range_lookups: usize, + pub aux_wits: usize, +} + +impl GKRinfo { + fn lookup_total(&self) -> usize { + self.and_lookups + self.xor_lookups + self.range_lookups + } +} + +// Trait that should be implemented by opcodes which use the +// GKR-IOP prover. Presently, such opcodes should also implement +// the Instruction trait and in general methods of GKRIOPInstruction +// will call corresponding methods of Instruction and do something extra +// with respect to syncing state between GKR-IOP and old-style circuits/witnesses +pub trait GKRIOPInstruction +where + Self: Instruction, +{ + type Layout: ProtocolWitnessGenerator + ProtocolBuilder; + + /// Similar to Instruction::construct_circuit; generally + /// meant to extend InstructionConfig with GKR-specific + /// fields + #[allow(unused_variables)] + fn construct_circuit_with_gkr_iop( + cb: &mut CircuitBuilder, + ) -> Result { + unimplemented!(); + } + + /// Should generate phase1 witness for GKR from step records + fn phase1_witness_from_steps( + layout: &Self::Layout, + steps: &[StepRecord], + ) -> Vec>; + + /// Similar to Instruction::assign_instance, but with access to GKR lookups and wits + fn assign_instance_with_gkr_iop( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + lookups: &[E::BaseField], + aux_wits: &[E::BaseField], + ) -> Result<(), ZKVMError>; + + /// Similar to Instruction::assign_instances, but with access to the GKR layout. + #[allow(clippy::type_complexity)] + fn assign_instances_with_gkr_iop( + config: &Self::InstructionConfig, + num_witin: usize, + steps: Vec, + gkr_layout: &Self::Layout, + ) -> Result< + ( + RowMajorMatrix, + GKRCircuitWitness, + LkMultiplicity, + ), + ZKVMError, + > { + let nthreads = max_usable_threads(); + let num_instance_per_batch = if steps.len() > 256 { + steps.len().div_ceil(nthreads) + } else { + steps.len() + } + .max(1); + let lk_multiplicity = LkMultiplicity::default(); + let mut raw_witin = + RowMajorMatrix::::new(steps.len(), num_witin, Self::padding_strategy()); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + + let gkr_witness = + gkr_layout.gkr_witness(&Self::phase1_witness_from_steps(gkr_layout, &steps), &[]); + + let (lookups, aux_wits) = { + // Extract lookups and auxiliary witnesses from GKR protocol + // Here we assume that the gkr_witness's last layer is a convenience layer which holds + // the output records for all instances; further, we assume that the last ```Self::lookup_count()``` + // elements of this layer are the lookup arguments. + let mut lookups = vec![vec![]; steps.len()]; + let mut aux_wits: Vec> = vec![vec![]; steps.len()]; + let last_layer = gkr_witness.layers.last().unwrap().bases.clone(); + let len = last_layer.len(); + let lookup_total = Self::gkr_info().lookup_total(); + + let non_lookups = if len == 0 { + 0 + } else { + assert!(len >= lookup_total); + len - lookup_total + }; + + for witness in last_layer.iter().skip(non_lookups) { + for i in 0..witness.len() { + lookups[i].push(witness[i]); + } + } + + let n_layers = gkr_witness.layers.len(); + + for i in 0..steps.len() { + // Omit last layer, which stores outputs and not real witnesses + for layer in gkr_witness.layers[..n_layers - 1].iter() { + for base in layer.bases.iter() { + aux_wits[i].push(base[i]); + } + } + } + + (lookups, aux_wits) + }; + + raw_witin_iter + .zip( + steps + .iter() + .enumerate() + .collect_vec() + .par_chunks(num_instance_per_batch), + ) + .flat_map(|(instances, steps)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip(steps) + .map(|(instance, (i, step))| { + Self::assign_instance_with_gkr_iop( + config, + instance, + &mut lk_multiplicity, + step, + &lookups[*i], + &aux_wits[*i], + ) + }) + .collect::>() + }) + .collect::>()?; + + raw_witin.padding_by_strategy(); + Ok((raw_witin, gkr_witness, lk_multiplicity)) + } + + /// Lookup and witness counts used by GKR proof + fn gkr_info() -> GKRinfo; + + /// Returns corresponding column in RMM for the i-th + /// output evaluation of the GKR proof + #[allow(unused_variables)] + fn output_evals_map(i: usize) -> usize { + unimplemented!(); + } + + /// Returns corresponding column in RMM for the i-th + /// witness of the GKR proof + #[allow(unused_variables)] + fn witness_map(i: usize) -> usize { + unimplemented!(); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index a4e572935..4fa0b0e5f 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; -use ceno_emul::{Change, InsnKind, StepRecord, SyscallSpec}; -use ff_ext::ExtensionField; -use itertools::Itertools; +use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, SyscallSpec}; +use ff_ext::{ExtensionField, SmallField}; +use itertools::{Itertools, zip_eq}; use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; use crate::{ @@ -11,7 +11,7 @@ use crate::{ error::ZKVMError, expression::{ToExpr, WitIn}, instructions::{ - Instruction, + GKRIOPInstruction, GKRinfo, Instruction, riscv::{constants::UInt, insn_base::WriteRD}, }, set_val, @@ -19,6 +19,11 @@ use crate::{ }; use ff_ext::FieldInto; +use gkr_iop::{ + ProtocolWitnessGenerator, + precompiles::{AND_LOOKUPS, KeccakLayout, KeccakTrace, RANGE_LOOKUPS, XOR_LOOKUPS}, +}; + /// LargeEcallDummy can handle any instruction and produce its effects, /// including multiple memory operations. /// @@ -29,7 +34,7 @@ impl Instruction for LargeEcallDummy type InstructionConfig = LargeEcallConfig; fn name() -> String { - format!("{}_DUMMY", S::NAME) + S::NAME.to_owned() } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let dummy_insn = DummyConfig::construct_circuit( @@ -56,8 +61,8 @@ impl Instruction for LargeEcallDummy let mem_writes = (0..S::MEM_OPS_COUNT) .map(|i| { - let val_before = cb.create_witin(|| format!("mem_before_{}", i)); - let val_after = cb.create_witin(|| format!("mem_after_{}", i)); + let val_before = cb.create_witin(|| format!("mem_before_{}_READ_ARG", i)); + let val_after = cb.create_witin(|| format!("mem_after_{}_WRITE_ARG", i)); let addr = cb.create_witin(|| format!("addr_{}", i)); WriteMEM::construct_circuit( cb, @@ -70,11 +75,17 @@ impl Instruction for LargeEcallDummy }) .collect::, _>>()?; + // Will be filled in by GKR Instruction trait + let lookups = vec![]; + let aux_wits = vec![]; + Ok(LargeEcallConfig { dummy_insn, start_addr, reg_writes, mem_writes, + lookups, + aux_wits, }) } @@ -111,6 +122,135 @@ impl Instruction for LargeEcallDummy } } +impl GKRIOPInstruction for LargeEcallDummy { + type Layout = KeccakLayout; + + fn gkr_info() -> crate::instructions::GKRinfo { + GKRinfo { + and_lookups: 3 * AND_LOOKUPS, + xor_lookups: 3 * XOR_LOOKUPS, + range_lookups: RANGE_LOOKUPS, + aux_wits: 40144, + } + } + + fn construct_circuit_with_gkr_iop( + cb: &mut CircuitBuilder, + ) -> Result { + let mut partial_config = Self::construct_circuit(cb)?; + + assert!(partial_config.lookups.is_empty()); + assert!(partial_config.aux_wits.is_empty()); + + // TODO: capacity + let mut lookups = vec![]; + let mut aux_wits = vec![]; + + for i in 0..AND_LOOKUPS { + let a = cb.create_witin(|| format!("and_lookup_{i}_a_LOOKUP_ARG")); + let b = cb.create_witin(|| format!("and_lookup_{i}_b_LOOKUP_ARG")); + let c = cb.create_witin(|| format!("and_lookup_{i}_c_LOOKUP_ARG")); + cb.lookup_and_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..XOR_LOOKUPS { + let a = cb.create_witin(|| format!("xor_lookup_{i}_a_LOOKUP_ARG")); + let b = cb.create_witin(|| format!("xor_lookup_{i}_b_LOOKUP_ARG")); + let c = cb.create_witin(|| format!("xor_lookup_{i}_c_LOOKUP_ARG")); + cb.lookup_xor_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..RANGE_LOOKUPS { + let wit = cb.create_witin(|| format!("range_lookup_{i}_LOOKUP_ARG")); + cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; + lookups.push(wit); + } + + for i in 0..40144 { + aux_wits.push(cb.create_witin(|| format!("{i}_GKR_WITNESS"))); + } + + partial_config.lookups = lookups; + partial_config.aux_wits = aux_wits; + + Ok(partial_config) + } + + // TODO: make this nicer without access to config + // one alternative: the verifier uses the namespaces + // contained in the constraint system to select + // gkr-specific fields + fn output_evals_map(i: usize) -> usize { + if i < 50 { + 27 + 6 * i + } else if i < 100 { + 26 + 6 * (i - 50) + } else { + 326 + i - 100 + } + } + + fn phase1_witness_from_steps( + layout: &Self::Layout, + steps: &[StepRecord], + ) -> Vec::BaseField>> { + let instances = steps + .iter() + .map(|step| { + step.syscall() + .unwrap() + .mem_ops + .iter() + .map(|op| op.value.before) + .collect_vec() + .try_into() + .unwrap() + }) + .collect_vec(); + + layout.phase1_witness(KeccakTrace { instances }) + } + + fn assign_instance_with_gkr_iop( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + lookups: &[E::BaseField], + aux_wits: &[E::BaseField], + ) -> Result<(), ZKVMError> { + Self::assign_instance(config, instance, lk_multiplicity, step)?; + + let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); + let mut var_iter = config.lookups.iter(); + + let mut pop_arg = || -> u64 { + let wit = wit_iter.next().unwrap(); + let var = var_iter.next().unwrap(); + set_val!(instance, var, wit); + wit + }; + + for _i in 0..AND_LOOKUPS { + lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); + pop_arg(); + } + for _i in 0..XOR_LOOKUPS { + lk_multiplicity.lookup_xor_byte(pop_arg(), pop_arg()); + pop_arg(); + } + for _i in 0..RANGE_LOOKUPS { + lk_multiplicity.assert_ux::<16>(pop_arg()); + } + + for (aux_wit_var, aux_wit) in zip_eq(config.aux_wits.iter(), aux_wits) { + set_val!(instance, aux_wit_var, (aux_wit.to_canonical_u64())); + } + assert!(var_iter.next().is_none()); + + Ok(()) + } +} #[derive(Debug)] pub struct LargeEcallConfig { dummy_insn: DummyConfig, @@ -119,4 +259,7 @@ pub struct LargeEcallConfig { start_addr: WitIn, mem_writes: Vec<(WitIn, Change, WriteMEM)>, + + aux_wits: Vec, + lookups: Vec, } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 0b6e5a824..47fb43e4b 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -466,7 +466,7 @@ pub struct DummyExtraConfig { impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { let ecall_config = cs.register_opcode_circuit::>(); - let keccak_config = cs.register_opcode_circuit::>(); + let keccak_config = cs.register_keccakf_circuit(); let secp256k1_add_config = cs.register_opcode_circuit::>(); let secp256k1_double_config = @@ -509,10 +509,10 @@ impl DummyExtraConfig { fixed: &mut ZKVMFixedTraces, ) { fixed.register_opcode_circuit::>(cs); - fixed.register_opcode_circuit::>(cs); + fixed.register_keccakf_circuit(cs); fixed.register_opcode_circuit::>(cs); - fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); @@ -562,11 +562,8 @@ impl DummyExtraConfig { } } - witness.assign_opcode_circuit::>( - cs, - &self.keccak_config, - keccak_steps, - )?; + witness.assign_keccakf_circuit(cs, &self.keccak_config, keccak_steps)?; + witness.assign_opcode_circuit::>( cs, &self.secp256k1_add_config, diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index f97a1886d..f32d72800 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -30,6 +30,10 @@ impl ZKVMConstraintSystem { assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } + // TODO: get keccak cs manually + // assert_eq!(keccak_cs.num_fixed, 0); + vm_pk.keccak_pk = self.keccak_gkr_iop.key_gen(&vm_pk.pp, None); + vm_pk.initial_global_state_expr = self.initial_global_state_expr; vm_pk.finalize_global_state_expr = self.finalize_global_state_expr; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 4571fec73..02915479e 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use gkr_iop::gkr::{Evaluation, GKRCircuit, GKRProverOutput}; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; use p3::field::PrimeCharacteristicRing; @@ -6,6 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::BTreeMap, fmt::{self, Debug}, + marker::PhantomData, ops::Div, }; use sumcheck::structs::IOPProverMessage; @@ -50,6 +52,17 @@ pub struct ZKVMOpcodeProof pub wits_commit: PCS::Commitment, pub wits_opening_proof: PCS::Proof, pub wits_in_evals: Vec, + + pub gkr_opcode_proof: Option>, +} + +#[derive(Clone, Serialize)] +// WARN/TODO: depends on serde's `arc` feature which might not behave correctly +pub struct GKROpcodeProof> { + output_evals: Vec, + prover_output: GKRProverOutput>, + circuit: GKRCircuit, + _marker: PhantomData, } #[derive(Clone, Serialize, Deserialize)] diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 66d8ce625..4cbab1e00 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,6 +1,9 @@ +use ceno_emul::KeccakSpec; use ff_ext::ExtensionField; +use gkr_iop::{evaluation::PointAndEval, gkr::GKRCircuitWitness}; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, + marker::PhantomData, sync::Arc, }; @@ -24,7 +27,9 @@ use witness::{RowMajorMatrix, next_pow2_instance_padding}; use crate::{ error::ZKVMError, expression::Instance, + instructions::{Instruction, riscv::dummy::LargeEcallDummy}, scheme::{ + GKROpcodeProof, constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, @@ -32,7 +37,8 @@ use crate::{ }, }, structs::{ - Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, + GKRIOPProvingKey, KeccakGKRIOP, Point, ProvingKey, TowerProofs, TowerProver, + TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, utils::{add_mle_list_by_expr, get_challenge_pows, optimal_sumcheck_threads}, }; @@ -93,6 +99,7 @@ impl> ZKVMProver { let mut commitments = BTreeMap::new(); let mut wits = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); + let mut keccak_gkr_wit = Some(witnesses.keccak_gkr_wit.clone()); let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name @@ -183,10 +190,21 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); + + // Only Keccak has non-empty GKR-IOP component + let gkr_iop_pk = if *circuit_name + == as Instruction>::name() + { + Some((&self.pk.keccak_pk, keccak_gkr_wit.take().unwrap())) + } else { + None + }; + let opcode_proof = self.create_opcode_proof( circuit_name, &self.pk.pp, pk, + gkr_iop_pk, witness, wits_commit, &pi, @@ -240,12 +258,17 @@ impl> ZKVMProver { /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input #[allow(clippy::too_many_arguments)] + #[allow(clippy::type_complexity)] #[tracing::instrument(skip_all, name = "create_opcode_proof", fields(circuit_name=name,profiling_2), level="trace")] pub fn create_opcode_proof( &self, name: &str, pp: &PCS::ProverParam, circuit_pk: &ProvingKey, + gkr_iop_pk: Option<( + &GKRIOPProvingKey>, + GKRCircuitWitness, + )>, witnesses: Vec>, wits_commit: PCS::CommitmentWithWitness, pi: &[ArcMultilinearExtension<'_, E>], @@ -305,6 +328,7 @@ impl> ZKVMProver { // infer all tower witness after last layer let span = entered_span!("tower_witness_r_layers"); + let r_wit_layers = infer_tower_product_witness( log2_num_instances + log2_r_count, r_records_last_layer, @@ -604,6 +628,7 @@ impl> ZKVMProver { distrinct_zerocheck_terms_set.len() + 1 // +1 from sel_non_lc_zero_sumcheck } ); + let mut main_sel_evals_iter = main_sel_evals.into_iter(); main_sel_evals_iter.next(); // skip sel_r let r_records_in_evals = (0..r_counts_per_instance) @@ -627,6 +652,7 @@ impl> ZKVMProver { distrinct_zerocheck_terms_set.len() + 1 } ); + let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); exit_span!(main_sel_span); @@ -661,8 +687,48 @@ impl> ZKVMProver { opening_dur.elapsed(), ); exit_span!(pcs_open_span); + let wits_commit = PCS::get_pure_commitment(&wits_commit); + let gkr_opcode_proof = if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { + let gkr_iop_pk = gkr_iop_pk.clone(); + let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit(); + + let point = Arc::new(input_open_point); + + let out_evals = gkr_wit + .layers + .last() + .unwrap() + .bases + .iter() + .map(|base| PointAndEval { + point: point.clone(), + eval: subprotocols::utils::evaluate_mle_ext(base, &point), + }) + .collect_vec(); + + let prover_output = gkr_circuit + .prove(gkr_wit, &out_evals, &[], transcript) + .expect("Failed to prove phase"); + // unimplemented!("cannot fully handle GKRIOP component yet") + + let _gkr_open_point = prover_output.opening_evaluations[0].point.clone(); + // TODO: open polynomials for GKR proof + + let output_evals = out_evals.into_iter().map(|pae| pae.eval).collect_vec(); + + Some(GKROpcodeProof { + output_evals, + prover_output, + circuit: gkr_circuit, + _marker: PhantomData, + }) + } else { + None + }; + + // extend with Optio(gkr evals (not combined)) Ok(ZKVMOpcodeProof { num_instances, record_r_out_evals, @@ -679,6 +745,7 @@ impl> ZKVMProver { wits_commit, wits_opening_proof, wits_in_evals, + gkr_opcode_proof, }) } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 6b1a2151e..98ef23176 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -139,6 +139,7 @@ fn test_rw_lk_expression_combination() { name.as_str(), &prover.pk.pp, prover.pk.circuit_pks.get(&name).unwrap(), + None, wits_in, commit, &[], diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 53178f0aa..cf80b4eee 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,6 +1,7 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; use ark_std::iterable::Iterable; +use ceno_emul::{KeccakSpec, SyscallSpec}; use ff_ext::ExtensionField; use itertools::{Itertools, chain, interleave, izip}; @@ -17,6 +18,7 @@ use witness::next_pow2_instance_padding; use crate::{ error::ZKVMError, expression::{Instance, StructuralWitIn}, + instructions::{GKRIOPInstruction, riscv::dummy::LargeEcallDummy}, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, utils::eval_by_expr_with_instance, @@ -273,6 +275,7 @@ impl> ZKVMVerifier cs.w_expressions.len(), cs.lk_expressions.len(), ); + let (log2_r_count, log2_w_count, log2_lk_count) = ( ceil_log2(r_counts_per_instance), ceil_log2(w_counts_per_instance), @@ -307,6 +310,7 @@ impl> ZKVMVerifier num_product_fanin, transcript, )?; + assert!(record_evals.len() == 2, "[r_record, w_record]"); assert!(logup_q_evals.len() == 1, "[lk_q_record]"); assert!(logup_p_evals.len() == 1, "[lk_p_record]"); @@ -402,6 +406,43 @@ impl> ZKVMVerifier ) }; + if name == KeccakSpec::NAME { + // TODO: remove clone + let gkr_iop = proof + .gkr_opcode_proof + .clone() + .expect("Keccak syscall should contain GKR-IOP proof"); + + // Match output_evals with EcallDummy polynomials + for (i, gkr_out_eval) in gkr_iop.output_evals.iter().enumerate() { + assert_eq!( + *gkr_out_eval, + proof.wits_in_evals[LargeEcallDummy::::output_evals_map(i)], + "{i}" + ); + } + // Verify GKR proof + let point = Arc::new(input_opening_point.clone()); + let out_evals = gkr_iop + .output_evals + .iter() + .map(|eval| gkr_iop::evaluation::PointAndEval { + point: point.clone(), + eval: *eval, + }) + .collect_vec(); + + gkr_iop + .circuit + .verify( + gkr_iop.prover_output.gkr_proof.clone(), + &out_evals, + &[], + transcript, + ) + .expect("GKR-IOP verify failure"); + } + let computed_evals = [ // read *alpha_read diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 2fe82f980..768310e85 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -2,13 +2,14 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::Expression, - instructions::Instruction, + instructions::{GKRIOPInstruction, Instruction, riscv::dummy::LargeEcallDummy}, state::StateCircuit, tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, }; -use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; +use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord, SyscallSpec}; use ff_ext::ExtensionField; +use gkr_iop::{gkr::GKRCircuitWitness, precompiles::KeccakLayout}; use itertools::{Itertools, chain}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ @@ -45,7 +46,7 @@ pub struct TowerProverSpec<'a, E: ExtensionField> { pub witness: Vec>>, } -pub type WitnessId = u16; +pub type WitnessId = u32; pub type ChallengeId = u16; #[derive(Copy, Clone, Debug, EnumIter, PartialEq, Eq, Hash)] @@ -130,6 +131,79 @@ impl> VerifyingKey } } +#[derive(Clone, Debug)] +pub struct GKRIOPProvingKey, State> { + pub fixed_traces: Option>>, + pub fixed_commit_wd: Option, + pub vk: GKRIOPVerifyingKey, +} + +impl, State: Default> Default + for GKRIOPProvingKey +{ + fn default() -> Self { + Self { + fixed_traces: None, + fixed_commit_wd: None, + vk: GKRIOPVerifyingKey::default(), + } + } +} + +#[derive(Clone, Debug)] +pub struct GKRIOPVerifyingKey, State> { + pub(crate) state: State, + pub fixed_commit: Option, +} + +impl, State: Default> Default + for GKRIOPVerifyingKey +{ + fn default() -> Self { + Self { + state: State::default(), + fixed_commit: None, + } + } +} + +impl, State> + GKRIOPVerifyingKey +{ + pub fn get_state(&self) -> &State { + &self.state + } +} + +#[derive(Clone, Default, Debug)] +pub struct KeccakGKRIOP { + pub chip: gkr_iop::chip::Chip, + pub layout: KeccakLayout, +} + +impl KeccakGKRIOP { + pub fn key_gen>( + self, + pp: &PCS::ProverParam, + fixed_traces: Option>, + ) -> GKRIOPProvingKey> { + // transpose from row-major to column-major + let fixed_traces_polys = fixed_traces.as_ref().map(|rmm| rmm.to_mles()); + + let fixed_commit_wd = fixed_traces.map(|traces| PCS::batch_commit(pp, traces).unwrap()); + let fixed_commit = fixed_commit_wd.as_ref().map(PCS::get_pure_commitment); + + GKRIOPProvingKey { + fixed_traces: fixed_traces_polys, + fixed_commit_wd, + vk: GKRIOPVerifyingKey { + state: self, + fixed_commit, + }, + } + } +} + #[derive(Clone, Debug)] pub struct ProgramParams { pub platform: Platform, @@ -154,6 +228,7 @@ pub struct ZKVMConstraintSystem { pub(crate) circuit_css: BTreeMap>, pub(crate) initial_global_state_expr: Expression, pub(crate) finalize_global_state_expr: Expression, + pub keccak_gkr_iop: KeccakGKRIOP, pub params: ProgramParams, } @@ -164,6 +239,7 @@ impl Default for ZKVMConstraintSystem { initial_global_state_expr: Expression::ZERO, finalize_global_state_expr: Expression::ZERO, params: ProgramParams::default(), + keccak_gkr_iop: KeccakGKRIOP::default(), } } } @@ -175,6 +251,30 @@ impl ZKVMConstraintSystem { ..Default::default() } } + + pub fn register_keccakf_circuit( + &mut self, + ) -> as Instruction>::InstructionConfig { + // Add GKR-IOP instance + let params = gkr_iop::precompiles::KeccakParams {}; + let (layout, chip) = as gkr_iop::ProtocolBuilder>::build(params); + self.keccak_gkr_iop = KeccakGKRIOP { layout, chip }; + + let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", KeccakSpec::NAME)); + let mut circuit_builder = + CircuitBuilder::::new_with_params(&mut cs, self.params.clone()); + let config = + LargeEcallDummy::::construct_circuit_with_gkr_iop(&mut circuit_builder) + .unwrap(); + assert!( + self.circuit_css + .insert(KeccakSpec::NAME.to_owned(), cs) + .is_none() + ); + + config + } + pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); let mut circuit_builder = @@ -220,6 +320,14 @@ pub struct ZKVMFixedTraces { } impl ZKVMFixedTraces { + pub fn register_keccakf_circuit(&mut self, _cs: &ZKVMConstraintSystem) { + assert!( + self.circuit_fixed_traces + .insert(LargeEcallDummy::::name(), None) + .is_none() + ); + } + pub fn register_opcode_circuit>(&mut self, _cs: &ZKVMConstraintSystem) { assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none()); } @@ -244,6 +352,7 @@ impl ZKVMFixedTraces { #[derive(Default, Clone)] pub struct ZKVMWitnesses { + pub keccak_gkr_wit: GKRCircuitWitness, witnesses_opcodes: BTreeMap>, witnesses_tables: BTreeMap>, lk_mlts: BTreeMap, @@ -263,6 +372,43 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } + pub fn assign_keccakf_circuit( + &mut self, + css: &ZKVMConstraintSystem, + config: & as Instruction>::InstructionConfig, + records: Vec, + ) -> Result<(), ZKVMError> { + let cs = css + .get_cs(&LargeEcallDummy::::name()) + .unwrap(); + let (witness, gkr_witness, logup_multiplicity) = + LargeEcallDummy::::assign_instances_with_gkr_iop( + config, + cs.num_witin as usize, + records, + &css.keccak_gkr_iop.layout, + )?; + self.keccak_gkr_wit = gkr_witness; + + assert!( + self.witnesses_opcodes + .insert(LargeEcallDummy::::name(), witness) + .is_none() + ); + assert!( + !self + .witnesses_tables + .contains_key(&LargeEcallDummy::::name()) + ); + assert!( + self.lk_mlts + .insert(LargeEcallDummy::::name(), logup_multiplicity) + .is_none() + ); + + Ok(()) + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, @@ -337,7 +483,6 @@ impl ZKVMWitnesses { input, )?; assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); - assert!(!self.witnesses_opcodes.contains_key(&TC::name())); Ok(()) @@ -364,6 +509,7 @@ pub struct ZKVMProvingKey> pub vp: PCS::VerifierParam, // pk for opcode and table circuits pub circuit_pks: BTreeMap>, + pub keccak_pk: GKRIOPProvingKey>, // expression for global state in/out pub initial_global_state_expr: Expression, @@ -376,6 +522,7 @@ impl> ZKVMProvingKey( challenges: &[E], // sumcheck batch challenge alpha: E, -) -> BTreeSet { +) -> BTreeSet { assert!(expr.is_monomial_form()); let monomial_terms = expr.evaluate( &|_| unreachable!(), @@ -287,7 +287,7 @@ pub fn add_mle_list_by_expr<'a, E: ExtensionField>( monomial_terms .into_iter() .flat_map(|(_, monomial_term)| monomial_term.into_iter().collect_vec()) - .collect::>() + .collect::>() } #[cfg(test)] diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index fa76acf7d..46f950142 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -6,17 +6,18 @@ extern crate ceno_rt; use ceno_rt::{info_out, syscalls::syscall_keccak_permute}; use core::slice; -const ITERATIONS: usize = 3; +const ITERATIONS: usize = 4; fn main() { let mut state = [0_u64; 25]; for _ in 0..ITERATIONS { syscall_keccak_permute(&mut state); - log_state(&state); + // log_state(&state); } } +#[allow(dead_code)] fn log_state(state: &[u64; 25]) { let out = unsafe { slice::from_raw_parts(state.as_ptr() as *const u8, state.len() * size_of::()) diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index bbe7e85e0..e366c5db6 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -14,10 +14,25 @@ ark-std.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } +ndarray.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true rand.workspace = true rayon.workspace = true +serde.workspace = true subprotocols = { path = "../subprotocols" } thiserror = "1" +tiny-keccak.workspace = true transcript = { path = "../transcript" } +witness = { path = "../witness" } + +[dev-dependencies] +criterion.workspace = true + +[[bench]] +harness = false +name = "bitwise_keccakf" + +[[bench]] +harness = false +name = "lookup_keccakf" diff --git a/gkr_iop/benches/bitwise_keccakf.rs b/gkr_iop/benches/bitwise_keccakf.rs new file mode 100644 index 000000000..902b63a1b --- /dev/null +++ b/gkr_iop/benches/bitwise_keccakf.rs @@ -0,0 +1,37 @@ +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::run_keccakf; +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group("keccak_f".to_string()); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccak_f", "keccak_f"), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + #[allow(clippy::unit_arg)] + black_box(run_keccakf(state, false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/benches/lookup_keccakf.rs b/gkr_iop/benches/lookup_keccakf.rs new file mode 100644 index 000000000..a9ab4f5f0 --- /dev/null +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -0,0 +1,39 @@ +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::run_faster_keccakf; + +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group("keccakf"); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); + let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + #[allow(clippy::unit_arg)] + black_box(run_faster_keccakf(vec![state1, state2], false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index b97d3c9ba..2ac39416b 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -64,14 +64,15 @@ impl ProtocolBuilder for TowerChipLayout { let height = self.params.height; let lookup_challenge = Expression::Const(self.lookup_challenge.clone()); - self.output_cumulative_sum = chip.allocate_output_evals(); + self.output_cumulative_sum = chip.allocate_output_evals::<2>().try_into().unwrap(); // Tower layers let ([updated_table, count], challenges) = (0..height).fold( (self.output_cumulative_sum.clone(), vec![]), |([den, num], challenges), i| { let [den_0, den_1, num_0, num_1] = if i == height - 1 { - // Allocate witnesses in the extension field, except numerator inputs in the base field. + // Allocate witnesses in the extension field, except numerator inputs in the + // base field. let ([num_0, num_1], [den_0, den_1]) = chip.allocate_wits_in_layer(); [den_0, den_1, num_0, num_1] } else { @@ -109,6 +110,7 @@ impl ProtocolBuilder for TowerChipLayout { in_bases, in_exts, vec![den, num], + vec![], )); let [challenge] = chip.allocate_challenges(); ( @@ -138,6 +140,7 @@ impl ProtocolBuilder for TowerChipLayout { vec![table.1.clone()], vec![], vec![updated_table], + vec![], )); chip.allocate_base_opening(self.committed_table_id, table.1); diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index ac742fe04..e8c5079e1 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -1,5 +1,6 @@ use std::array; +use itertools::Itertools; use subprotocols::expression::{Constant, Witness}; use crate::{ @@ -22,10 +23,11 @@ impl Chip { array::from_fn(|i| i + self.n_committed_exts - N) } - /// Allocate `Witness` and `EvalExpression` for the input polynomials in a layer. - /// Where `Witness` denotes the index and `EvalExpression` denotes the position - /// to place the evaluation of the polynomial after processing the layer prover - /// for each polynomial. This should be called at most once for each layer! + /// Allocate `Witness` and `EvalExpression` for the input polynomials in a + /// layer. Where `Witness` denotes the index and `EvalExpression` + /// denotes the position to place the evaluation of the polynomial after + /// processing the layer prover for each polynomial. This should be + /// called at most once for each layer! #[allow(clippy::type_complexity)] pub fn allocate_wits_in_layer( &mut self, @@ -51,9 +53,15 @@ impl Chip { } /// Generate the evaluation expression for each output. - pub fn allocate_output_evals(&mut self) -> [EvalExpression; N] { + pub fn allocate_output_evals(&mut self) -> Vec +// -> [EvalExpression; N] + { self.n_evaluations += N; - array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + // array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + // TODO: hotfix to avoid stack overflow, fix later + (0..N) + .map(|i| EvalExpression::Single(i + self.n_evaluations - N)) + .collect_vec() } /// Allocate challenges. @@ -62,14 +70,16 @@ impl Chip { array::from_fn(|i| Constant::Challenge(i + self.n_challenges - N)) } - /// Allocate a PCS opening action to a base polynomial with index `wit_index`. - /// The `EvalExpression` represents the expression to compute the evaluation. + /// Allocate a PCS opening action to a base polynomial with index + /// `wit_index`. The `EvalExpression` represents the expression to + /// compute the evaluation. pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.base_openings.push((wit_index, eval)); } - /// Allocate a PCS opening action to an ext polynomial with index `wit_index`. - /// The `EvalExpression` represents the expression to compute the evaluation. + /// Allocate a PCS opening action to an ext polynomial with index + /// `wit_index`. The `EvalExpression` represents the expression to + /// compute the evaluation. pub fn allocate_ext_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.ext_openings.push((wit_index, eval)); } diff --git a/gkr_iop/src/chip/protocol.rs b/gkr_iop/src/chip/protocol.rs index a633791eb..0374d8e6c 100644 --- a/gkr_iop/src/chip/protocol.rs +++ b/gkr_iop/src/chip/protocol.rs @@ -4,13 +4,13 @@ use super::Chip; impl Chip { /// Extract information from Chip that required in the GKR phase. - pub fn gkr_circuit(&'_ self) -> GKRCircuit<'_> { + pub fn gkr_circuit(&self) -> GKRCircuit { GKRCircuit { - layers: &self.layers, + layers: self.layers.clone(), n_challenges: self.n_challenges, n_evaluations: self.n_evaluations, - base_openings: &self.base_openings, - ext_openings: &self.ext_openings, + base_openings: self.base_openings.clone(), + ext_openings: self.ext_openings.clone(), } } } diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index 365f1a250..264486c83 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -3,10 +3,12 @@ use std::sync::Arc; use ff_ext::ExtensionField; use itertools::{Itertools, izip}; use multilinear_extensions::virtual_poly::build_eq_x_r_vec_sequential; +use serde::Serialize; use subprotocols::expression::{Constant, Point}; -/// Evaluation expression for the gkr layer reduction and PCS opening preparation. -#[derive(Clone, Debug)] +/// Evaluation expression for the gkr layer reduction and PCS opening +/// preparation. +#[derive(Clone, Debug, Serialize)] pub enum EvalExpression { /// Single entry in the evaluation vector. Single(usize), @@ -14,9 +16,10 @@ pub enum EvalExpression { Linear(usize, Constant, Constant), /// Merging multiple evaluations which denotes a partition of the original /// polynomial. `(usize, Constant)` denote the modification of the point. - /// For example, when it receive a point `(p0, p1, p2, p3)` from a succeeding - /// layer, `vec![(2, c0), (4, c1)]` will modify the point to `(p0, p1, c0, p2, c1, p3)`. - /// where the indices specify how the partition applied to the original polynomial. + /// For example, when it receive a point `(p0, p1, p2, p3)` from a + /// succeeding layer, `vec![(2, c0), (4, c1)]` will modify the point to + /// `(p0, p1, c0, p2, c1, p3)`. where the indices specify how the + /// partition applied to the original polynomial. Partition(Vec>, Vec<(usize, Constant)>), } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 1e632a18f..97906626b 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,6 +1,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use layer::{Layer, LayerWitness}; +use serde::Serialize; use subprotocols::{expression::Point, sumcheck::SumcheckProof}; use transcript::Transcript; @@ -12,28 +13,31 @@ use crate::{ pub mod layer; pub mod mock; -#[derive(Clone, Debug)] -pub struct GKRCircuit<'a> { - pub layers: &'a [Layer], +#[derive(Clone, Debug, Serialize)] +pub struct GKRCircuit { + pub layers: Vec, pub n_challenges: usize, pub n_evaluations: usize, - pub base_openings: &'a [(usize, EvalExpression)], - pub ext_openings: &'a [(usize, EvalExpression)], + pub base_openings: Vec<(usize, EvalExpression)>, + pub ext_openings: Vec<(usize, EvalExpression)>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct GKRCircuitWitness { pub layers: Vec>, } +#[derive(Clone, Serialize)] pub struct GKRProverOutput { pub gkr_proof: GKRProof, pub opening_evaluations: Vec, } +#[derive(Clone, Serialize)] pub struct GKRProof(pub Vec>); +#[derive(Clone, Debug, Serialize)] pub struct Evaluation { pub value: E, pub point: Point, @@ -42,7 +46,7 @@ pub struct Evaluation { pub struct GKRClaims(pub Vec); -impl GKRCircuit<'_> { +impl GKRCircuit { pub fn prove( &self, circuit_wit: GKRCircuitWitness, @@ -56,7 +60,7 @@ impl GKRCircuit<'_> { let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); - let sumcheck_proofs = izip!(self.layers, circuit_wit.layers) + let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) .map(|(layer, layer_wit)| { layer.prove(layer_wit, &mut evaluations, &mut challenges, transcript) }) @@ -85,7 +89,7 @@ impl GKRCircuit<'_> { let mut challenges = challenges.to_vec(); let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); - for (layer, layer_proof) in izip!(self.layers, sumcheck_proofs) { + for (layer, layer_proof) in izip!(&self.layers, sumcheck_proofs) { layer.verify(layer_proof, &mut evaluations, &mut challenges, transcript)?; } @@ -99,7 +103,7 @@ impl GKRCircuit<'_> { evaluations: &[PointAndEval], challenges: &[E], ) -> Vec> { - chain!(self.base_openings, self.ext_openings) + chain!(&self.base_openings, &self.ext_openings) .map(|(poly, eval)| { let poly = *poly; let PointAndEval { point, eval: value } = eval.evaluate(evaluations, challenges); diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 2f23a6d77..ef0005d18 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -2,6 +2,7 @@ use ark_std::log2; use ff_ext::ExtensionField; use itertools::{chain, izip}; use linear_layer::LinearLayer; +use serde::Serialize; use subprotocols::{ expression::{Constant, Expression, Point}, sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, @@ -20,25 +21,26 @@ pub mod linear_layer; pub mod sumcheck_layer; pub mod zerocheck_layer; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize)] pub enum LayerType { Sumcheck, Zerocheck, Linear, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize)] pub struct Layer { pub name: String, pub ty: LayerType, /// Challenges generated at the beginning of the layer protocol. pub challenges: Vec, - /// Expressions to prove in this layer. For zerocheck and linear layers, each - /// expression corresponds to an output. While in sumcheck, there is only 1 - /// expression, which corresponds to the sum of all outputs. This design is - /// for the convenience when building the following expression: - /// `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * eq(p_1, x)) expr(x)`. - /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. + /// Expressions to prove in this layer. For zerocheck and linear layers, + /// each expression corresponds to an output. While in sumcheck, there + /// is only 1 expression, which corresponds to the sum of all outputs. + /// This design is for the convenience when building the following + /// expression: `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * + /// eq(p_1, x)) expr(x)`. where `vec![e_0, beta * e_1]` will be the + /// output evaluation expressions. pub exprs: Vec, /// Positions to place the evaluations of the base inputs of this layer. pub in_bases: Vec, @@ -47,9 +49,12 @@ pub struct Layer { /// The expressions of the evaluations from the succeeding layers, which are /// connected to the outputs of this layer. pub outs: Vec, + + // For debugging purposes + pub expr_names: Vec, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct LayerWitness { pub bases: Vec>, pub exts: Vec>, @@ -66,7 +71,15 @@ impl Layer { in_bases: Vec, in_exts: Vec, outs: Vec, + expr_names: Vec, ) -> Self { + let mut expr_names = expr_names; + if expr_names.len() < exprs.len() { + expr_names.extend(vec![ + "unavailable".to_string(); + exprs.len() - expr_names.len() + ]); + } Self { name, ty, @@ -75,6 +88,7 @@ impl Layer { in_bases, in_exts, outs, + expr_names, } } diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 40d8a6ec8..8387c4d80 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -80,7 +80,7 @@ impl LinearLayer for Layer { transcript.append_field_element_exts(&ext_mle_evals); transcript.append_field_element_exts(&base_mle_evals); - for (sigma, expr) in izip!(sigmas, &self.exprs) { + for (sigma, expr, expr_name) in izip!(sigmas, &self.exprs, &self.expr_names) { let got = expr.evaluate( &ext_mle_evals, &base_mle_evals, @@ -91,7 +91,7 @@ impl LinearLayer for Layer { if *sigma != got { return Err(BackendError::LayerVerificationFailed( self.name.clone(), - VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got), + VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got, expr_name.clone()), )); } } diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index afb65d9a2..01bd00e8f 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -66,6 +66,7 @@ impl SumcheckLayer for Layer { proof, challenges, transcript, + vec![], ); verifier_state diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index e9e9bc577..4df79c405 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -63,6 +63,7 @@ impl ZerocheckLayer for Layer { let verifier_state = ZerocheckVerifierState::new( sigmas, self.exprs.clone(), + vec![], out_points, proof, challenges, diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 660f5156e..3c359bdfe 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -35,7 +35,7 @@ pub enum MockProverError { impl MockProver { pub fn check( - circuit: GKRCircuit<'_>, + circuit: GKRCircuit, circuit_wit: &GKRCircuitWitness, mut evaluations: Vec>, mut challenges: Vec, diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index dff90eb6d..06c83f486 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -9,6 +9,7 @@ pub mod chip; pub mod error; pub mod evaluation; pub mod gkr; +pub mod precompiles; pub mod utils; pub trait ProtocolBuilder: Sized { @@ -48,12 +49,14 @@ where fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness; } -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, +// `gkr_phase` and `opening_phase`. pub struct ProtocolProver, PCS>( PhantomData<(E, Trans, PCS)>, ); -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, +// `gkr_phase` and `opening_phase`. pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs new file mode 100644 index 000000000..153e01d2b --- /dev/null +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -0,0 +1,573 @@ +use std::{array::from_fn, marker::PhantomData, sync::Arc}; + +use crate::{ + ProtocolBuilder, ProtocolWitnessGenerator, + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, + }, +}; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain, iproduct}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomialExtensionField}; +use p3_goldilocks::Goldilocks; + +use subprotocols::expression::{Constant, Expression, Witness}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; +#[derive(Clone, Debug, Default)] +struct KeccakParams {} + +#[derive(Clone, Debug, Default)] +struct KeccakLayout { + _params: KeccakParams, + + committed_bits_id: usize, + + _result: Vec, + _marker: PhantomData, +} + +const X: usize = 5; +const Y: usize = 5; +const Z: usize = 64; +const STATE_SIZE: usize = X * Y * Z; +const C_SIZE: usize = X * Z; +const D_SIZE: usize = X * Z; + +fn to_xyz(i: usize) -> (usize, usize, usize) { + assert!(i < STATE_SIZE); + (i / 64 % 5, (i / 64) / 5, i % 64) +} + +fn from_xyz(x: usize, y: usize, z: usize) -> usize { + assert!(x < 5 && y < 5 && z < 64); + 64 * (5 * y + x) + z +} + +fn xor(a: F, b: F) -> F { + a + b - a * b - a * b +} + +fn and_expr(a: Expression, b: Expression) -> Expression { + a.clone() * b.clone() +} + +fn not_expr(a: Expression) -> Expression { + one_expr() - a +} + +fn xor_expr(a: Expression, b: Expression) -> Expression { + a.clone() + b.clone() - Expression::Const(Constant::Base(2)) * a * b +} + +fn zero_expr() -> Expression { + Expression::Const(Constant::Base(0)) +} + +fn one_expr() -> Expression { + Expression::Const(Constant::Base(1)) +} + +fn c(x: usize, z: usize, bits: &[F]) -> F { + (0..5) + .map(|y| bits[from_xyz(x, y, z)]) + .fold(F::ZERO, |acc, x| xor(acc, x)) +} + +fn c_expr(x: usize, z: usize, state_wits: &[Witness]) -> Expression { + (0..5) + .map(|y| Expression::from(state_wits[from_xyz(x, y, z)])) + .fold(zero_expr(), xor_expr) +} + +fn from_xz(x: usize, z: usize) -> usize { + x * 64 + z +} + +fn d(x: usize, z: usize, c_vals: &[F]) -> F { + let lhs = from_xz((x + 5 - 1) % 5, z); + let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); + xor(c_vals[lhs], c_vals[rhs]) +} + +fn d_expr(x: usize, z: usize, c_wits: &[Witness]) -> Expression { + let lhs = from_xz((x + 5 - 1) % 5, z); + let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); + xor_expr(c_wits[lhs].into(), c_wits[rhs].into()) +} + +fn theta(bits: Vec) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + + let c_vals = iproduct!(0..5, 0..64) + .map(|(x, z)| c(x, z, &bits)) + .collect_vec(); + + let d_vals = iproduct!(0..5, 0..64) + .map(|(x, z)| d(x, z, &c_vals)) + .collect_vec(); + + bits.iter() + .enumerate() + .map(|(i, bit)| { + let (x, _, z) = to_xyz(i); + xor(*bit, d_vals[from_xz(x, z)]) + }) + .collect() +} + +fn and(a: F, b: F) -> F { + a * b +} + +fn not(a: F) -> F { + F::ONE - a +} + +fn u64s_to_bools(state64: &[u64]) -> Vec { + state64 + .iter() + .flat_map(|&word| (0..64).map(move |i| ((word >> i) & 1) == 1)) + .collect() +} + +fn chi(bits: &[F]) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + + bits.iter() + .enumerate() + .map(|(i, bit)| { + let (x, y, z) = to_xyz(i); + let rhs = and( + not(bits[from_xyz((x + 1) % X, y, z)]), + bits[from_xyz((x + 2) % X, y, z)], + ); + xor(*bit, rhs) + }) + .collect() +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +fn iota(bits: &[F], round_value: u64) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + let mut ret = bits.to_vec(); + + let cast = |x| match x { + 0 => F::ZERO, + 1 => F::ONE, + _ => unreachable!(), + }; + + for z in 0..Z { + ret[from_xyz(0, 0, z)] = xor(bits[from_xyz(0, 0, z)], cast((round_value >> z) & 1)); + } + + ret +} + +fn iota_expr(bits: &[Witness], index: usize, round_value: u64) -> Expression { + assert_eq!(bits.len(), STATE_SIZE); + let (x, y, z) = to_xyz(index); + + if x > 0 || y > 0 { + bits[index].into() + } else { + let round_bit = Expression::Const(Constant::Base( + ((round_value >> index) & 1).try_into().unwrap(), + )); + xor_expr(bits[from_xyz(0, 0, z)].into(), round_bit) + } +} + +fn chi_expr(i: usize, bits: &[Witness]) -> Expression { + assert_eq!(bits.len(), STATE_SIZE); + + let (x, y, z) = to_xyz(i); + let rhs = and_expr( + not_expr(bits[from_xyz((x + 1) % X, y, z)].into()), + bits[from_xyz((x + 2) % X, y, z)].into(), + ); + xor_expr((bits[i]).into(), rhs) +} + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(params: Self::Params) -> Self { + Self { + _params: params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_bits_id] = chip.allocate_committed_base(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_output = chip.allocate_output_evals::(); + + (0..ROUNDS).rev().fold(final_output, |round_output, round| { + let (chi_output, _) = chip.allocate_wits_in_layer::(); + + let exprs = (0..STATE_SIZE) + .map(|i| iota_expr(&chi_output.iter().map(|e| e.0).collect_vec(), i, RC[round])) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Iota:: compute output"), + LayerType::Zerocheck, + exprs, + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + round_output.to_vec(), + vec![], + )); + + let (theta_output, _) = chip.allocate_wits_in_layer::(); + + // Apply the effects of the rho + pi permutation directly o the argument of chi + // No need for a separate layer + let perm = rho_and_pi_permutation(); + let permuted = (0..STATE_SIZE) + .map(|i| theta_output[perm[i]].0) + .collect_vec(); + + let exprs = (0..STATE_SIZE) + .map(|i| chi_expr(i, &permuted)) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Chi:: apply rho, pi and chi"), + LayerType::Zerocheck, + exprs, + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (d_and_state, _) = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }, 0>(); + let (d, state2) = d_and_state.split_at(D_SIZE); + + // Compute post-theta state using original state and D[][] values + let exprs = (0..STATE_SIZE) + .map(|i| { + let (x, _, z) = to_xyz(i); + xor_expr(state2[i].0.into(), d[from_xz(x, z)].0.into()) + }) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute output"), + LayerType::Zerocheck, + exprs, + vec![], + d_and_state.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (c, []) = chip.allocate_wits_in_layer::<{ C_SIZE }, 0>(); + + let c_wits = c.iter().map(|e| e.0).collect_vec(); + // Compute D[][] from C[][] values + let d_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d_expr(x, z, &c_wits)) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute D[x][z]"), + LayerType::Zerocheck, + d_exprs, + vec![], + c.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + d.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (state, []) = chip.allocate_wits_in_layer::(); + let state_wits = state.iter().map(|s| s.0).collect_vec(); + + // Compute C[][] from state + let c_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c_expr(x, z, &state_wits)) + .collect_vec(); + + // Copy state + let id_exprs: Vec = + (0..STATE_SIZE).map(|i| state_wits[i].into()).collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute C[x][z]"), + LayerType::Zerocheck, + chain!(c_exprs, id_exprs).collect_vec(), + vec![], + state.iter().map(|t| t.1.clone()).collect_vec(), + vec![], + chain!( + c.iter().map(|e| e.1.clone()), + state2.iter().map(|e| e.1.clone()) + ) + .collect_vec(), + vec![], + )); + + state.iter().map(|e| e.1.clone()).collect_vec() + }); + + // Skip base opening allocation + } +} + +pub struct KeccakTrace { + pub bits: [bool; STATE_SIZE], +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut res = vec![vec![]; 1]; + res[0] = phase1 + .bits + .into_iter() + .map(|b| E::BaseField::from_u64(b as u64)) + .collect(); + res + } + + fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { + let mut bits = phase1[self.committed_bits_id].clone(); + + let n_layers = 100; + let mut layer_wits = Vec::>::with_capacity(n_layers + 1); + + #[allow(clippy::needless_range_loop)] + for round in 0..24 { + if round == 0 { + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + } + + let c_wits = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c(x, z, &bits)) + .collect_vec(); + + layer_wits.push(LayerWitness::new( + chain!( + c_wits.clone().into_iter().map(|b| vec![b]), + // Note: it seems test pass even if this is uncommented. + // Maybe it's good to assert there are no unused witnesses + // bits.clone().into_iter().map(|b| vec![b]) + ) + .collect_vec(), + vec![], + )); + + let d_wits = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d(x, z, &c_wits)) + .collect_vec(); + + layer_wits.push(LayerWitness::new( + chain!( + d_wits.clone().into_iter().map(|b| vec![b]), + bits.clone().into_iter().map(|b| vec![b]) + ) + .collect_vec(), + vec![], + )); + + bits = theta(bits); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + + bits = chi(&pi(&rho(&bits))); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + + if round < 23 { + bits = iota(&bits, RC[round]); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + } + } + + // Assumes one input instance + let total_witness_size: usize = layer_wits.iter().map(|layer| layer.bases.len()).sum(); + dbg!(total_witness_size); + + layer_wits.reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +// based on +// https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_rho.cpp +fn rho(state: &[T]) -> Vec { + assert_eq!(state.len(), STATE_SIZE); + let (mut x, mut y) = (1, 0); + let mut ret = [T::default(); STATE_SIZE]; + + for z in 0..Z { + ret[from_xyz(0, 0, z)] = state[from_xyz(0, 0, z)]; + } + for t in 0..24 { + for z in 0..Z { + let new_z = (1000 * Z + z - (t + 1) * (t + 2) / 2) % Z; + ret[from_xyz(x, y, z)] = state[from_xyz(x, y, new_z)]; + } + (x, y) = (y, (2 * x + 3 * y) % Y); + } + + ret.to_vec() +} + +// https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_pi.cpp +fn pi(state: &[T]) -> Vec { + assert_eq!(state.len(), STATE_SIZE); + let mut ret = [T::default(); STATE_SIZE]; + + iproduct!(0..X, 0..Y, 0..Z) + .map(|(x, y, z)| ret[from_xyz(x, y, z)] = state[from_xyz((x + 3 * y) % X, x, z)]) + .count(); + + ret.to_vec() +} + +// Combines rho and pi steps into a single permutation +fn rho_and_pi_permutation() -> Vec { + let perm: [usize; STATE_SIZE] = from_fn(|i| i); + pi(&rho(&perm)) +} + +pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + let gkr_circuit = chip.gkr_circuit(); + + let bits = u64s_to_bools(&state); + + let phase1_witness = layout.phase1_witness(KeccakTrace { + bits: bits.try_into().unwrap(), + }); + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let gkr_witness = layout.gkr_witness(&phase1_witness, &[]); + + let out_evals = { + let point = Arc::new(vec![]); + + let last_witness = gkr_witness.layers[0] + .bases + .clone() + .into_iter() + .flatten() + .collect_vec(); + + // Last witness is missing the final sub-round; apply it now + let expected_result_manual = iota(&last_witness, RC[23]); + + if test { + let mut state = state; + keccakf(&mut state); + let state = u64s_to_bools(&state) + .into_iter() + .map(|b| Goldilocks::from_u64(b as u64)) + .collect_vec(); + assert_eq!(state, expected_result_manual); + } + + expected_result_manual + .iter() + .map(|bit| PointAndEval { + point: point.clone(), + eval: E::from_bases(&[*bit, Goldilocks::ZERO]), + }) + .collect_vec() + }; + + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng, SeedableRng}; + + use super::*; + + #[test] + fn test_keccakf() { + for _ in 0..3 { + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); + let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + run_keccakf(state, true, true); + } + } +} diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs new file mode 100644 index 000000000..50dd09384 --- /dev/null +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -0,0 +1,1205 @@ +use std::{cmp::Ordering, marker::PhantomData, sync::Arc}; + +use crate::{ + ProtocolBuilder, ProtocolWitnessGenerator, + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, + }, + precompiles::utils::{MaskRepresentation, nest, not8_expr, zero_expr}, +}; +use ndarray::{ArrayView, Ix2, Ix3, s}; + +use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; + +use ff_ext::{ExtensionField, SmallField}; +use itertools::{Itertools, chain, iproduct, zip_eq}; +use p3_goldilocks::Goldilocks; +use subprotocols::expression::{Constant, Expression, Witness}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; + +#[derive(Clone, Debug, Default)] +pub struct KeccakParams {} + +#[derive(Clone, Debug, Default)] +pub struct KeccakLayout { + _params: KeccakParams, + _input_columns: Vec, + _result: Vec, + _marker: PhantomData, +} + +fn expansion_expr(expansion: &[(usize, Witness)]) -> Expression { + let (total, ret) = expansion + .iter() + .rev() + .fold((0, zero_expr()), |acc, (sz, felt)| { + ( + acc.0 + sz, + acc.1 * Expression::Const(Constant::Base(1 << sz)) + (*felt).into(), + ) + }); + + assert_eq!(total, SIZE); + ret +} + +/// Compute an adequate split of 64-bits into chunks for performing a rotation +/// by `delta`. The first element of the return value is the vec of chunk sizes. +/// The second one is the length of its suffix that needs to be rotated +fn rotation_split(delta: usize) -> (Vec, usize) { + let delta = delta % 64; + + if delta == 0 { + return (vec![32, 32], 0); + } + + // This split meets all requirements except for <= 16 sizes + let split32 = match delta.cmp(&32) { + Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], + Ordering::Equal => vec![32, 32], + Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], + }; + + // Split off large chunks + let split16 = split32 + .into_iter() + .flat_map(|size| { + assert!(size < 32); + if size <= 16 { + vec![size] + } else { + vec![16, size - 16] + } + }) + .collect_vec(); + + let mut sum = 0; + for (i, size) in split16.iter().rev().enumerate() { + sum += size; + if sum == delta { + return (split16, i + 1); + } + } + + panic!(); +} + +struct ConstraintSystem { + expressions: Vec, + expr_names: Vec, + evals: Vec, + and_lookups: Vec, + xor_lookups: Vec, + range_lookups: Vec, +} + +impl ConstraintSystem { + fn new() -> Self { + ConstraintSystem { + expressions: vec![], + evals: vec![], + expr_names: vec![], + and_lookups: vec![], + xor_lookups: vec![], + range_lookups: vec![], + } + } + + fn add_constraint(&mut self, expr: Expression, name: String) { + self.expressions.push(expr); + self.evals.push(zero_eval()); + self.expr_names.push(name); + } + + fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { + self.and_lookups.push(CenoLookup::And(a, b, c)); + } + + fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { + self.xor_lookups.push(CenoLookup::Xor(a, b, c)); + } + + /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. + /// In general it can be done by two U16 checks: one for `value` and one for + /// `value << (16 - size)`. + fn lookup_range(&mut self, value: Expression, size: usize) { + assert!(size <= 16); + self.range_lookups.push(CenoLookup::U16(value.clone())); + if size < 16 { + self.range_lookups.push(CenoLookup::U16( + value * Expression::Const(Constant::Base(1 << (16 - size))), + )) + } + } + + fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { + self.add_constraint(lhs - rhs, name); + } + + // Constrains that lhs and rhs encode the same value of SIZE bits + // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) + // This needs to be constrained separately + fn constrain_reps_eq( + &mut self, + lhs: &[(usize, Witness)], + rhs: &[(usize, Witness)], + name: String, + ) { + self.add_constraint( + expansion_expr::(lhs) - expansion_expr::(rhs), + name, + ); + } + + /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. + /// `rot8` and `input8` each consist of 8 chunks of 8-bits. + /// + /// `split_rep` is a chunk representation of the input which + /// allows to reduce the required rotation to an array rotation. It may use + /// non-uniform chunks. + /// + /// For example, when `delta = 2`, the 64 bits are split into chunks of + /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the + /// least significant bits so a left rotation will become a right rotation + /// of the array). To perform the required rotation, we can + /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. + /// + /// In the first step, we check that `rot8` and `split_rep` represent the + /// same 64 bits. In the second step we check that `rot8` and the appropiate + /// array rotation of `split_rep` represent the same 64 bits. + /// + /// This type of representation-equality check is done by packing chunks + /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, + /// 2c] to the first 4 elements of `rot8`). In addition, we do range + /// checks on `split_rep` which check that the felts meet the required + /// sizes. + /// + /// This algorithm imposes the following general requirements for + /// `split_rep`: + /// - There exists a suffix of `split_rep` which sums to exactly `delta`. + /// This suffix can contain several elements. + /// - Chunk sizes are at most 16 (so they can be range-checked) or they are + /// exactly equal to 32. + /// - There exists a prefix of chunks which sums exactly to 32. This must + /// hold for the rotated array as well. + /// - The number of chunks should be as small as possible. + /// + /// Consult the method `rotation_split` to see how splits are computed for a + /// given `delta + /// + /// Note that the function imposes range checks on chunk values, but it + /// makes two exceptions: + /// 1. It doesn't check the 8-bit reps (input and output). This is + /// because all 8-bit reps in the global circuit are implicitly + /// range-checked because they are lookup arguments. + /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit + /// chunk value is checked to be equal to the composition of 4 8-bit + /// chunks. As mentioned in 1., these can be trusted to be range + /// checked, so the resulting 32-bit is correct by construction as + /// well. + fn constrain_left_rotation64( + &mut self, + input8: &[Witness], + split_rep: &[(usize, Witness)], + rot8: &[Witness], + delta: usize, + label: String, + ) { + assert_eq!(input8.len(), 8); + assert_eq!(rot8.len(), 8); + + // Assert that the given split witnesses are correct for this delta + let (sizes, chunks_rotation) = rotation_split(delta); + assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); + + // Lookup ranges + for (size, elem) in split_rep { + if *size != 32 { + self.lookup_range((*elem).into(), *size); + } + } + + // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are + // the same 64 bitstring + let mut helper = |rep8: &[Witness], rep_x: &[(usize, Witness)], chunks_rotation: usize| { + // Do the same thing for the two 32-bit halves + let mut rep_x = rep_x.to_owned(); + rep_x.rotate_right(chunks_rotation); + + for i in 0..2 { + // The respective 4 elements in the byte representation + let lhs = rep8[4 * i..4 * (i + 1)] + .iter() + .map(|wit| (8, *wit)) + .collect_vec(); + let cnt = rep_x.len() / 2; + let rhs = &rep_x[cnt * i..cnt * (i + 1)]; + + assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); + + self.constrain_reps_eq::<32>( + &lhs, + rhs, + format!( + "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", + sizes + ), + ); + } + }; + + helper(input8, split_rep, 0); + helper(rot8, split_rep, chunks_rotation); + } +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], +]; + +pub const KECCAK_INPUT_SIZE: usize = 50; +pub const KECCAK_OUTPUT_SIZE: usize = 50; + +pub const AND_LOOKUPS_PER_ROUND: usize = 200; +pub const XOR_LOOKUPS_PER_ROUND: usize = 608; +pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; +pub const LOOKUP_FELTS_PER_ROUND: usize = + 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; + +pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; +pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; +pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; + +macro_rules! allocate_and_split { + ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ + let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); + let mut iter = witnesses.into_iter(); + ( + $( + iter.by_ref().take($size).collect_vec(), + )* + ) + }}; + } + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(params: Self::Params) -> Self { + Self { + _params: params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + let _ = chip.allocate_committed_base::<{ 50 + 40144 }>(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_outputs = + chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); + + let mut final_outputs_iter = final_outputs.iter(); + + let [keccak_output32, keccak_input32, lookup_outputs] = [ + KECCAK_OUTPUT_SIZE, + KECCAK_INPUT_SIZE, + // LOOKUPS_PER_ROUND + LOOKUP_FELTS_PER_ROUND * ROUNDS, + ] + .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); + + let keccak_output32 = keccak_output32.to_vec(); + let keccak_input32 = keccak_input32.to_vec(); + let lookup_outputs = lookup_outputs.to_vec(); + + let (keccak_output8, []) = chip.allocate_wits_in_layer::<200, 0>(); + + let keccak_output8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &keccak_output8).unwrap(); + + let mut expressions = vec![]; + let mut evals = vec![]; + let mut expr_names = vec![]; + + let mut global_and_lookup = 0; + let mut global_xor_lookup = 3 * AND_LOOKUPS; + let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; + + let mut openings = 0; + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::<32>( + &keccak_output8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0)) + .collect_vec(), + ); + expressions.push(expr); + evals.push(keccak_output32[evals.len()].clone()); + expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); + } + } + } + + chip.add_layer(Layer::new( + "build 32-bit output".to_string(), + LayerType::Zerocheck, + expressions, + vec![], + keccak_output8 + .into_iter() + .map(|e| e.1.clone()) + .collect_vec(), + vec![], + evals, + expr_names, + )); + + let state8_loop = (0..ROUNDS).rev().fold( + keccak_output8.iter().map(|e| e.1.clone()).collect_vec(), + |round_output, round| { + #[allow(non_snake_case)] + let ( + state8, + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + ) = allocate_and_split!( + chip, 1656, 200, 200, 30, 40, 40, 200, 146, 200, 200, 200, 200 + ); + + let total_witnesses = 200 + 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 200 + 200; + // dbg!(total_witnesses); + assert_eq!(1656, total_witnesses); + + let bases = chain!( + state8.clone(), + c_aux.clone(), + c_temp.clone(), + c_rot.clone(), + d.clone(), + theta_output.clone(), + rotation_witness.clone(), + rhopi_output.clone(), + nonlinear.clone(), + chi_output.clone(), + iota_output.clone(), + ) + .collect_vec(); + + for wit in &bases { + chip.allocate_base_opening(openings, wit.clone().1); + openings += 1; + } + + // TODO: ndarrays can be replaced with normal arrays + + // Input state of the round in 8-bit chunks + let state8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + let mut system = ConstraintSystem::new(); + + // The purpose is to compute the auxiliary array + // c[i] = XOR (state[j][i]) for j in 0..5 + // We unroll it into + // c_aux[i][j] = XOR (state[k][i]) for k in 0..j + // We use c_aux[i][4] instead of c[i] + // c_aux is also stored in 8-bit chunks + let c_aux: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); + + for i in 0..5 { + for k in 0..8 { + // Initialize first element + system.constrain_eq( + state8[[0, i, k]].0.into(), + c_aux[[i, 0, k]].0.into(), + "init c_aux".to_string(), + ); + } + for j in 1..5 { + // Check xor using lookups over all chunks + for k in 0..8 { + system.lookup_xor8( + c_aux[[i, j - 1, k]].0.into(), + state8[[j, i, k]].0.into(), + c_aux[[i, j, k]].0.into(), + ); + } + } + } + + // Compute c_rot[i] = c[i].rotate_left(1) + // To understand how rotations are performed in general, consult the + // documentation of `constrain_left_rotation64`. Here c_temp is the split + // witness for a 1-rotation. + + let c_temp: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 6), &c_temp).unwrap(); + let c_rot: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &c_rot).unwrap(); + + let (sizes, _) = rotation_split(1); + + for i in 0..5 { + assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); + + system.constrain_left_rotation64( + &c_aux.slice(s![i, 4, ..]).iter().map(|e| e.0).collect_vec(), + &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) + .map(|(e, sz)| (*sz, e.0)) + .collect_vec(), + &c_rot.slice(s![i, ..]).iter().map(|e| e.0).collect_vec(), + 1, + "theta rotation".to_string(), + ); + } + + // d is computed simply as XOR of required elements of c (and rotations) + // again stored as 8-bit chunks + let d: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &d).unwrap(); + + for i in 0..5 { + for k in 0..8 { + system.lookup_xor8( + c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), + c_rot[[(i + 1) % 5, k]].0.into(), + d[[i, k]].0.into(), + ) + } + } + + // output state of the Theta sub-round, simple XOR, in 8-bit chunks + let theta_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_xor8( + state8[[j, i, k]].0.into(), + d[[i, k]].0.into(), + theta_output[[j, i, k]].0.into(), + ) + } + } + } + + // output state after applying both Rho and Pi sub-rounds + // sub-round Pi is a simple permutation of 64-bit lanes + // sub-round Rho requires rotations + let rhopi_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); + + // iterator over split witnesses + let mut rotation_witness = rotation_witness.iter(); + + for i in 0..5 { + #[allow(clippy::needless_range_loop)] + for j in 0..5 { + let arg = theta_output + .slice(s!(j, i, ..)) + .iter() + .map(|e| e.0) + .collect_vec(); + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); + let many = sizes.len(); + let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) + .map(|(sz, (wit, _))| (sz, *wit)) + .collect_vec(); + let arg_rotated = rhopi_output + .slice(s!((2 * i + 3 * j) % 5, j, ..)) + .iter() + .map(|e| e.0) + .collect_vec(); + system.constrain_left_rotation64( + &arg, + &rep_split, + &arg_rotated, + ROTATION_CONSTANTS[j][i], + format!("RHOPI {i}, {j}"), + ); + } + } + + let chi_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); + + // for the Chi sub-round, we use an intermediate witness storing the result of + // the required AND + let nonlinear: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_and8( + not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), + rhopi_output[[j, (i + 2) % 5, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + ); + + system.lookup_xor8( + rhopi_output[[j, i, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + ); + } + } + } + + // TODO: 24/25 elements stay the same after Iota; eliminate duplication? + let iota_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + if i == 0 && j == 0 { + for k in 0..8 { + system.lookup_xor8( + chi_output[[j, i, k]].0.into(), + Expression::Const(Constant::Base( + ((RC[round] >> (k * 8)) & 0xFF) as i64, + )), + iota_output[[j, i, k]].0.into(), + ); + } + } else { + for k in 0..8 { + system.constrain_eq( + iota_output[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + "nothing special".to_string(), + ); + } + } + } + } + + let ConstraintSystem { + mut expressions, + mut expr_names, + mut evals, + and_lookups, + xor_lookups, + range_lookups, + .. + } = system; + + iota_output + .into_iter() + .enumerate() + .map(|(i, val)| { + expressions.push(val.0.into()); + expr_names.push(format!("iota_output {i}")); + evals.push(round_output[i].clone()); + }) + .count(); + + for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) + .flatten() + .enumerate() + { + expressions.push(lookup); + expr_names.push(format!("round {round}: {i}th lookup felt")); + let idx = if i < 3 * AND_LOOKUPS_PER_ROUND { + &mut global_and_lookup + } else if i < 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND { + &mut global_xor_lookup + } else { + &mut global_range_lookup + }; + evals.push(lookup_outputs[*idx].clone()); + *idx += 1; + } + + chip.add_layer(Layer::new( + format!("Round {round}"), + LayerType::Zerocheck, + expressions, + vec![], + bases.into_iter().map(|e| e.1).collect_vec(), + vec![], + evals, + expr_names, + )); + + state8.into_iter().map(|e| e.1.clone()).collect_vec() + }, + ); + + assert!(global_and_lookup == 3 * AND_LOOKUPS); + assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); + assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); + + let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); + + let state8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + let mut expressions = vec![]; + let mut evals = vec![]; + let mut expr_names = vec![]; + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::<32>( + state8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0)) + .collect_vec() + .as_slice(), + ); + expressions.push(expr); + evals.push(keccak_input32[evals.len()].clone()); + expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); + } + } + } + + // TODO: eliminate this duplication + zip_eq(state8.iter(), state8_loop.iter()) + .map(|(e, e_loop)| { + expressions.push(e.0.into()); + evals.push(e_loop.clone()); + expr_names.push("state8 identity".to_string()); + }) + .count(); + + chip.add_layer(Layer::new( + "build 32-bit input".to_string(), + LayerType::Zerocheck, + expressions, + vec![], + state8.into_iter().map(|e| e.1.clone()).collect_vec(), + vec![], + evals, + expr_names, + )); + + // TODO: allocate everything + chip.allocate_base_opening(0, state8[[0, 0, 0]].clone().1); + chip.allocate_base_opening(1, state8[[0, 0, 1]].clone().1); + } +} + +#[derive(Clone, Default)] +pub struct KeccakTrace { + pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut poly = vec![vec![]; KECCAK_INPUT_SIZE]; + for instance in phase1.instances { + let felts = u64s_to_felts::(instance.into_iter().map(|e| e as u64).collect_vec()); + for i in 0..KECCAK_INPUT_SIZE { + poly[i].push(felts[i]); + } + } + poly + } + + fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { + let n_layers = 24 + 2 + 1; + let mut layer_wits = vec![ + LayerWitness { + bases: vec![], + exts: vec![], + num_vars: 1 + }; + n_layers + ]; + + let num_instances = phase1[0].len(); + + for i in 0..num_instances { + fn conv64to8(input: u64) -> [u64; 8] { + MaskRepresentation::new(vec![(64, input).into()]) + .convert(vec![8; 8]) + .values() + .try_into() + .unwrap() + } + + let mut com_state = vec![]; + #[allow(clippy::needless_range_loop)] + for j in 0..KECCAK_INPUT_SIZE { + com_state.push(phase1[j][i]); + } + + let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; + + let mut add_and = |a: u64, b: u64, round: usize| { + let c = a & b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + and_lookups[round].extend(vec![a, b, c]); + }; + + let mut add_xor = |a: u64, b: u64, round: usize| { + let c = a ^ b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + xor_lookups[round].extend(vec![a, b, c]); + }; + + let mut add_range = |value: u64, size: usize, round: usize| { + assert!(size <= 16, "{size}"); + range_lookups[round].push(value); + if size < 16 { + range_lookups[round].push(value << (16 - size)); + assert!(value << (16 - size) < (1 << 16)); + } + }; + + let state32 = com_state + .into_iter() + // TODO double check assumptions about canonical + .map(|e| e.to_canonical_u64()) + .collect_vec(); + + let mut state64 = [[0u64; 5]; 5]; + let mut state8 = [[[0u64; 8]; 5]; 5]; + + zip_eq(iproduct!(0..5, 0..5), state32.clone().iter().tuples()) + .map(|((x, y), (lo, hi))| { + state64[x][y] = lo | (hi << 32); + }) + .count(); + + for x in 0..5 { + for y in 0..5 { + state8[x][y] = conv64to8(state64[x][y]); + } + } + + let mut curr_layer = 0; + let mut push_instance = |wits: Vec| { + let felts = u64s_to_felts::(wits); + if layer_wits[curr_layer].bases.is_empty() { + layer_wits[curr_layer] = LayerWitness::new(nest::(&felts), vec![]); + } else { + assert_eq!(felts.len(), layer_wits[curr_layer].bases.len()); + for (i, base) in layer_wits[curr_layer].bases.iter_mut().enumerate() { + base.push(felts[i]); + } + } + curr_layer += 1; + }; + + push_instance(state8.into_iter().flatten().flatten().collect_vec()); + + #[allow(clippy::needless_range_loop)] + for round in 0..ROUNDS { + let mut c_aux64 = [[0u64; 5]; 5]; + let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + + for i in 0..5 { + c_aux64[i][0] = state64[0][i]; + c_aux8[i][0] = conv64to8(c_aux64[i][0]); + for j in 1..5 { + c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; + c_aux8[i][j] = conv64to8(c_aux64[i][j]); + + for k in 0..8 { + add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); + } + } + } + + let mut c64 = [0u64; 5]; + let mut c8 = [[0u64; 8]; 5]; + + for x in 0..5 { + c64[x] = c_aux64[x][4]; + c8[x] = conv64to8(c64[x]); + } + + let mut c_temp = [[0u64; 6]; 5]; + for i in 0..5 { + let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) + .convert(vec![16, 15, 1, 16, 15, 1]); + c_temp[i] = rep.values().try_into().unwrap(); + for mask in rep.rep { + add_range(mask.value, mask.size, round); + } + } + + let mut crot64 = [0u64; 5]; + let mut crot8 = [[0u64; 8]; 5]; + for i in 0..5 { + crot64[i] = c64[i].rotate_left(1); + crot8[i] = conv64to8(crot64[i]); + } + + let mut d64 = [0u64; 5]; + let mut d8 = [[0u64; 8]; 5]; + for x in 0..5 { + d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); + d8[x] = conv64to8(d64[x]); + for k in 0..8 { + add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); + } + } + + let mut theta_state64 = state64; + let mut theta_state8 = [[[0u64; 8]; 5]; 5]; + let mut rotation_witness = vec![]; + + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] ^= d64[x]; + theta_state8[y][x] = conv64to8(theta_state64[y][x]); + + for k in 0..8 { + add_xor(state8[y][x][k], d8[x][k], round); + } + + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) + .convert(sizes); + for mask in rep.rep.iter() { + if mask.size != 32 { + add_range(mask.value, mask.size, round); + } + } + rotation_witness.extend(rep.values()); + } + } + + // Rho and Pi steps + let mut rhopi_output64 = [[0u64; 5]; 5]; + let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = + theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); + } + } + + for x in 0..5 { + for y in 0..5 { + rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); + } + } + + // Chi step + + let mut nonlinear64 = [[0u64; 5]; 5]; + let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + nonlinear64[y][x] = + !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; + nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); + + for k in 0..8 { + add_and( + 0xFF - rhopi_output8[y][(x + 1) % 5][k], + rhopi_output8[y][(x + 2) % 5][k], + round, + ); + } + } + } + + let mut chi_output64 = [[0u64; 5]; 5]; + let mut chi_output8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; + chi_output8[y][x] = conv64to8(chi_output64[y][x]); + for k in 0..8 { + add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) + } + } + } + + // Iota step + let mut iota_output64 = chi_output64; + let mut iota_output8 = [[[0u64; 8]; 5]; 5]; + iota_output64[0][0] ^= RC[round]; + + for k in 0..8 { + add_xor(chi_output8[0][0][k], (RC[round] >> (k * 8)) & 0xFF, round); + } + + for x in 0..5 { + for y in 0..5 { + iota_output8[x][y] = conv64to8(iota_output64[x][y]); + } + } + + let all_wits64 = [ + state8.into_iter().flatten().flatten().collect_vec(), + c_aux8.into_iter().flatten().flatten().collect_vec(), + c_temp.into_iter().flatten().collect_vec(), + crot8.into_iter().flatten().collect_vec(), + d8.into_iter().flatten().collect_vec(), + theta_state8.into_iter().flatten().flatten().collect_vec(), + rotation_witness, + rhopi_output8.into_iter().flatten().flatten().collect_vec(), + nonlinear8.into_iter().flatten().flatten().collect_vec(), + chi_output8.into_iter().flatten().flatten().collect_vec(), + iota_output8.into_iter().flatten().flatten().collect_vec(), + ]; + + push_instance(all_wits64.into_iter().flatten().collect_vec()); + + state8 = iota_output8; + state64 = iota_output64; + } + + let mut keccak_output32 = vec![vec![vec![0; 2]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + keccak_output32[x][y] = MaskRepresentation::from( + state8[x][y].into_iter().map(|e| (8, e)).collect_vec(), + ) + .convert(vec![32; 2]) + .values(); + } + } + + push_instance(state8.into_iter().flatten().flatten().collect_vec()); + + // For temporary convenience, use one extra layer to store the correct outputs + // of the circuit This is not used during proving + let lookups = chain!( + (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), + (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), + (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) + ) + .collect_vec(); + + push_instance( + chain!( + keccak_output32.into_iter().flatten().flatten(), + state32, + lookups + ) + .collect_vec(), + ); + } + + let len = layer_wits.len() - 1; + layer_wits[..len].reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + + let mut instances = vec![]; + for state in &states { + let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); + let state_mask32 = state_mask64.convert(vec![32; 50]); + + instances.push( + state_mask32 + .values() + .iter() + .map(|e| *e as u32) + .collect_vec() + .try_into() + .unwrap(), + ); + } + + let num_instances = instances.len(); + let phase1_witness = layout.phase1_witness(KeccakTrace { + instances: instances.clone(), + }); + + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &[]); + + let out_evals = { + let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); + let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); + + if test_outputs { + // Confront outputs with tiny_keccak::keccakf call + let mut instance_outputs = vec![vec![]; num_instances]; + for base in gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .take(KECCAK_OUTPUT_SIZE) + { + assert_eq!(base.len(), num_instances); + for i in 0..num_instances { + instance_outputs[i].push(base[i]); + } + } + + for i in 0..num_instances { + let mut state = states[i]; + keccakf(&mut state); + assert_eq!( + state + .to_vec() + .iter() + .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) + .map(|e| Goldilocks::from_u64(e as u64)) + .collect_vec(), + instance_outputs[i] + ); + } + } + + let out_evals = gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .map(|base| PointAndEval { + point: point.clone(), + eval: subprotocols::utils::evaluate_mle_ext(base, &point), + }) + .collect_vec(); + + assert_eq!( + out_evals.len(), + KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS + ); + + out_evals + }; + + let gkr_circuit = chip.gkr_circuit(); + dbg!(&gkr_circuit.layers.len()); + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_keccakf() { + for _ in 0..3 { + // let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let num_instances = 8; + let mut states: Vec<[u64; 25]> = vec![]; + + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())) + } + run_faster_keccakf(states, true, true); + } + } + + // TODO: make it pass + #[ignore] + #[test] + fn test_keccakf_nonpow2() { + for _ in 0..3 { + // let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let num_instances = 3; + let mut states: Vec<[u64; 25]> = vec![]; + + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())) + } + run_faster_keccakf(states, true, true); + } + } +} diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs new file mode 100644 index 000000000..ee783a711 --- /dev/null +++ b/gkr_iop/src/precompiles/mod.rs @@ -0,0 +1,8 @@ +mod bitwise_keccakf; +mod lookup_keccakf; +mod utils; +pub use bitwise_keccakf::run_keccakf; +pub use lookup_keccakf::{ + AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakParams, KeccakTrace, RANGE_LOOKUPS, + RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, run_faster_keccakf, +}; diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs new file mode 100644 index 000000000..e5b17b3bd --- /dev/null +++ b/gkr_iop/src/precompiles/utils.rs @@ -0,0 +1,167 @@ +use ff_ext::ExtensionField; +use itertools::Itertools; +use p3_field::PrimeCharacteristicRing; +use subprotocols::expression::{Constant, Expression}; + +use crate::evaluation::EvalExpression; + +pub fn zero_expr() -> Expression { + Expression::Const(Constant::Base(0)) +} + +pub fn not8_expr(expr: Expression) -> Expression { + Expression::Const(Constant::Base(0xFF)) - expr +} + +pub fn zero_eval() -> EvalExpression { + EvalExpression::Linear(0, Constant::Base(0), Constant::Base(0)) +} + +pub fn nest(v: &[E::BaseField]) -> Vec> { + v.iter().map(|e| vec![*e]).collect_vec() +} + +pub fn u64s_to_felts(words: Vec) -> Vec { + words.into_iter().map(E::BaseField::from_u64).collect() +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Mask { + pub size: usize, + pub value: u64, +} + +impl Mask { + pub fn new(size: usize, value: u64) -> Self { + if size < 64 { + assert!(value < (1 << size)); + } + Self { size, value } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MaskRepresentation { + pub rep: Vec, +} + +impl From for (usize, u64) { + fn from(mask: Mask) -> Self { + (mask.size, mask.value) + } +} + +impl From<(usize, u64)> for Mask { + fn from(tuple: (usize, u64)) -> Self { + Mask::new(tuple.0, tuple.1) + } +} + +impl From for Vec<(usize, u64)> { + fn from(mask_rep: MaskRepresentation) -> Self { + mask_rep.rep.into_iter().map(|mask| mask.into()).collect() + } +} + +impl From> for MaskRepresentation { + fn from(tuples: Vec<(usize, u64)>) -> Self { + MaskRepresentation { + rep: tuples.into_iter().map(|tuple| tuple.into()).collect(), + } + } +} + +impl MaskRepresentation { + pub fn new(masks: Vec) -> Self { + Self { rep: masks } + } + + pub fn from_bits(bits: Vec, sizes: Vec) -> Self { + assert_eq!(bits.len(), sizes.iter().sum::()); + let mut masks = Vec::new(); + let mut bit_iter = bits.into_iter(); + for size in sizes { + let mut mask = 0; + for i in 0..size { + mask += (1 << i) * bit_iter.next().unwrap(); + } + masks.push(Mask::new(size, mask)); + } + Self { rep: masks } + } + + pub fn to_bits(&self) -> Vec { + self.rep + .iter() + .flat_map(|mask| (0..mask.size).map(move |i| ((mask.value >> i) & 1))) + .collect() + } + + pub fn convert(&self, new_sizes: Vec) -> Self { + let bits = self.to_bits(); + Self::from_bits(bits, new_sizes) + } + + pub fn values(&self) -> Vec { + self.rep.iter().map(|m| m.value).collect_vec() + } + + pub fn masks(&self) -> Vec { + self.rep.clone() + } +} + +#[derive(Debug)] +pub enum CenoLookup { + And(Expression, Expression, Expression), + Xor(Expression, Expression, Expression), + U16(Expression), +} + +impl IntoIterator for CenoLookup { + type Item = Expression; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + match self { + CenoLookup::And(a, b, c) => vec![a, b, c].into_iter(), + CenoLookup::Xor(a, b, c) => vec![a, b, c].into_iter(), + CenoLookup::U16(a) => vec![a].into_iter(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::precompiles::utils::{Mask, MaskRepresentation}; + + #[test] + fn test_mask_representation_from_bits() { + let bits = vec![1, 0, 1, 1, 0, 1, 0, 0]; + let sizes = vec![3, 5]; + let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone()); + assert_eq!(mask_rep.rep.len(), 2); + assert_eq!(mask_rep.rep[0], Mask::new(3, 0b101)); + assert_eq!(mask_rep.rep[1], Mask::new(5, 0b00101)); + } + + #[test] + fn test_mask_representation_to_bits() { + let masks = vec![Mask::new(3, 0b101), Mask::new(5, 0b00101)]; + let mask_rep = MaskRepresentation::new(masks); + let bits = mask_rep.to_bits(); + assert_eq!(bits, vec![1, 0, 1, 1, 0, 1, 0, 0]); + } + + #[test] + fn test_mask_representation_convert() { + let bits = vec![1, 0, 1, 1, 0, 1, 0, 0]; + let sizes = vec![3, 5]; + let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone()); + let new_sizes = vec![4, 4]; + let new_mask_rep = mask_rep.convert(new_sizes); + assert_eq!(new_mask_rep.rep.len(), 2); + assert_eq!(new_mask_rep.rep[0], Mask::new(4, 0b1101)); + assert_eq!(new_mask_rep.rep[1], Mask::new(4, 0b0010)); + } +} diff --git a/sanity.sh b/sanity.sh new file mode 100755 index 000000000..efdaec941 --- /dev/null +++ b/sanity.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +cargo make build + +# Run the first command +cargo run --bin e2e --release -- examples/target/riscv32im-ceno-zkvm-elf/release/examples/secp256k1_add_syscall +if [ $? -ne 0 ]; then + echo "Command 1 failed" + exit 1 +fi + +# Run the second command +cargo run --bin e2e --release -- examples/target/riscv32im-ceno-zkvm-elf/release/examples/ceno_rt_keccak +if [ $? -ne 0 ]; then + echo "Command 2 failed" + exit 1 +fi diff --git a/subprotocols/Cargo.toml b/subprotocols/Cargo.toml index 35e5c9408..c5fc1ed8e 100644 --- a/subprotocols/Cargo.toml +++ b/subprotocols/Cargo.toml @@ -17,6 +17,7 @@ multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" p3-field.workspace = true rand.workspace = true rayon.workspace = true +serde.workspace = true thiserror = "1" transcript = { path = "../transcript" } diff --git a/subprotocols/examples/zerocheck_logup.rs b/subprotocols/examples/zerocheck_logup.rs index 4390db649..36c227b7e 100644 --- a/subprotocols/examples/zerocheck_logup.rs +++ b/subprotocols/examples/zerocheck_logup.rs @@ -51,6 +51,7 @@ fn run_verifier( let verifier = ZerocheckVerifierState::new( vec![*ans], vec![expr], + vec![], vec![point], proof, &challenges, diff --git a/subprotocols/src/error.rs b/subprotocols/src/error.rs index 5dad7bc9d..ca8eddefe 100644 --- a/subprotocols/src/error.rs +++ b/subprotocols/src/error.rs @@ -4,6 +4,6 @@ use crate::expression::Expression; #[derive(Clone, Debug, Error)] pub enum VerifierError { - #[error("Claim not match: expr: {0:?}, expect: {1:?}, got: {2:?}")] - ClaimNotMatch(Expression, E, E), + #[error("Claim not match: expr: {0:?}\n (expr name: {3:?})\n expect: {1:?}, got: {2:?}")] + ClaimNotMatch(Expression, E, E, String), } diff --git a/subprotocols/src/expression.rs b/subprotocols/src/expression.rs index 7750280a6..5fa2a1556 100644 --- a/subprotocols/src/expression.rs +++ b/subprotocols/src/expression.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use ff_ext::ExtensionField; +use serde::Serialize; mod evaluate; mod op; @@ -9,7 +10,7 @@ mod macros; pub type Point = Arc>; -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] pub enum Constant { /// Base field Base(i64), @@ -31,7 +32,7 @@ impl Default for Constant { } } -#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] pub enum Witness { /// Base field polynomial (index). BasePoly(usize), @@ -47,7 +48,7 @@ impl Default for Witness { } } -#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] pub enum Expression { /// Constant Const(Constant), diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs index d01c2702a..0de80b42d 100644 --- a/subprotocols/src/sumcheck.rs +++ b/subprotocols/src/sumcheck.rs @@ -3,6 +3,7 @@ use std::{iter, mem, sync::Arc, vec}; use ark_std::log2; use ff_ext::ExtensionField; use itertools::chain; +use serde::Serialize; use transcript::Transcript; use crate::{ @@ -40,6 +41,7 @@ where num_vars: usize, } +#[derive(Clone, Serialize)] pub struct SumcheckProof { /// Messages for each round. pub univariate_polys: Vec>>, @@ -179,6 +181,7 @@ where sigma: E, expr: Expression, proof: SumcheckProof, + expr_names: Vec, challenges: &'a [E], transcript: &'a mut Trans, out_points: Vec<&'a [E]>, @@ -202,7 +205,11 @@ where proof: SumcheckProof, challenges: &'a [E], transcript: &'a mut Trans, + expr_names: Vec, ) -> Self { + // Fill in missing debug data + let mut expr_names = expr_names; + expr_names.resize(1, "nothing".to_owned()); Self { sigma, expr, @@ -210,6 +217,7 @@ where challenges, transcript, out_points, + expr_names, } } @@ -221,6 +229,7 @@ where challenges, transcript, out_points, + expr_names, } = self; let SumcheckProof { univariate_polys, @@ -257,11 +266,13 @@ where &in_point, challenges, ); + if expected_claim != got_claim { return Err(VerifierError::ClaimNotMatch( expr, expected_claim, got_claim, + expr_names[0].clone(), )); } @@ -324,6 +335,7 @@ mod test { proof, &challenges, &mut verifier_transcript, + vec![], ); verifier.verify().expect("verification failed"); diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs index 547e2f711..cfedf1dc4 100644 --- a/subprotocols/src/zerocheck.rs +++ b/subprotocols/src/zerocheck.rs @@ -196,6 +196,7 @@ where inv_of_one_minus_points: Vec>, exprs: Vec<(Expression, &'a [E])>, proof: SumcheckProof, + expr_names: Vec, challenges: &'a [E], transcript: &'a mut Trans, } @@ -208,11 +209,16 @@ where pub fn new( sigmas: Vec, exprs: Vec, + expr_names: Vec, points: Vec<&'a [E]>, proof: SumcheckProof, challenges: &'a [E], transcript: &'a mut Trans, ) -> Self { + // Fill in missing debug data + let mut expr_names = expr_names; + expr_names.resize(exprs.len(), "nothing".to_owned()); + let inv_of_one_minus_points = points .iter() .flat_map(|point| point.iter().map(|p| E::ONE - *p).collect_vec()) @@ -235,6 +241,7 @@ where proof, challenges, transcript, + expr_names, } } @@ -246,6 +253,7 @@ where proof, challenges, transcript, + expr_names, .. } = self; let SumcheckProof { @@ -288,7 +296,10 @@ where ); // Check the final evaluations. - for (expected_claim, (expr, _)) in izip!(expected_claims, exprs) { + assert_eq!(expr_names.len(), exprs.len()); + // assert_eq!(expected_claims.len(), expr_names.len()); + + for (expected_claim, (expr, _), expr_name) in izip!(expected_claims, exprs, expr_names) { let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); if expected_claim != got_claim { @@ -296,6 +307,7 @@ where expr, expected_claim, got_claim, + expr_name.clone(), )); } } @@ -355,6 +367,7 @@ mod test { let verifier = ZerocheckVerifierState::new( sigmas, exprs, + vec![], points, proof, &challenges,