From 5e18d9a006f07cd8a3c732c3f5ffeb773d272c65 Mon Sep 17 00:00:00 2001 From: Mihai Date: Mon, 24 Mar 2025 19:14:50 +0200 Subject: [PATCH 01/17] add keccak_f precompiles & utilities --- Cargo.lock | 49 + Cargo.toml | 1 + gkr_iop/Cargo.toml | 13 + gkr_iop/benches/faster_keccak.rs | 38 + gkr_iop/benches/keccak_f.rs | 38 + gkr_iop/examples/multi_layer_logup.rs | 40 +- gkr_iop/src/chip/builder.rs | 31 +- gkr_iop/src/gkr/layer.rs | 33 +- gkr_iop/src/gkr/layer/linear_layer.rs | 6 +- gkr_iop/src/gkr/layer/sumcheck_layer.rs | 1 + gkr_iop/src/gkr/layer/zerocheck_layer.rs | 1 + gkr_iop/src/lib.rs | 7 +- gkr_iop/src/precompiles/faster_keccak.rs | 1162 ++++++++++++++++++++++ gkr_iop/src/precompiles/keccak_f.rs | 600 +++++++++++ gkr_iop/src/precompiles/mod.rs | 5 + gkr_iop/src/precompiles/utils.rs | 190 ++++ subprotocols/examples/zerocheck_logup.rs | 1 + subprotocols/src/error.rs | 4 +- subprotocols/src/sumcheck.rs | 6 + subprotocols/src/zerocheck.rs | 8 +- 20 files changed, 2189 insertions(+), 45 deletions(-) create mode 100644 gkr_iop/benches/faster_keccak.rs create mode 100644 gkr_iop/benches/keccak_f.rs create mode 100644 gkr_iop/src/precompiles/faster_keccak.rs create mode 100644 gkr_iop/src/precompiles/keccak_f.rs create mode 100644 gkr_iop/src/precompiles/mod.rs create mode 100644 gkr_iop/src/precompiles/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 3e76c8d84..b6418e00a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1172,15 +1172,18 @@ name = "gkr_iop" version = "0.1.0" dependencies = [ "ark-std", + "criterion", "ff_ext", "itertools 0.13.0", "multilinear_extensions", + "ndarray", "p3-field", "p3-goldilocks", "rand", "rayon", "subprotocols", "thiserror 1.0.69", + "tiny-keccak", "transcript", ] @@ -1487,6 +1490,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -1589,6 +1602,21 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nimue" version = "0.2.0" @@ -2089,6 +2117,21 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "poseidon" version = "0.1.0" @@ -2296,6 +2339,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 4fd37c506..560c75a18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" itertools = "0.13" +ndarray = "*" num-bigint = { version = "0.4.6" } num-derive = "0.4" num-traits = "0.2" diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index bbe7e85e0..323a24233 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -14,10 +14,23 @@ ark-std.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } +ndarray.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true rand.workspace = true rayon.workspace = true subprotocols = { path = "../subprotocols" } thiserror = "1" +tiny-keccak.workspace = true transcript = { path = "../transcript" } + +[dev-dependencies] +criterion.workspace = true + +[[bench]] +harness = false +name = "keccak_f" + +[[bench]] +harness = false +name = "faster_keccak" \ No newline at end of file diff --git a/gkr_iop/benches/faster_keccak.rs b/gkr_iop/benches/faster_keccak.rs new file mode 100644 index 000000000..380ef91f5 --- /dev/null +++ b/gkr_iop/benches/faster_keccak.rs @@ -0,0 +1,38 @@ +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::run_faster_keccakf; + +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("keccak_f")); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccak_f", format!("keccak_f")), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); + let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + let _ = black_box(run_faster_keccakf(vec![state1, state2], false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/benches/keccak_f.rs b/gkr_iop/benches/keccak_f.rs new file mode 100644 index 000000000..b3b330725 --- /dev/null +++ b/gkr_iop/benches/keccak_f.rs @@ -0,0 +1,38 @@ +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::{run_faster_keccakf, run_keccakf}; +use p3_field::extension::BinomialExtensionField; +use p3_goldilocks::Goldilocks; +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("keccak_f")); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccak_f", format!("keccak_f")), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + let _ = black_box(run_keccakf(state, false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index b97d3c9ba..6bf7928d2 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -2,18 +2,18 @@ use std::{marker::PhantomData, mem, sync::Arc}; use ff_ext::ExtensionField; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, evaluation::{EvalExpression, PointAndEval}, gkr::{ - GKRCircuitWitness, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, + GKRCircuitWitness, GKRProverOutput, }, + ProtocolBuilder, ProtocolWitnessGenerator, }; -use itertools::{Itertools, izip}; -use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; +use itertools::{izip, Itertools}; +use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; use p3_goldilocks::Goldilocks; -use rand::{Rng, rngs::OsRng}; +use rand::{rngs::OsRng, Rng}; use subprotocols::expression::{Constant, Expression}; use transcript::{BasicTranscript, Transcript}; @@ -64,14 +64,15 @@ impl ProtocolBuilder for TowerChipLayout { let height = self.params.height; let lookup_challenge = Expression::Const(self.lookup_challenge.clone()); - self.output_cumulative_sum = chip.allocate_output_evals(); + self.output_cumulative_sum = chip.allocate_output_evals::<2>().try_into().unwrap(); // Tower layers let ([updated_table, count], challenges) = (0..height).fold( (self.output_cumulative_sum.clone(), vec![]), |([den, num], challenges), i| { let [den_0, den_1, num_0, num_1] = if i == height - 1 { - // Allocate witnesses in the extension field, except numerator inputs in the base field. + // Allocate witnesses in the extension field, except numerator inputs in the + // base field. let ([num_0, num_1], [den_0, den_1]) = chip.allocate_wits_in_layer(); [den_0, den_1, num_0, num_1] } else { @@ -86,17 +87,20 @@ impl ProtocolBuilder for TowerChipLayout { num_1.0.into(), ]; let (in_bases, in_exts) = if i == height - 1 { - (vec![num_0.1.clone(), num_1.1.clone()], vec![ - den_0.1.clone(), - den_1.1.clone(), - ]) + ( + vec![num_0.1.clone(), num_1.1.clone()], + vec![den_0.1.clone(), den_1.1.clone()], + ) } else { - (vec![], vec![ - den_0.1.clone(), - den_1.1.clone(), - num_0.1.clone(), - num_1.1.clone(), - ]) + ( + vec![], + vec![ + den_0.1.clone(), + den_1.1.clone(), + num_0.1.clone(), + num_1.1.clone(), + ], + ) }; chip.add_layer(Layer::new( format!("Tower_layer_{}", i), @@ -109,6 +113,7 @@ impl ProtocolBuilder for TowerChipLayout { in_bases, in_exts, vec![den, num], + vec![], )); let [challenge] = chip.allocate_challenges(); ( @@ -138,6 +143,7 @@ impl ProtocolBuilder for TowerChipLayout { vec![table.1.clone()], vec![], vec![updated_table], + vec![], )); chip.allocate_base_opening(self.committed_table_id, table.1); diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index ac742fe04..c05165d3d 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -1,5 +1,6 @@ use std::array; +use itertools::Itertools; use subprotocols::expression::{Constant, Witness}; use crate::{ @@ -22,10 +23,11 @@ impl Chip { array::from_fn(|i| i + self.n_committed_exts - N) } - /// Allocate `Witness` and `EvalExpression` for the input polynomials in a layer. - /// Where `Witness` denotes the index and `EvalExpression` denotes the position - /// to place the evaluation of the polynomial after processing the layer prover - /// for each polynomial. This should be called at most once for each layer! + /// Allocate `Witness` and `EvalExpression` for the input polynomials in a + /// layer. Where `Witness` denotes the index and `EvalExpression` + /// denotes the position to place the evaluation of the polynomial after + /// processing the layer prover for each polynomial. This should be + /// called at most once for each layer! #[allow(clippy::type_complexity)] pub fn allocate_wits_in_layer( &mut self, @@ -51,9 +53,16 @@ impl Chip { } /// Generate the evaluation expression for each output. - pub fn allocate_output_evals(&mut self) -> [EvalExpression; N] { + pub fn allocate_output_evals(&mut self) -> Vec +// -> [EvalExpression; N] + { self.n_evaluations += N; - array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + //array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + // TODO: hotfix to avoid stack overflow, fix later + (0..N) + .into_iter() + .map(|i| EvalExpression::Single(i + self.n_evaluations - N)) + .collect_vec() } /// Allocate challenges. @@ -62,14 +71,16 @@ impl Chip { array::from_fn(|i| Constant::Challenge(i + self.n_challenges - N)) } - /// Allocate a PCS opening action to a base polynomial with index `wit_index`. - /// The `EvalExpression` represents the expression to compute the evaluation. + /// Allocate a PCS opening action to a base polynomial with index + /// `wit_index`. The `EvalExpression` represents the expression to + /// compute the evaluation. pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.base_openings.push((wit_index, eval)); } - /// Allocate a PCS opening action to an ext polynomial with index `wit_index`. - /// The `EvalExpression` represents the expression to compute the evaluation. + /// Allocate a PCS opening action to an ext polynomial with index + /// `wit_index`. The `EvalExpression` represents the expression to + /// compute the evaluation. pub fn allocate_ext_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.ext_openings.push((wit_index, eval)); } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 2f23a6d77..00749d822 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -33,12 +33,13 @@ pub struct Layer { pub ty: LayerType, /// Challenges generated at the beginning of the layer protocol. pub challenges: Vec, - /// Expressions to prove in this layer. For zerocheck and linear layers, each - /// expression corresponds to an output. While in sumcheck, there is only 1 - /// expression, which corresponds to the sum of all outputs. This design is - /// for the convenience when building the following expression: - /// `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * eq(p_1, x)) expr(x)`. - /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. + /// Expressions to prove in this layer. For zerocheck and linear layers, + /// each expression corresponds to an output. While in sumcheck, there + /// is only 1 expression, which corresponds to the sum of all outputs. + /// This design is for the convenience when building the following + /// expression: `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * + /// eq(p_1, x)) expr(x)`. where `vec![e_0, beta * e_1]` will be the + /// output evaluation expressions. pub exprs: Vec, /// Positions to place the evaluations of the base inputs of this layer. pub in_bases: Vec, @@ -47,6 +48,9 @@ pub struct Layer { /// The expressions of the evaluations from the succeeding layers, which are /// connected to the outputs of this layer. pub outs: Vec, + + // For debugging purposes + pub expr_names: Vec, } #[derive(Clone, Debug)] @@ -66,7 +70,15 @@ impl Layer { in_bases: Vec, in_exts: Vec, outs: Vec, + expr_names: Vec, ) -> Self { + let mut expr_names = expr_names; + if expr_names.len() < exprs.len() { + expr_names.extend(vec![ + "unavailable".to_string(); + exprs.len() - expr_names.len() + ]); + } Self { name, ty, @@ -75,6 +87,7 @@ impl Layer { in_bases, in_exts, outs, + expr_names, } } @@ -208,10 +221,10 @@ impl Layer { ext_mle_evals: &[E], point: &Point, ) { - for (value, pos) in izip!(chain![base_mle_evals, ext_mle_evals], chain![ - &self.in_bases, - &self.in_exts - ]) { + for (value, pos) in izip!( + chain![base_mle_evals, ext_mle_evals], + chain![&self.in_bases, &self.in_exts] + ) { *(pos.entry_mut(claims)) = PointAndEval { point: point.clone(), eval: *value, diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 40d8a6ec8..6367d1b10 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -1,5 +1,5 @@ use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::{izip, Itertools}; use subprotocols::{ error::VerifierError, expression::Point, @@ -80,7 +80,7 @@ impl LinearLayer for Layer { transcript.append_field_element_exts(&ext_mle_evals); transcript.append_field_element_exts(&base_mle_evals); - for (sigma, expr) in izip!(sigmas, &self.exprs) { + for (sigma, expr, expr_name) in izip!(sigmas, &self.exprs, &self.expr_names) { let got = expr.evaluate( &ext_mle_evals, &base_mle_evals, @@ -91,7 +91,7 @@ impl LinearLayer for Layer { if *sigma != got { return Err(BackendError::LayerVerificationFailed( self.name.clone(), - VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got), + VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got, expr_name.clone()), )); } } diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index afb65d9a2..01bd00e8f 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -66,6 +66,7 @@ impl SumcheckLayer for Layer { proof, challenges, transcript, + vec![], ); verifier_state diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index e9e9bc577..4df79c405 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -63,6 +63,7 @@ impl ZerocheckLayer for Layer { let verifier_state = ZerocheckVerifierState::new( sigmas, self.exprs.clone(), + vec![], out_points, proof, challenges, diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index dff90eb6d..06c83f486 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -9,6 +9,7 @@ pub mod chip; pub mod error; pub mod evaluation; pub mod gkr; +pub mod precompiles; pub mod utils; pub trait ProtocolBuilder: Sized { @@ -48,12 +49,14 @@ where fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness; } -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, +// `gkr_phase` and `opening_phase`. pub struct ProtocolProver, PCS>( PhantomData<(E, Trans, PCS)>, ); -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, +// `gkr_phase` and `opening_phase`. pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs new file mode 100644 index 000000000..28ea0170d --- /dev/null +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -0,0 +1,1162 @@ +use std::{ + array::from_fn, + cmp::Ordering, + iter::{once, zip}, + marker::PhantomData, + sync::Arc, +}; + +use crate::{ + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + layer::{Layer, LayerType, LayerWitness}, + GKRCircuitWitness, GKRProverOutput, + }, + precompiles::utils::{nest, not8_expr, zero_expr, MaskRepresentation}, + ProtocolBuilder, ProtocolWitnessGenerator, +}; +use ndarray::{range, s, ArrayView, Ix2, Ix3}; + +use super::utils::{u64s_to_felts, zero_eval, CenoLookup}; +use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; + +use ff_ext::{ExtensionField, SmallField}; +use itertools::{chain, iproduct, zip_eq, Itertools}; +use p3_goldilocks::Goldilocks; +use subprotocols::expression::{Constant, Expression, Witness}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; + +#[derive(Clone, Debug, Default)] +struct KeccakParams {} + +#[derive(Clone, Debug, Default)] +struct KeccakLayout { + params: KeccakParams, + + committed_bits_id: usize, + + result: Vec, + _marker: PhantomData, +} + +fn expansion_expr(expansion: &[(usize, Witness)]) -> Expression { + let (total, ret) = expansion + .iter() + .rev() + .fold((0, zero_expr()), |acc, (sz, felt)| { + ( + acc.0 + sz, + acc.1 * Expression::Const(Constant::Base(1 << sz)) + felt.clone().into(), + ) + }); + + assert_eq!(total, SIZE); + ret +} + +/// Compute an adequate split of 64-bits into chunks for performing a rotation +/// by `delta`. The first element of the return value is the vec of chunk sizes. +/// The second one is the length of its suffix that needs to be rotated +fn rotation_split(delta: usize) -> (Vec, usize) { + let delta = delta % 64; + + if delta == 0 { + return (vec![32, 32], 0); + } + + // This split meets all requirements except for <= 16 sizes + let split32 = match delta.cmp(&32) { + Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], + Ordering::Equal => vec![32, 32], + Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], + }; + + // Split off large chunks + let split16 = split32 + .into_iter() + .flat_map(|size| { + assert!(size < 32); + if size <= 16 { + vec![size] + } else { + vec![16, size - 16] + } + }) + .collect_vec(); + + let mut sum = 0; + for (i, size) in split16.iter().rev().enumerate() { + sum += size; + if sum == delta { + return (split16, i + 1); + } + } + + panic!(); +} + +struct ConstraintSystem { + expressions: Vec, + expr_names: Vec, + evals: Vec, + and_lookups: Vec, + xor_lookups: Vec, + range_lookups: Vec, +} + +impl ConstraintSystem { + fn new() -> Self { + ConstraintSystem { + expressions: vec![], + evals: vec![], + expr_names: vec![], + and_lookups: vec![], + xor_lookups: vec![], + range_lookups: vec![], + } + } + + fn add_constraint(&mut self, expr: Expression, name: String) { + self.expressions.push(expr); + self.evals.push(zero_eval()); + self.expr_names.push(name); + } + + fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { + self.and_lookups.push(CenoLookup::And(a, b, c)); + } + + fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { + self.xor_lookups.push(CenoLookup::Xor(a, b, c)); + } + + /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. + /// In general it can be done by two U16 checks: one for `value` and one for + /// `value << (16 - size)`. + fn lookup_range(&mut self, value: Expression, size: usize) { + assert!(size <= 16); + self.range_lookups.push(CenoLookup::U16(value.clone())); + if size < 16 { + self.range_lookups.push(CenoLookup::U16( + value * Expression::Const(Constant::Base(1 << (16 - size))), + )) + } + } + + fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { + self.add_constraint(lhs - rhs, name); + } + + // Constrains that lhs and rhs encode the same value of SIZE bits + // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) + // This needs to be constrained separately + fn constrain_reps_eq( + &mut self, + lhs: &[(usize, Witness)], + rhs: &[(usize, Witness)], + name: String, + ) { + self.add_constraint( + expansion_expr::(lhs) - expansion_expr::(rhs), + name, + ); + } + + /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. + /// `rot8` and `input8` each consist of 8 chunks of 8-bits. + /// + /// `split_rep` is a chunk representation of the input which + /// allows to reduce the required rotation to an array rotation. It may use + /// non-uniform chunks. + /// + /// For example, when `delta = 2`, the 64 bits are split into chunks of + /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the + /// least significant bits so a left rotation will become a right rotation + /// of the array). To perform the required rotation, we can + /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. + /// + /// In the first step, we check that `rot8` and `split_rep` represent the + /// same 64 bits. In the second step we check that `rot8` and the appropiate + /// array rotation of `split_rep` represent the same 64 bits. + /// + /// This type of representation-equality check is done by packing chunks + /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, + /// 2c] to the first 4 elements of `rot8`). In addition, we do range + /// checks on `split_rep` which check that the felts meet the required + /// sizes. + /// + /// This algorithm imposes the following general requirements for + /// `split_rep`: + /// - There exists a suffix of `split_rep` which sums to exactly `delta`. + /// This suffix can contain several elements. + /// - Chunk sizes are at most 16 (so they can be range-checked) or they are + /// exactly equal to 32. + /// - There exists a prefix of chunks which sums exactly to 32. This must + /// hold for the rotated array as well. + /// - The number of chunks should be as small as possible. + /// + /// Consult the method `rotation_split` to see how splits are computed for a + /// given `delta + /// + /// Note that the function imposes range checks on chunk values, but it + /// makes two exceptions: + /// 1. It doesn't check the 8-bit reps (input and output). This is + /// because all 8-bit reps in the global circuit are implicitly + /// range-checked because they are lookup arguments. + /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit + /// chunk value is checked to be equal to the composition of 4 8-bit + /// chunks. As mentioned in 1., these can be trusted to be range + /// checked, so the resulting 32-bit is correct by construction as + /// well. + fn constrain_left_rotation64( + &mut self, + input8: &[Witness], + split_rep: &[(usize, Witness)], + rot8: &[Witness], + delta: usize, + label: String, + ) { + assert_eq!(input8.len(), 8); + assert_eq!(rot8.len(), 8); + + // Assert that the given split witnesses are correct for this delta + let (sizes, chunks_rotation) = rotation_split(delta); + assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); + + // Lookup ranges + for (size, elem) in split_rep { + if *size != 32 { + self.lookup_range(elem.clone().into(), *size); + } + } + + // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are + // the same 64 bitstring + let mut helper = |rep8: &[Witness], repX: &[(usize, Witness)], chunks_rotation: usize| { + // Do the same thing for the two 32-bit halves + let mut repX = repX.to_owned(); + repX.rotate_right(chunks_rotation); + + for i in 0..2 { + // The respective 4 elements in the byte representation + let lhs = rep8[4 * i..4 * (i + 1)] + .iter() + .map(|wit| (8, wit.clone())) + .collect_vec(); + let cnt = repX.len() / 2; + let rhs = &repX[cnt * i..cnt * (i + 1)]; + + assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); + + self.constrain_reps_eq::<32>( + &lhs, + &rhs, + format!( + "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", + sizes + ), + ); + } + }; + + helper(input8, split_rep, 0); + helper(rot8, split_rep, chunks_rotation); + } +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], +]; + +pub const KECCAK_INPUT_SIZE: usize = 50; +pub const KECCAK_OUTPUT_SIZE: usize = 50; + +pub const AND_LOOKUPS_PER_ROUND: usize = 200; +pub const XOR_LOOKUPS_PER_ROUND: usize = 608; +pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; +pub const LOOKUPS_PER_ROUND: usize = + AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; + +macro_rules! allocate_and_split { + ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ + let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); + let mut iter = witnesses.into_iter(); + ( + $( + iter.by_ref().take($size).collect_vec(), + )* + ) + }}; + } + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(params: Self::Params) -> Self { + Self { + params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_bits_id] = chip.allocate_committed_base(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_outputs = + chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUPS_PER_ROUND * ROUNDS }>(); + + let mut final_outputs_iter = final_outputs.iter(); + + let [keccak_output32, keccak_input32, lookup_outputs] = [ + KECCAK_OUTPUT_SIZE, + KECCAK_INPUT_SIZE, + LOOKUPS_PER_ROUND * ROUNDS, + ] + .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); + + let keccak_output32 = keccak_output32.to_vec(); + let keccak_input32 = keccak_input32.to_vec(); + let lookup_outputs = lookup_outputs.to_vec(); + + let (keccak_output8, []) = chip.allocate_wits_in_layer::<200, 0>(); + + let keccak_output8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &keccak_output8).unwrap(); + + let mut expressions = vec![]; + let mut evals = vec![]; + let mut expr_names = vec![]; + + let mut lookup_index = 0; + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::<32>( + &keccak_output8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0.clone())) + .collect_vec() + .as_slice(), + ); + expressions.push(expr); + evals.push(keccak_output32[evals.len()].clone()); + expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); + } + } + } + + chip.add_layer(Layer::new( + format!("build 32-bit output"), + LayerType::Zerocheck, + expressions, + vec![], + keccak_output8 + .into_iter() + .map(|e| e.1.clone()) + .collect_vec(), + vec![], + evals, + expr_names, + )); + + let state8_loop = (0..ROUNDS).rev().into_iter().fold( + keccak_output8.iter().map(|e| e.1.clone()).collect_vec(), + |round_output, round| { + #[allow(non_snake_case)] + let ( + state8, + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + ) = allocate_and_split!( + chip, 1656, 200, 200, 30, 40, 40, 200, 146, 200, 200, 200, 200 + ); + + let total_witnesses = 200 + 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 200 + 200; + // dbg!(total_witnesses); + assert_eq!(1656, total_witnesses); + + let bases = chain!( + state8.clone(), + c_aux.clone(), + c_temp.clone(), + c_rot.clone(), + d.clone(), + theta_output.clone(), + rotation_witness.clone(), + rhopi_output.clone(), + nonlinear.clone(), + chi_output.clone(), + iota_output.clone(), + ) + .collect_vec(); + + // TODO: replace ndarrays + + // Input state of the round in 8-bit chunks + let state8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + let mut system = ConstraintSystem::new(); + + // The purpose is to compute the auxiliary array + // c[i] = XOR (state[j][i]) for j in 0..5 + // We unroll it into + // c_aux[i][j] = XOR (state[k][i]) for k in 0..j + // We use c_aux[i][4] instead of c[i] + // c_aux is also stored in 8-bit chunks + let c_aux: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); + + for i in 0..5 { + for k in 0..8 { + // Initialize first element + system.constrain_eq( + state8[[0, i, k]].0.into(), + c_aux[[i, 0, k]].0.into(), + "init c_aux".to_string(), + ); + } + for j in 1..5 { + // Check xor using lookups over all chunks + for k in 0..8 { + system.lookup_xor8( + c_aux[[i, j - 1, k]].0.into(), + state8[[j, i, k]].0.into(), + c_aux[[i, j, k]].0.into(), + ); + } + } + } + + // Compute c_rot[i] = c[i].rotate_left(1) + // To understand how rotations are performed in general, consult the + // documentation of `constrain_left_rotation64`. Here c_temp is the split + // witness for a 1-rotation. + + let c_temp: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 6), &c_temp).unwrap(); + let c_rot: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &c_rot).unwrap(); + + let (sizes, _) = rotation_split(1); + + for i in 0..5 { + assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); + + system.constrain_left_rotation64( + &c_aux.slice(s![i, 4, ..]).iter().map(|e| e.0).collect_vec(), + &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) + .map(|(e, sz)| (*sz, e.0)) + .collect_vec(), + &c_rot.slice(s![i, ..]).iter().map(|e| e.0).collect_vec(), + 1, + "theta rotation".to_string(), + ); + } + + // d is computed simply as XOR of required elements of c (and rotations) + // again stored as 8-bit chunks + let d: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &d).unwrap(); + + for i in 0..5 { + for k in 0..8 { + system.lookup_xor8( + c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), + c_rot[[(i + 1) % 5, k]].0.into(), + d[[i, k]].0.into(), + ) + } + } + + // output state of the Theta sub-round, simple XOR, in 8-bit chunks + let theta_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_xor8( + state8[[j, i, k]].0.into(), + d[[i, k]].0.into(), + theta_output[[j, i, k]].0.into(), + ) + } + } + } + + // output state after applying both Rho and Pi sub-rounds + // sub-round Pi is a simple permutation of 64-bit lanes + // sub-round Rho requires rotations + let rhopi_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); + + // iterator over split witnesses + let mut rotation_witness = rotation_witness.iter(); + + for i in 0..5 { + for j in 0..5 { + let arg = theta_output + .slice(s!(j, i, ..)) + .iter() + .map(|e| e.0) + .collect_vec(); + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); + let many = sizes.len(); + let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) + .map(|(sz, (wit, _))| (sz, wit.clone())) + .collect_vec(); + let arg_rotated = rhopi_output + .slice(s!((2 * i + 3 * j) % 5, j, ..)) + .iter() + .map(|e| e.0) + .collect_vec(); + system.constrain_left_rotation64( + &arg, + &rep_split, + &arg_rotated, + ROTATION_CONSTANTS[j][i], + format!("RHOPI {i}, {j}"), + ); + } + } + + let chi_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); + + // for the Chi sub-round, we use an intermediate witness storing the result of + // the required AND + let nonlinear: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_and8( + not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), + rhopi_output[[j, (i + 2) % 5, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + ); + + system.lookup_xor8( + rhopi_output[[j, i, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + ); + } + } + } + + // TODO: 24/25 elements stay the same after Iota; eliminate duplication? + let iota_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + if i == 0 && j == 0 { + for k in 0..8 { + system.lookup_xor8( + chi_output[[j, i, k]].0.into(), + Expression::Const(Constant::Base( + ((RC[round] >> (k * 8)) & 0xFF) as i64, + )), + iota_output[[j, i, k]].0.into(), + ); + } + } else { + for k in 0..8 { + system.constrain_eq( + iota_output[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + "nothing special".to_string(), + ); + } + } + } + } + + let ConstraintSystem { + mut expressions, + mut expr_names, + mut evals, + and_lookups, + xor_lookups, + range_lookups, + .. + } = system; + + iota_output + .into_iter() + .enumerate() + .map(|(i, val)| { + expressions.push(val.0.into()); + expr_names.push(format!("iota_output {i}")); + evals.push(round_output[i].clone()); + }) + .count(); + + // TODO: use real challenge + let alpha = Constant::Base(1 << 8); + let beta = Constant::Base(0); + + // Send all lookups to the final output layer + for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups).enumerate() { + expressions.push(lookup.compress(alpha.clone(), beta.clone())); + expr_names.push(format!("{i}th: {:?}", lookup)); + evals.push(lookup_outputs[lookup_index].clone()); + lookup_index += 1; + } + + chip.add_layer(Layer::new( + format!("Round {round}"), + LayerType::Zerocheck, + expressions, + vec![], + bases.into_iter().map(|e| e.1).collect_vec(), + vec![], + evals, + expr_names, + )); + + state8 + .into_iter() + .map(|e| e.1.clone()) + .collect_vec() + .try_into() + .unwrap() + }, + ); + + let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); + + let state8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + let mut expressions = vec![]; + let mut evals = vec![]; + let mut expr_names = vec![]; + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::<32>( + &state8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0.clone())) + .collect_vec() + .as_slice(), + ); + expressions.push(expr); + evals.push(keccak_input32[evals.len()].clone()); + expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); + } + } + } + + // TODO: eliminate this duplication + zip_eq(state8.iter(), state8_loop.iter()) + .map(|(e, e_loop)| { + expressions.push(e.0.clone().into()); + evals.push(e_loop.clone()); + expr_names.push(format!("state8 identity")); + }) + .count(); + + chip.add_layer(Layer::new( + format!("build 32-bit input"), + LayerType::Zerocheck, + expressions, + vec![], + state8.into_iter().map(|e| e.1.clone()).collect_vec(), + vec![], + evals, + expr_names, + )); + } +} + +pub struct KeccakTrace { + pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut res = vec![]; + for instance in phase1.instances { + res.push(u64s_to_felts::( + instance.into_iter().map(|e| e as u64).collect_vec(), + )); + } + res + } + + fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + let n_layers = 24 + 2 + 1; + let mut layer_wits = vec![ + LayerWitness { + bases: vec![], + exts: vec![], + num_vars: 1 + }; + n_layers + ]; + + for com_state in phase1 { + fn conv64to8(input: u64) -> [u64; 8] { + MaskRepresentation::new(vec![(64, input).into()]) + .convert(vec![8; 8]) + .values() + .try_into() + .unwrap() + } + + let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; + + let mut add_and = |a: u64, b: u64, round: usize| { + let c = a & b; + and_lookups[round].push((c << 16) + (b << 8) + a); + }; + + let mut add_xor = |a: u64, b: u64, round: usize| { + let c = a ^ b; + xor_lookups[round].push((c << 16) + (b << 8) + a); + }; + + let mut add_range = |value: u64, size: usize, round: usize| { + assert!(size <= 16, "{size}"); + range_lookups[round].push(value); + if size < 16 { + range_lookups[round].push(value << (16 - size)); + } + }; + + let state32 = com_state + .into_iter() + // TODO double check assumptions about canonical + .map(|e| e.to_canonical_u64()) + .collect_vec(); + let mut state64 = [[0u64; 5]; 5]; + let mut state8 = [[[0u64; 8]; 5]; 5]; + + zip_eq(iproduct!(0..5, 0..5), state32.clone().iter().tuples()) + .map(|((x, y), (lo, hi))| { + state64[x][y] = lo | (hi << 32); + }) + .count(); + + for x in 0..5 { + for y in 0..5 { + state8[x][y] = conv64to8(state64[x][y]); + } + } + + let mut curr_layer = 0; + let mut push_instance = |wits: Vec| { + let felts = u64s_to_felts::(wits); + if layer_wits[curr_layer].bases.is_empty() { + layer_wits[curr_layer] = LayerWitness::new(nest::(&felts), vec![]); + } else { + for (i, bases) in layer_wits[curr_layer].bases.iter_mut().enumerate() { + bases.push(felts[i]); + } + } + curr_layer += 1; + }; + + push_instance(state8.clone().into_iter().flatten().flatten().collect_vec()); + + for round in 0..24 { + let mut c_aux64 = [[0u64; 5]; 5]; + let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + + for i in 0..5 { + c_aux64[i][0] = state64[0][i]; + c_aux8[i][0] = conv64to8(c_aux64[i][0]); + for j in 1..5 { + c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; + c_aux8[i][j] = conv64to8(c_aux64[i][j]); + + for k in 0..8 { + add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); + } + } + } + + let mut c64 = [0u64; 5]; + let mut c8 = [[0u64; 8]; 5]; + + for x in 0..5 { + c64[x] = c_aux64[x][4]; + c8[x] = conv64to8(c64[x]); + } + + let mut c_temp = [[0u64; 6]; 5]; + for i in 0..5 { + let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) + .convert(vec![16, 15, 1, 16, 15, 1]); + c_temp[i] = rep.values().try_into().unwrap(); + for mask in rep.rep { + add_range(mask.value, mask.size, round); + } + } + + let mut crot64 = [0u64; 5]; + let mut crot8 = [[0u64; 8]; 5]; + for i in 0..5 { + crot64[i] = c64[i].rotate_left(1); + crot8[i] = conv64to8(crot64[i]); + } + + let mut d64 = [0u64; 5]; + let mut d8 = [[0u64; 8]; 5]; + for x in 0..5 { + d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); + d8[x] = conv64to8(d64[x]); + for k in 0..8 { + add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); + } + } + + let mut theta_state64 = state64.clone(); + let mut theta_state8 = [[[0u64; 8]; 5]; 5]; + let mut rotation_witness = vec![]; + + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] ^= d64[x]; + theta_state8[y][x] = conv64to8(theta_state64[y][x]); + + for k in 0..8 { + add_xor(state8[y][x][k], d8[x][k], round); + } + + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) + .convert(sizes); + for mask in rep.rep.iter() { + if mask.size != 32 { + add_range(mask.value, mask.size, round); + } + } + rotation_witness.extend(rep.values()); + } + } + + // Rho and Pi steps + let mut rhopi_output64 = [[0u64; 5]; 5]; + let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = + theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); + } + } + + for x in 0..5 { + for y in 0..5 { + rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); + } + } + + // Chi step + + let mut nonlinear64 = [[0u64; 5]; 5]; + let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + nonlinear64[y][x] = + !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; + nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); + + for k in 0..8 { + add_and( + 0xFF - rhopi_output8[y][(x + 1) % 5][k], + rhopi_output8[y][(x + 2) % 5][k], + round, + ); + } + } + } + + let mut chi_output64 = [[0u64; 5]; 5]; + let mut chi_output8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; + chi_output8[y][x] = conv64to8(chi_output64[y][x]); + for k in 0..8 { + add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) + } + } + } + + // Iota step + let mut iota_output64 = chi_output64.clone(); + let mut iota_output8 = [[[0u64; 8]; 5]; 5]; + iota_output64[0][0] ^= RC[round]; + + for k in 0..8 { + add_xor(chi_output8[0][0][k], (RC[round] >> (k * 8)) & 0xFF, round); + } + + for x in 0..5 { + for y in 0..5 { + iota_output8[x][y] = conv64to8(iota_output64[x][y]); + } + } + + let all_wits64 = [ + state8.into_iter().flatten().flatten().collect_vec(), + c_aux8.into_iter().flatten().flatten().collect_vec(), + c_temp.into_iter().flatten().collect_vec(), + crot8.into_iter().flatten().collect_vec(), + d8.into_iter().flatten().collect_vec(), + theta_state8.into_iter().flatten().flatten().collect_vec(), + rotation_witness, + rhopi_output8.into_iter().flatten().flatten().collect_vec(), + nonlinear8.into_iter().flatten().flatten().collect_vec(), + chi_output8.into_iter().flatten().flatten().collect_vec(), + iota_output8 + .clone() + .into_iter() + .flatten() + .flatten() + .collect_vec(), + ]; + + // let sizes = all_wits64.iter().map(|e| e.len()).collect_vec(); + // dbg!(&sizes); + + // let all_wits = nest::( + // &all_wits64 + // .into_iter() + // .flat_map(|v| u64s_to_felts::(v)) + // .collect_vec(), + // ); + + push_instance(all_wits64.into_iter().flatten().collect_vec()); + + state8 = iota_output8; + state64 = iota_output64; + } + + let mut keccak_output32 = vec![vec![vec![0; 2]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + keccak_output32[x][y] = MaskRepresentation::from( + state8[x][y].into_iter().map(|e| (8, e)).collect_vec(), + ) + .convert(vec![32; 2]) + .values(); + } + } + + push_instance(state8.clone().into_iter().flatten().flatten().collect_vec()); + + // For temporary convenience, use one extra layer to store the correct outputs + // of the circuit This is not used during proving + let lookups = (0..24) + .into_iter() + .rev() + .flat_map(|i| { + chain!( + and_lookups[i].clone(), + xor_lookups[i].clone(), + range_lookups[i].clone() + ) + }) + .collect_vec(); + + push_instance( + chain!( + keccak_output32.into_iter().flatten().flatten(), + state32, + lookups + ) + .collect_vec(), + ); + } + + let len = layer_wits.len() - 1; + layer_wits[..len].reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> () { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + let gkr_circuit = chip.gkr_circuit(); + + let mut instances = vec![]; + for state in states { + let state_mask64 = + MaskRepresentation::from(state.into_iter().map(|e| (64, e)).collect_vec()); + let state_mask32 = state_mask64.convert(vec![32; 50]); + + instances.push( + state_mask32 + .values() + .iter() + .map(|e| *e as u32) + .collect_vec() + .try_into() + .unwrap(), + ); + } + + let phase1_witness = layout.phase1_witness(KeccakTrace { instances }); + + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &vec![]); + + let out_evals = { + let point1 = Arc::new(vec![E::ZERO]); + let point2 = Arc::new(vec![E::ONE]); + let final_output1 = gkr_witness + .layers + .last() + .unwrap() + .bases + //.clone() + .iter() + .map(|base| base[0].clone()) + .collect_vec(); + let final_output2 = gkr_witness + .layers + .last() + .unwrap() + .bases + //.clone() + .iter() + .map(|base| base[1].clone()) + .collect_vec(); + + // if test { + // // confront outputs with tiny_keccak result + // let mut keccak_output64 = state_mask64.values().try_into().unwrap(); + // keccakf(&mut keccak_output64); + + // let keccak_output32 = MaskRepresentation::from( + // keccak_output64.into_iter().map(|e| (64, e)).collect_vec(), + // ) + // .convert(vec![32; 50]) + // .values(); + + // let keccak_output32 = u64s_to_felts::(keccak_output32); + // assert_eq!(keccak_output32, final_output[..50]); + // } + + let len = final_output1.len(); + let gkr_outputs = chain!( + zip(final_output1, once(point1).cycle().take(len)), + zip(final_output2, once(point2).cycle().take(len)) + ); + gkr_outputs + .into_iter() + .map(|(elem, point)| PointAndEval { + point: point.clone(), + eval: E::from_bases(&[elem, Goldilocks::ZERO]), + }) + .collect_vec() + }; + + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &vec![], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_v2_keccakf() { + for _ in 0..3 { + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); + let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + // let state = [0; 50]; + run_faster_keccakf(vec![state1, state2], true, true); + } + } +} diff --git a/gkr_iop/src/precompiles/keccak_f.rs b/gkr_iop/src/precompiles/keccak_f.rs new file mode 100644 index 000000000..0a2e44f09 --- /dev/null +++ b/gkr_iop/src/precompiles/keccak_f.rs @@ -0,0 +1,600 @@ +use std::{array::from_fn, marker::PhantomData, sync::Arc}; + +use crate::{ + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + layer::{Layer, LayerType, LayerWitness}, + GKRCircuitWitness, GKRProverOutput, + }, + ProtocolBuilder, ProtocolWitnessGenerator, +}; +use ff_ext::ExtensionField; +use itertools::{chain, iproduct, Itertools}; +use p3_field::{extension::BinomialExtensionField, Field, PrimeCharacteristicRing}; +use p3_goldilocks::Goldilocks; + +use subprotocols::expression::{Constant, Expression, Witness}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; +#[derive(Clone, Debug, Default)] +struct KeccakParams {} + +#[derive(Clone, Debug, Default)] +struct KeccakLayout { + params: KeccakParams, + + committed_bits_id: usize, + + result: Vec, + _marker: PhantomData, +} + +const X: usize = 5; +const Y: usize = 5; +const Z: usize = 64; +const STATE_SIZE: usize = X * Y * Z; +const C_SIZE: usize = X * Z; +const D_SIZE: usize = X * Z; + +fn to_xyz(i: usize) -> (usize, usize, usize) { + assert!(i < STATE_SIZE); + (i / 64 % 5, (i / 64) / 5, i % 64) +} + +fn from_xyz(x: usize, y: usize, z: usize) -> usize { + assert!(x < 5 && y < 5 && z < 64); + 64 * (5 * y + x) + z +} + +fn xor(a: F, b: F) -> F { + a + b - a * b - a * b +} + +fn and_expr(a: Expression, b: Expression) -> Expression { + a.clone() * b.clone() +} + +fn not_expr(a: Expression) -> Expression { + one_expr() - a +} + +fn xor_expr(a: Expression, b: Expression) -> Expression { + a.clone() + b.clone() - Expression::Const(Constant::Base(2)) * a * b +} + +fn zero_expr() -> Expression { + Expression::Const(Constant::Base(0)) +} + +fn one_expr() -> Expression { + Expression::Const(Constant::Base(1)) +} + +fn c(x: usize, z: usize, bits: &[F]) -> F { + (0..5) + .map(|y| bits[from_xyz(x, y, z)]) + .fold(F::ZERO, |acc, x| xor(acc, x)) +} + +fn c_expr(x: usize, z: usize, state_wits: &[Witness]) -> Expression { + (0..5) + .map(|y| Expression::from(state_wits[from_xyz(x, y, z)])) + .fold(zero_expr(), |acc, x| xor_expr(acc, x)) +} + +fn from_xz(x: usize, z: usize) -> usize { + x * 64 + z +} + +fn d(x: usize, z: usize, c_vals: &[F]) -> F { + let lhs = from_xz((x + 5 - 1) % 5, z); + let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); + xor(c_vals[lhs], c_vals[rhs]) +} + +fn d_expr(x: usize, z: usize, c_wits: &[Witness]) -> Expression { + let lhs = from_xz((x + 5 - 1) % 5, z); + let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); + xor_expr(c_wits[lhs].into(), c_wits[rhs].into()) +} + +fn theta(bits: Vec) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + + let c_vals = iproduct!(0..5, 0..64) + .map(|(x, z)| c(x, z, &bits)) + .collect_vec(); + + let d_vals = iproduct!(0..5, 0..64) + .map(|(x, z)| d(x, z, &c_vals)) + .collect_vec(); + + bits.iter() + .enumerate() + .map(|(i, bit)| { + let (x, _, z) = to_xyz(i); + xor(*bit, d_vals[from_xz(x, z)]) + }) + .collect() +} + +fn and(a: F, b: F) -> F { + a * b +} + +fn not(a: F) -> F { + F::ONE - a +} + +fn bools_to_u64s(state: &[bool]) -> Vec { + state + .chunks(64) + .map(|chunk| { + chunk + .iter() + .enumerate() + .fold(0u64, |acc, (i, &bit)| acc | ((bit as u64) << i)) + }) + .collect::>() +} + +fn u64s_to_bools(state64: &[u64]) -> Vec { + state64 + .iter() + .flat_map(|&word| (0..64).map(move |i| ((word >> i) & 1) == 1)) + .collect() +} + +fn dbg_vec(vec: &Vec) { + let res = vec + .iter() + .map(|f| if *f == F::ZERO { false } else { true }) + .collect_vec(); + // println!("{:?}", bools_to_u64s(&res)); +} +fn chi(bits: &Vec) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + + bits.iter() + .enumerate() + .map(|(i, bit)| { + let (x, y, z) = to_xyz(i); + let rhs = and( + not(bits[from_xyz((x + 1) % X, y, z)]), + bits[from_xyz((x + 2) % X, y, z)], + ); + xor(*bit, rhs) + }) + .collect() +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +fn iota(bits: &Vec, round_value: u64) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + let mut ret = bits.clone(); + + let cast = |x| match x { + 0 => F::ZERO, + 1 => F::ONE, + _ => unreachable!(), + }; + + for z in 0..Z { + ret[from_xyz(0, 0, z)] = xor(bits[from_xyz(0, 0, z)], cast((round_value >> z) & 1)); + } + + ret +} + +fn iota_expr(bits: &[Witness], index: usize, round_value: u64) -> Expression { + assert_eq!(bits.len(), STATE_SIZE); + let (x, y, z) = to_xyz(index); + + if x > 0 || y > 0 { + bits[index].into() + } else { + let round_bit = Expression::Const(Constant::Base( + ((round_value >> index) & 1).try_into().unwrap(), + )); + xor_expr(bits[from_xyz(0, 0, z)].into(), round_bit) + } +} + +fn chi_expr(i: usize, bits: &[Witness]) -> Expression { + assert_eq!(bits.len(), STATE_SIZE); + + let (x, y, z) = to_xyz(i); + let rhs = and_expr( + not_expr(bits[from_xyz((x + 1) % X, y, z)].into()), + bits[from_xyz((x + 2) % X, y, z)].into(), + ); + xor_expr((bits[i]).into(), rhs) +} + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(params: Self::Params) -> Self { + Self { + params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_bits_id] = chip.allocate_committed_base(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_output = chip.allocate_output_evals::(); + + (0..ROUNDS) + .rev() + .into_iter() + .fold(final_output, |round_output, round| { + let (chi_output, _) = chip.allocate_wits_in_layer::(); + + let exprs = (0..STATE_SIZE) + .map(|i| iota_expr(&chi_output.iter().map(|e| e.0).collect_vec(), i, RC[round])) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Iota:: compute output"), + LayerType::Zerocheck, + exprs, + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + round_output.to_vec(), + vec![], + )); + + let (theta_output, _) = chip.allocate_wits_in_layer::(); + + // Apply the effects of the rho + pi permutation directly o the argument of chi + // No need for a separate layer + let perm = rho_and_pi_permutation(); + let permuted = (0..STATE_SIZE) + .map(|i| theta_output[perm[i]].0) + .collect_vec(); + + let exprs = (0..STATE_SIZE) + .map(|i| chi_expr(i, &permuted)) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Chi:: apply rho, pi and chi"), + LayerType::Zerocheck, + exprs, + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (d_and_state, _) = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }, 0>(); + let (d, state2) = d_and_state.split_at(D_SIZE); + + // Compute post-theta state using original state and D[][] values + let exprs = (0..STATE_SIZE) + .map(|i| { + let (x, y, z) = to_xyz(i); + xor_expr(state2[i].0.into(), d[from_xz(x, z)].0.into()) + }) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute output"), + LayerType::Zerocheck, + exprs, + vec![], + d_and_state.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (c, []) = chip.allocate_wits_in_layer::<{ C_SIZE }, 0>(); + + let c_wits = c.iter().map(|e| e.0.clone()).collect_vec(); + // Compute D[][] from C[][] values + let d_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d_expr(x, z, &c_wits)) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute D[x][z]"), + LayerType::Zerocheck, + d_exprs, + vec![], + c.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + d.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (state, []) = chip.allocate_wits_in_layer::(); + let state_wits = state.iter().map(|s| s.0).collect_vec(); + + // Compute C[][] from state + let c_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c_expr(x, z, &state_wits)) + .collect_vec(); + + // Copy state + let id_exprs: Vec = + (0..STATE_SIZE).map(|i| state_wits[i].into()).collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute C[x][z]"), + LayerType::Zerocheck, + chain!(c_exprs, id_exprs).collect_vec(), + vec![], + state.iter().map(|t| t.1.clone()).collect_vec(), + vec![], + chain!( + c.iter().map(|e| e.1.clone()), + state2.iter().map(|e| e.1.clone()) + ) + .collect_vec(), + vec![], + )); + + state + .iter() + .map(|e| e.1.clone()) + .collect_vec() + .try_into() + .unwrap() + }); + + // Skip base opening allocation + } +} + +pub struct KeccakTrace { + pub bits: [bool; STATE_SIZE], +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut res = vec![vec![]; 1]; + res[0] = phase1 + .bits + .into_iter() + .map(|b| E::BaseField::from_u64(b as u64)) + .collect(); + res + } + + fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + let mut bits = phase1[self.committed_bits_id].clone(); + + let n_layers = 100; + let mut layer_wits = Vec::>::with_capacity(n_layers + 1); + + for i in 0..24 { + if i == 0 { + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + } + + let c_wits = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c(x, z, &bits)) + .collect_vec(); + + layer_wits.push(LayerWitness::new( + chain!( + c_wits.clone().into_iter().map(|b| vec![b]), + // Note: it seems test pass even if this is uncommented. + // Maybe it's good to assert there are no unused witnesses + // bits.clone().into_iter().map(|b| vec![b]) + ) + .collect_vec(), + vec![], + )); + + let d_wits = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d(x, z, &c_wits)) + .collect_vec(); + + layer_wits.push(LayerWitness::new( + chain!( + d_wits.clone().into_iter().map(|b| vec![b]), + bits.clone().into_iter().map(|b| vec![b]) + ) + .collect_vec(), + vec![], + )); + + bits = theta(bits); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + + bits = chi(&pi(&rho(&bits))); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + + if i < 23 { + bits = iota(&bits, RC[i]); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + } + } + + // Assumes one input instance + let total_witness_size: usize = layer_wits.iter().map(|layer| layer.bases.len()).sum(); + dbg!(total_witness_size); + + layer_wits.reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +// based on +// https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_rho.cpp +fn rho(state: &[T]) -> Vec { + assert_eq!(state.len(), STATE_SIZE); + let (mut x, mut y) = (1, 0); + let mut ret = [T::default(); STATE_SIZE]; + + for z in 0..Z { + ret[from_xyz(0, 0, z)] = state[from_xyz(0, 0, z)]; + } + for t in 0..24 { + for z in 0..Z { + let new_z = (1000 * Z + z - (t + 1) * (t + 2) / 2) % Z; + ret[from_xyz(x, y, z)] = state[from_xyz(x, y, new_z)]; + } + (x, y) = (y, (2 * x + 3 * y) % Y); + } + + return ret.to_vec(); +} + +// https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_pi.cpp +fn pi(state: &[T]) -> Vec { + assert_eq!(state.len(), STATE_SIZE); + let mut ret = [T::default(); STATE_SIZE]; + + iproduct!(0..X, 0..Y, 0..Z) + .into_iter() + .map(|(x, y, z)| ret[from_xyz(x, y, z)] = state[from_xyz((x + 3 * y) % X, x, z)]) + .count(); + + return ret.to_vec(); +} + +// Combines rho and pi steps into a single permutation +fn rho_and_pi_permutation() -> Vec { + let perm: [usize; STATE_SIZE] = from_fn(|i| i); + pi(&rho(&perm)) +} + +pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + let gkr_circuit = chip.gkr_circuit(); + + let bits = u64s_to_bools(&state); + + let phase1_witness = layout.phase1_witness(KeccakTrace { + bits: bits.try_into().unwrap(), + }); + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let gkr_witness = layout.gkr_witness(&phase1_witness, &vec![]); + + let out_evals = { + let point = Arc::new(vec![]); + + let last_witness = gkr_witness.layers[0] + .bases + .clone() + .into_iter() + .flatten() + .collect_vec(); + + // Last witness is missing the final sub-round; apply it now + let expected_result_manual = iota(&last_witness, RC[23]); + + if test { + let mut state = state.clone(); + keccakf(&mut state); + let state = u64s_to_bools(&state) + .into_iter() + .map(|b| Goldilocks::from_u64(b as u64)) + .collect_vec(); + assert_eq!(state, expected_result_manual); + } + + expected_result_manual + .iter() + .map(|bit| PointAndEval { + point: point.clone(), + eval: E::from_bases(&[*bit, Goldilocks::ZERO]), + }) + .collect_vec() + }; + + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &vec![], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng, SeedableRng}; + + use super::*; + + #[test] + fn test_keccakf() { + for _ in 0..3 { + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); + let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + run_keccakf(state, true, true); + } + } +} diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs new file mode 100644 index 000000000..56ce5c42f --- /dev/null +++ b/gkr_iop/src/precompiles/mod.rs @@ -0,0 +1,5 @@ +mod faster_keccak; +mod keccak_f; +mod utils; +pub use faster_keccak::run_faster_keccakf; +pub use keccak_f::run_keccakf; diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs new file mode 100644 index 000000000..bec858e43 --- /dev/null +++ b/gkr_iop/src/precompiles/utils.rs @@ -0,0 +1,190 @@ +use ff_ext::ExtensionField; +use itertools::Itertools; +use p3_field::PrimeCharacteristicRing; +use subprotocols::expression::{Constant, Expression}; + +use crate::evaluation::EvalExpression; + +pub fn zero_expr() -> Expression { + Expression::Const(Constant::Base(0)) +} + +pub fn one_expr() -> Expression { + Expression::Const(Constant::Base(1)) +} + +pub fn not8_expr(expr: Expression) -> Expression { + Expression::Const(Constant::Base(0xFF)) - expr +} + +pub fn zero_eval() -> EvalExpression { + EvalExpression::Linear(0, Constant::Base(0), Constant::Base(0)) +} + +pub fn nest(v: &Vec) -> Vec> { + v.clone().into_iter().map(|e| vec![e]).collect_vec() +} + +pub fn u64s_to_felts(words: Vec) -> Vec { + words + .into_iter() + .map(|word| E::BaseField::from_u64(word)) + .collect() +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Mask { + pub size: usize, + pub value: u64, +} + +impl Mask { + pub fn new(size: usize, value: u64) -> Self { + if size < 64 { + assert!(value < (1 << size)); + } + Self { size, value } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MaskRepresentation { + pub rep: Vec, +} + +impl From for (usize, u64) { + fn from(mask: Mask) -> Self { + (mask.size, mask.value) + } +} + +impl From<(usize, u64)> for Mask { + fn from(tuple: (usize, u64)) -> Self { + Mask::new(tuple.0, tuple.1) + } +} + +impl From for Vec<(usize, u64)> { + fn from(mask_rep: MaskRepresentation) -> Self { + mask_rep.rep.into_iter().map(|mask| mask.into()).collect() + } +} + +impl From> for MaskRepresentation { + fn from(tuples: Vec<(usize, u64)>) -> Self { + MaskRepresentation { + rep: tuples.into_iter().map(|tuple| tuple.into()).collect(), + } + } +} + +impl MaskRepresentation { + pub fn new(masks: Vec) -> Self { + Self { rep: masks } + } + + pub fn from_bits(bits: Vec, sizes: Vec) -> Self { + assert_eq!(bits.len(), sizes.iter().sum::()); + let mut masks = Vec::new(); + let mut bit_iter = bits.into_iter(); + for size in sizes { + let mut mask = 0; + for i in 0..size { + mask += (1 << i) * (bit_iter.next().unwrap() as u64); + } + masks.push(Mask::new(size, mask)); + } + Self { rep: masks } + } + + pub fn to_bits(&self) -> Vec { + self.rep + .iter() + .flat_map(|mask| (0..mask.size).map(move |i| ((mask.value >> i) & 1))) + .collect() + } + + pub fn convert(&self, new_sizes: Vec) -> Self { + let bits = self.to_bits(); + Self::from_bits(bits, new_sizes) + } + + pub fn values(&self) -> Vec { + self.rep.iter().map(|m| m.value).collect_vec() + } + + pub fn masks(&self) -> Vec { + self.rep.clone() + } + + fn len(&self) -> usize { + self.rep.len() + } +} + +#[derive(Debug)] +pub enum CenoLookup { + And(Expression, Expression, Expression), + Xor(Expression, Expression, Expression), + U16(Expression), +} + +impl CenoLookup { + // combine lookup arguments with challenges + pub fn compress(&self, alpha: Constant, beta: Constant) -> Expression { + let [alpha, beta] = [alpha, beta].map(|e| { + // assert!(matches!(e, Constant::Challenge(_))); + Expression::Const(e) + }); + + match self { + CenoLookup::And(a, b, c) => { + a.clone() + + alpha.clone() * b.clone() + + alpha.clone() * alpha.clone() * c.clone() + + beta + } + CenoLookup::Xor(a, b, c) => { + a.clone() + + alpha.clone() * b.clone() + + alpha.clone() * alpha.clone() * c.clone() + + beta + } + CenoLookup::U16(a) => a.clone() + beta, + } + } +} + +mod tests { + use crate::precompiles::utils::{Mask, MaskRepresentation}; + + #[test] + fn test_mask_representation_from_bits() { + let bits = vec![1, 0, 1, 1, 0, 1, 0, 0]; + let sizes = vec![3, 5]; + let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone()); + assert_eq!(mask_rep.rep.len(), 2); + assert_eq!(mask_rep.rep[0], Mask::new(3, 0b101)); + assert_eq!(mask_rep.rep[1], Mask::new(5, 0b00101)); + } + + #[test] + fn test_mask_representation_to_bits() { + let masks = vec![Mask::new(3, 0b101), Mask::new(5, 0b00101)]; + let mask_rep = MaskRepresentation::new(masks); + let bits = mask_rep.to_bits(); + assert_eq!(bits, vec![1, 0, 1, 1, 0, 1, 0, 0]); + } + + #[test] + fn test_mask_representation_convert() { + let bits = vec![1, 0, 1, 1, 0, 1, 0, 0]; + let sizes = vec![3, 5]; + let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone()); + let new_sizes = vec![4, 4]; + let new_mask_rep = mask_rep.convert(new_sizes); + assert_eq!(new_mask_rep.rep.len(), 2); + assert_eq!(new_mask_rep.rep[0], Mask::new(4, 0b1101)); + assert_eq!(new_mask_rep.rep[1], Mask::new(4, 0b0010)); + } +} diff --git a/subprotocols/examples/zerocheck_logup.rs b/subprotocols/examples/zerocheck_logup.rs index 4390db649..36c227b7e 100644 --- a/subprotocols/examples/zerocheck_logup.rs +++ b/subprotocols/examples/zerocheck_logup.rs @@ -51,6 +51,7 @@ fn run_verifier( let verifier = ZerocheckVerifierState::new( vec![*ans], vec![expr], + vec![], vec![point], proof, &challenges, diff --git a/subprotocols/src/error.rs b/subprotocols/src/error.rs index 5dad7bc9d..ca8eddefe 100644 --- a/subprotocols/src/error.rs +++ b/subprotocols/src/error.rs @@ -4,6 +4,6 @@ use crate::expression::Expression; #[derive(Clone, Debug, Error)] pub enum VerifierError { - #[error("Claim not match: expr: {0:?}, expect: {1:?}, got: {2:?}")] - ClaimNotMatch(Expression, E, E), + #[error("Claim not match: expr: {0:?}\n (expr name: {3:?})\n expect: {1:?}, got: {2:?}")] + ClaimNotMatch(Expression, E, E, String), } diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs index d01c2702a..3b1ebc079 100644 --- a/subprotocols/src/sumcheck.rs +++ b/subprotocols/src/sumcheck.rs @@ -179,6 +179,7 @@ where sigma: E, expr: Expression, proof: SumcheckProof, + expr_names: Vec, challenges: &'a [E], transcript: &'a mut Trans, out_points: Vec<&'a [E]>, @@ -202,6 +203,7 @@ where proof: SumcheckProof, challenges: &'a [E], transcript: &'a mut Trans, + expr_names: Vec, ) -> Self { Self { sigma, @@ -210,6 +212,7 @@ where challenges, transcript, out_points, + expr_names, } } @@ -221,6 +224,7 @@ where challenges, transcript, out_points, + expr_names, } = self; let SumcheckProof { univariate_polys, @@ -262,6 +266,7 @@ where expr, expected_claim, got_claim, + "placeholder".to_string(), )); } @@ -324,6 +329,7 @@ mod test { proof, &challenges, &mut verifier_transcript, + vec![], ); verifier.verify().expect("verification failed"); diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs index 0d8b1bcec..7b6d989e4 100644 --- a/subprotocols/src/zerocheck.rs +++ b/subprotocols/src/zerocheck.rs @@ -196,6 +196,7 @@ where inv_of_one_minus_points: Vec>, exprs: Vec<(Expression, &'a [E])>, proof: SumcheckProof, + expr_names: Vec, challenges: &'a [E], transcript: &'a mut Trans, } @@ -208,6 +209,7 @@ where pub fn new( sigmas: Vec, exprs: Vec, + expr_names: Vec, points: Vec<&'a [E]>, proof: SumcheckProof, challenges: &'a [E], @@ -235,6 +237,7 @@ where proof, challenges, transcript, + expr_names, } } @@ -246,6 +249,7 @@ where proof, challenges, transcript, + expr_names, .. } = self; let SumcheckProof { @@ -285,7 +289,7 @@ where ); // Check the final evaluations. - for (expected_claim, (expr, _)) in izip!(expected_claims, exprs) { + for (expected_claim, (expr, _), expr_name) in izip!(expected_claims, exprs, expr_names) { let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); if expected_claim != got_claim { @@ -293,6 +297,7 @@ where expr, expected_claim, got_claim, + expr_name.clone(), )); } } @@ -352,6 +357,7 @@ mod test { let verifier = ZerocheckVerifierState::new( sigmas, exprs, + vec![], points, proof, &challenges, From 7611e70e63f16fa6a987e644232a93d1543c893c Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 1 Apr 2025 18:56:46 +0300 Subject: [PATCH 02/17] add gkr_iop state to ConstraintSystem/ZKVMWitness (wip, Apr 2) --- Cargo.lock | 2 + ceno_zkvm/Cargo.toml | 1 + ceno_zkvm/src/structs.rs | 83 +++++++++++++++++++++++- gkr_iop/Cargo.toml | 1 + gkr_iop/src/chip/protocol.rs | 8 +-- gkr_iop/src/gkr.rs | 18 ++--- gkr_iop/src/gkr/mock.rs | 6 +- gkr_iop/src/precompiles/faster_keccak.rs | 15 ++++- gkr_iop/src/precompiles/mod.rs | 5 +- 9 files changed, 118 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b6418e00a..16d310661 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -578,6 +578,7 @@ dependencies = [ "criterion", "ff_ext", "generic_static", + "gkr_iop", "glob", "itertools 0.13.0", "mpcs", @@ -1185,6 +1186,7 @@ dependencies = [ "thiserror 1.0.69", "tiny-keccak", "transcript", + "witness", ] [[package]] diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3c23fe0b1..2446aa9cc 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -19,6 +19,7 @@ serde_json.workspace = true base64 = "0.22" ceno_emul = { path = "../ceno_emul" } ff_ext = { path = "../ff_ext" } +gkr_iop = { path = "../gkr_iop" } mpcs = { path = "../mpcs" } multilinear_extensions = { version = "0", path = "../multilinear_extensions" } p3 = { path = "../p3" } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 2fe82f980..bfcd2ade1 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -2,13 +2,17 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::Expression, - instructions::Instruction, + instructions::{Instruction, riscv::dummy::LargeEcallDummy}, state::StateCircuit, tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, }; -use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; +use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord}; use ff_ext::ExtensionField; +use gkr_iop::{ + ProtocolWitnessGenerator, + precompiles::{KeccakLayout, KeccakTrace}, +}; use itertools::{Itertools, chain}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ @@ -130,6 +134,13 @@ impl> VerifyingKey } } +#[derive(Clone)] +pub struct KeccakGKRIOP { + pub chip: gkr_iop::chip::Chip, + pub layout: KeccakLayout, + pub circuit: gkr_iop::gkr::GKRCircuit, +} + #[derive(Clone, Debug)] pub struct ProgramParams { pub platform: Platform, @@ -154,6 +165,7 @@ pub struct ZKVMConstraintSystem { pub(crate) circuit_css: BTreeMap>, pub(crate) initial_global_state_expr: Expression, pub(crate) finalize_global_state_expr: Expression, + pub keccak_gkr_iop: Option>, pub params: ProgramParams, } @@ -164,6 +176,7 @@ impl Default for ZKVMConstraintSystem { initial_global_state_expr: Expression::ZERO, finalize_global_state_expr: Expression::ZERO, params: ProgramParams::default(), + keccak_gkr_iop: None, } } } @@ -175,6 +188,25 @@ impl ZKVMConstraintSystem { ..Default::default() } } + + pub fn register_keccakf_circuit( + &mut self, + ) -> as Instruction>::InstructionConfig { + // Add GKR-IOP instance + let params = gkr_iop::precompiles::KeccakParams {}; + let (layout, chip) = as gkr_iop::ProtocolBuilder>::build(params); + let circuit = chip.gkr_circuit(); + + assert!(self.keccak_gkr_iop.is_none()); + self.keccak_gkr_iop = Some(KeccakGKRIOP { + layout, + chip, + circuit, + }); + + self.register_opcode_circuit::>() + } + pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); let mut circuit_builder = @@ -220,6 +252,14 @@ pub struct ZKVMFixedTraces { } impl ZKVMFixedTraces { + pub fn register_keccakf_circuit(&mut self, _cs: &ZKVMConstraintSystem) { + assert!( + self.circuit_fixed_traces + .insert(LargeEcallDummy::::name(), None) + .is_none() + ); + } + pub fn register_opcode_circuit>(&mut self, _cs: &ZKVMConstraintSystem) { assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none()); } @@ -244,6 +284,7 @@ impl ZKVMFixedTraces { #[derive(Default, Clone)] pub struct ZKVMWitnesses { + keccak_trace: as gkr_iop::ProtocolWitnessGenerator>::Trace, witnesses_opcodes: BTreeMap>, witnesses_tables: BTreeMap>, lk_mlts: BTreeMap, @@ -263,6 +304,44 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } + pub fn assign_keccakf_circuit( + &mut self, + css: &mut ZKVMConstraintSystem, + config: & as Instruction>::InstructionConfig, + records: Vec, + ) -> Result<(), ZKVMError> { + // Ugly copy paste from assign_opcode_circuit, but we need to use the row major matrix + let cs = css + .get_cs(&LargeEcallDummy::::name()) + .unwrap(); + let (witness, logup_multiplicity) = LargeEcallDummy::::assign_instances( + config, + cs.num_witin as usize, + records, + )?; + + // GKR-IOP-specific trace from row major witness + self.keccak_trace = KeccakTrace::from(witness.clone()); + + assert!( + self.witnesses_opcodes + .insert(LargeEcallDummy::::name(), witness) + .is_none() + ); + assert!( + !self + .witnesses_tables + .contains_key(&LargeEcallDummy::::name()) + ); + assert!( + self.lk_mlts + .insert(LargeEcallDummy::::name(), logup_multiplicity) + .is_none() + ); + + Ok(()) + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index 323a24233..e05c48488 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -12,6 +12,7 @@ version.workspace = true [dependencies] ark-std.workspace = true ff_ext = { path = "../ff_ext" } +witness = { path = "../witness" } itertools.workspace = true multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } ndarray.workspace = true diff --git a/gkr_iop/src/chip/protocol.rs b/gkr_iop/src/chip/protocol.rs index a633791eb..0374d8e6c 100644 --- a/gkr_iop/src/chip/protocol.rs +++ b/gkr_iop/src/chip/protocol.rs @@ -4,13 +4,13 @@ use super::Chip; impl Chip { /// Extract information from Chip that required in the GKR phase. - pub fn gkr_circuit(&'_ self) -> GKRCircuit<'_> { + pub fn gkr_circuit(&self) -> GKRCircuit { GKRCircuit { - layers: &self.layers, + layers: self.layers.clone(), n_challenges: self.n_challenges, n_evaluations: self.n_evaluations, - base_openings: &self.base_openings, - ext_openings: &self.ext_openings, + base_openings: self.base_openings.clone(), + ext_openings: self.ext_openings.clone(), } } } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 1e632a18f..50ae2e7a7 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,5 +1,5 @@ use ff_ext::ExtensionField; -use itertools::{Itertools, chain, izip}; +use itertools::{chain, izip, Itertools}; use layer::{Layer, LayerWitness}; use subprotocols::{expression::Point, sumcheck::SumcheckProof}; use transcript::Transcript; @@ -13,13 +13,13 @@ pub mod layer; pub mod mock; #[derive(Clone, Debug)] -pub struct GKRCircuit<'a> { - pub layers: &'a [Layer], +pub struct GKRCircuit { + pub layers: Vec, pub n_challenges: usize, pub n_evaluations: usize, - pub base_openings: &'a [(usize, EvalExpression)], - pub ext_openings: &'a [(usize, EvalExpression)], + pub base_openings: Vec<(usize, EvalExpression)>, + pub ext_openings: Vec<(usize, EvalExpression)>, } #[derive(Clone, Debug)] @@ -42,7 +42,7 @@ pub struct Evaluation { pub struct GKRClaims(pub Vec); -impl GKRCircuit<'_> { +impl GKRCircuit { pub fn prove( &self, circuit_wit: GKRCircuitWitness, @@ -56,7 +56,7 @@ impl GKRCircuit<'_> { let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); - let sumcheck_proofs = izip!(self.layers, circuit_wit.layers) + let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) .map(|(layer, layer_wit)| { layer.prove(layer_wit, &mut evaluations, &mut challenges, transcript) }) @@ -85,7 +85,7 @@ impl GKRCircuit<'_> { let mut challenges = challenges.to_vec(); let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); - for (layer, layer_proof) in izip!(self.layers, sumcheck_proofs) { + for (layer, layer_proof) in izip!(self.layers.clone(), sumcheck_proofs) { layer.verify(layer_proof, &mut evaluations, &mut challenges, transcript)?; } @@ -99,7 +99,7 @@ impl GKRCircuit<'_> { evaluations: &[PointAndEval], challenges: &[E], ) -> Vec> { - chain!(self.base_openings, self.ext_openings) + chain!(&self.base_openings, &self.ext_openings) .map(|(poly, eval)| { let poly = *poly; let PointAndEval { point, eval: value } = eval.evaluate(evaluations, challenges); diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 660f5156e..06241fea5 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::{izip, Itertools}; use rand::rngs::OsRng; use subprotocols::{ expression::{Expression, VectorType}, @@ -12,7 +12,7 @@ use thiserror::Error; use crate::{evaluation::EvalExpression, utils::SliceIterator}; -use super::{GKRCircuit, GKRCircuitWitness, layer::LayerType}; +use super::{layer::LayerType, GKRCircuit, GKRCircuitWitness}; pub struct MockProver(PhantomData); @@ -35,7 +35,7 @@ pub enum MockProverError { impl MockProver { pub fn check( - circuit: GKRCircuit<'_>, + circuit: GKRCircuit, circuit_wit: &GKRCircuitWitness, mut evaluations: Vec>, mut challenges: Vec, diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 28ea0170d..f061b68e7 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -18,6 +18,8 @@ use crate::{ }; use ndarray::{range, s, ArrayView, Ix2, Ix3}; +use witness::RowMajorMatrix; + use super::utils::{u64s_to_felts, zero_eval, CenoLookup}; use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; @@ -31,10 +33,10 @@ use transcript::BasicTranscript; type E = BinomialExtensionField; #[derive(Clone, Debug, Default)] -struct KeccakParams {} +pub struct KeccakParams {} #[derive(Clone, Debug, Default)] -struct KeccakLayout { +pub struct KeccakLayout { params: KeccakParams, committed_bits_id: usize, @@ -727,10 +729,19 @@ impl ProtocolBuilder for KeccakLayout { } } +#[derive(Clone, Default)] pub struct KeccakTrace { pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, } +use p3_field::Field; + +impl From> for KeccakTrace { + fn from(value: RowMajorMatrix) -> Self { + unimplemented!(); + } +} + impl ProtocolWitnessGenerator for KeccakLayout where E: ExtensionField, diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index 56ce5c42f..775b246ba 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -1,5 +1,8 @@ mod faster_keccak; mod keccak_f; mod utils; -pub use faster_keccak::run_faster_keccakf; pub use keccak_f::run_keccakf; +pub use { + faster_keccak::run_faster_keccakf, faster_keccak::KeccakLayout, faster_keccak::KeccakParams, + faster_keccak::KeccakTrace, +}; From 1cedbc8094c7ad5eb71f94255d9c835940107231 Mon Sep 17 00:00:00 2001 From: Mihai Date: Thu, 3 Apr 2025 15:57:56 +0300 Subject: [PATCH 03/17] add gkr_iop state to ConstraintSystem/ZKVMWitness (wip, Apr 3) --- ceno_zkvm/src/instructions/riscv/rv32im.rs | 11 +-- ceno_zkvm/src/keygen.rs | 5 ++ ceno_zkvm/src/scheme/prover.rs | 29 ++++++- ceno_zkvm/src/structs.rs | 97 ++++++++++++++++++---- gkr_iop/src/precompiles/faster_keccak.rs | 3 +- 5 files changed, 118 insertions(+), 27 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 0b6e5a824..ed42f25eb 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -466,7 +466,7 @@ pub struct DummyExtraConfig { impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { let ecall_config = cs.register_opcode_circuit::>(); - let keccak_config = cs.register_opcode_circuit::>(); + let keccak_config = cs.register_keccakf_circuit(); let secp256k1_add_config = cs.register_opcode_circuit::>(); let secp256k1_double_config = @@ -509,7 +509,7 @@ impl DummyExtraConfig { fixed: &mut ZKVMFixedTraces, ) { fixed.register_opcode_circuit::>(cs); - fixed.register_opcode_circuit::>(cs); + fixed.register_keccakf_circuit(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); @@ -562,11 +562,8 @@ impl DummyExtraConfig { } } - witness.assign_opcode_circuit::>( - cs, - &self.keccak_config, - keccak_steps, - )?; + witness.assign_keccakf_circuit(&cs, &self.keccak_config, keccak_steps)?; + witness.assign_opcode_circuit::>( cs, &self.secp256k1_add_config, diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index f97a1886d..112ba1bc6 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -1,5 +1,6 @@ use crate::{ error::ZKVMError, + instructions::riscv::dummy::LargeEcallDummy, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey}, }; use ff_ext::ExtensionField; @@ -30,6 +31,10 @@ impl ZKVMConstraintSystem { assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } + // TODO: get keccak cs manually + // assert_eq!(keccak_cs.num_fixed, 0); + vm_pk.keccak_pk = self.keccak_gkr_iop.key_gen(&vm_pk.pp, None); + vm_pk.initial_global_state_expr = self.initial_global_state_expr; vm_pk.finalize_global_state_expr = self.finalize_global_state_expr; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 66d8ce625..5474edc77 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,3 +1,4 @@ +use ceno_emul::KeccakSpec; use ff_ext::ExtensionField; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, @@ -24,6 +25,10 @@ use witness::{RowMajorMatrix, next_pow2_instance_padding}; use crate::{ error::ZKVMError, expression::Instance, + instructions::{ + Instruction, + riscv::{dummy::LargeEcallDummy, ecall::EcallDummy}, + }, scheme::{ constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ @@ -32,7 +37,8 @@ use crate::{ }, }, structs::{ - Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, + GKRIOPProvingKey, KeccakGKRIOP, Point, ProvingKey, TowerProofs, TowerProver, + TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, utils::{add_mle_list_by_expr, get_challenge_pows, optimal_sumcheck_threads}, }; @@ -174,7 +180,9 @@ impl> ZKVMProver { && cs.r_table_expressions.is_empty() && cs.w_table_expressions.is_empty(); - if is_opcode_circuit { + if *circuit_name == as Instruction>::name() { + unimplemented!("keccak impl wip"); + } else if is_opcode_circuit { tracing::debug!( "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", circuit_name, @@ -183,10 +191,21 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); + + // Only Keccak has non-empty GKR-IOP component + let gkr_iop_pk = if *circuit_name + == as Instruction>::name() + { + Some(self.pk.keccak_pk.clone()) + } else { + None + }; + let opcode_proof = self.create_opcode_proof( circuit_name, &self.pk.pp, pk, + &gkr_iop_pk, witness, wits_commit, &pi, @@ -246,6 +265,7 @@ impl> ZKVMProver { name: &str, pp: &PCS::ProverParam, circuit_pk: &ProvingKey, + gkr_iop_pk: &Option>>, witnesses: Vec>, wits_commit: PCS::CommitmentWithWitness, pi: &[ArcMultilinearExtension<'_, E>], @@ -258,6 +278,11 @@ impl> ZKVMProver { let log2_num_instances = ceil_log2(next_pow2_instances); let (chip_record_alpha, _) = (challenges[0], challenges[1]); + if let Some(gkr_iop_pk) = gkr_iop_pk { + let mut gkr_iop_pk = gkr_iop_pk.clone(); + unimplemented!("cannot fully handle GKRIOP component yet") + } + // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); assert!( diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index bfcd2ade1..2d9fd7013 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -134,11 +134,77 @@ impl> VerifyingKey } } -#[derive(Clone)] +#[derive(Clone, Debug)] +pub struct GKRIOPProvingKey, State> { + pub fixed_traces: Option>>, + pub fixed_commit_wd: Option, + pub vk: GKRIOPVerifyingKey, +} + +impl, State: Default> Default + for GKRIOPProvingKey +{ + fn default() -> Self { + Self { + fixed_traces: None, + fixed_commit_wd: None, + vk: GKRIOPVerifyingKey::default(), + } + } +} + +#[derive(Clone, Debug)] +pub struct GKRIOPVerifyingKey, State> { + pub(crate) state: State, + pub fixed_commit: Option, +} + +impl, State: Default> Default + for GKRIOPVerifyingKey +{ + fn default() -> Self { + Self { + state: State::default(), + fixed_commit: None, + } + } +} + +impl, State> + GKRIOPVerifyingKey +{ + pub fn get_state(&self) -> &State { + &self.state + } +} + +#[derive(Clone, Default, Debug)] pub struct KeccakGKRIOP { pub chip: gkr_iop::chip::Chip, pub layout: KeccakLayout, - pub circuit: gkr_iop::gkr::GKRCircuit, +} + +impl KeccakGKRIOP { + pub fn key_gen>( + self, + pp: &PCS::ProverParam, + fixed_traces: Option>, + ) -> GKRIOPProvingKey> { + // transpose from row-major to column-major + let fixed_traces_polys = fixed_traces.as_ref().map(|rmm| rmm.to_mles()); + + let fixed_commit_wd = fixed_traces.map(|traces| PCS::batch_commit(pp, traces).unwrap()); + let fixed_commit = fixed_commit_wd.as_ref().map(PCS::get_pure_commitment); + + GKRIOPProvingKey { + fixed_traces: fixed_traces_polys, + fixed_commit_wd, + vk: GKRIOPVerifyingKey { + state: self, + fixed_commit, + }, + } + } } #[derive(Clone, Debug)] @@ -165,7 +231,7 @@ pub struct ZKVMConstraintSystem { pub(crate) circuit_css: BTreeMap>, pub(crate) initial_global_state_expr: Expression, pub(crate) finalize_global_state_expr: Expression, - pub keccak_gkr_iop: Option>, + pub keccak_gkr_iop: KeccakGKRIOP, pub params: ProgramParams, } @@ -176,7 +242,7 @@ impl Default for ZKVMConstraintSystem { initial_global_state_expr: Expression::ZERO, finalize_global_state_expr: Expression::ZERO, params: ProgramParams::default(), - keccak_gkr_iop: None, + keccak_gkr_iop: KeccakGKRIOP::default(), } } } @@ -195,14 +261,7 @@ impl ZKVMConstraintSystem { // Add GKR-IOP instance let params = gkr_iop::precompiles::KeccakParams {}; let (layout, chip) = as gkr_iop::ProtocolBuilder>::build(params); - let circuit = chip.gkr_circuit(); - - assert!(self.keccak_gkr_iop.is_none()); - self.keccak_gkr_iop = Some(KeccakGKRIOP { - layout, - chip, - circuit, - }); + self.keccak_gkr_iop = KeccakGKRIOP { layout, chip }; self.register_opcode_circuit::>() } @@ -284,7 +343,7 @@ impl ZKVMFixedTraces { #[derive(Default, Clone)] pub struct ZKVMWitnesses { - keccak_trace: as gkr_iop::ProtocolWitnessGenerator>::Trace, + keccak_phase1wit: Vec>, witnesses_opcodes: BTreeMap>, witnesses_tables: BTreeMap>, lk_mlts: BTreeMap, @@ -306,7 +365,7 @@ impl ZKVMWitnesses { pub fn assign_keccakf_circuit( &mut self, - css: &mut ZKVMConstraintSystem, + css: &ZKVMConstraintSystem, config: & as Instruction>::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -320,8 +379,11 @@ impl ZKVMWitnesses { records, )?; - // GKR-IOP-specific trace from row major witness - self.keccak_trace = KeccakTrace::from(witness.clone()); + // Intercept row-major matrix, convert into KeccakTrace and obtain phase1_wit + self.keccak_phase1wit = css + .keccak_gkr_iop + .layout + .phase1_witness(KeccakTrace::from(witness.clone())); assert!( self.witnesses_opcodes @@ -416,7 +478,6 @@ impl ZKVMWitnesses { input, )?; assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); - assert!(!self.witnesses_opcodes.contains_key(&TC::name())); Ok(()) @@ -443,6 +504,7 @@ pub struct ZKVMProvingKey> pub vp: PCS::VerifierParam, // pk for opcode and table circuits pub circuit_pks: BTreeMap>, + pub keccak_pk: GKRIOPProvingKey>, // expression for global state in/out pub initial_global_state_expr: Expression, @@ -455,6 +517,7 @@ impl> ZKVMProvingKey From> for KeccakTrace { - fn from(value: RowMajorMatrix) -> Self { + fn from(rmm: RowMajorMatrix) -> Self { + // clarify rmm encoding for large_ecall_dummy to obtain input unimplemented!(); } } From dd5ade7893a664889c641d632cc7303a540f307b Mon Sep 17 00:00:00 2001 From: Mihai Date: Wed, 9 Apr 2025 15:12:11 +0300 Subject: [PATCH 04/17] add lookups to DummyEcall (wip, Apr 9) --- ceno_zkvm/src/instructions.rs | 82 ++++++++++++++ .../instructions/riscv/dummy/dummy_ecall.rs | 107 +++++++++++++++++- examples/examples/ceno_rt_keccak.rs | 2 +- gkr_iop/src/precompiles/faster_keccak.rs | 7 +- gkr_iop/src/precompiles/mod.rs | 3 +- 5 files changed, 195 insertions(+), 6 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 58efee2ae..38b1469e8 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,11 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; +use gkr_iop::{ + ProtocolBuilder, ProtocolWitnessGenerator, + gkr::{GKRCircuitWitness, layer::LayerWitness}, + precompiles::KeccakLayout, +}; +use itertools::Itertools; use multilinear_extensions::util::max_usable_threads; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, @@ -67,3 +73,79 @@ pub trait Instruction { Ok((raw_witin, lk_multiplicity)) } } + +pub trait GKRIOPInstruction +where + Self: Instruction, +{ + type Layout: ProtocolWitnessGenerator + ProtocolBuilder; + fn assign_instance_with_gkr_iop( + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + lookups: Vec, + ) -> Result<(), ZKVMError>; + + fn phase1_witness_from_steps( + layout: &Self::Layout, + steps: Vec, + ) -> Vec>; + + fn assign_instances_with_gkr_iop( + config: &Self::InstructionConfig, + num_witin: usize, + steps: Vec, + gkr_layout: &mut Self::Layout, + ) -> Result<(RowMajorMatrix, LkMultiplicity), ZKVMError> { + let nthreads = max_usable_threads(); + let num_instance_per_batch = if steps.len() > 256 { + steps.len().div_ceil(nthreads) + } else { + steps.len() + } + .max(1); + let lk_multiplicity = LkMultiplicity::default(); + let mut raw_witin = + RowMajorMatrix::::new(steps.len(), num_witin, Self::padding_strategy()); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + + let gkr_witness = gkr_layout.gkr_witness( + &Self::phase1_witness_from_steps(gkr_layout, steps.clone()), + &vec![], + ); + + let lookups = gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .map(|instance| instance.into_iter().skip(100).collect_vec()) + .collect_vec(); + + raw_witin_iter + .zip(steps.par_chunks(num_instance_per_batch)) + .flat_map(|(instances, steps)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip(steps) + .map(|(instance, step)| { + let lookups = vec![]; + Self::assign_instance_with_gkr_iop( + config, + instance, + &mut lk_multiplicity, + step, + lookups, + ) + }) + .collect::>() + }) + .collect::>()?; + + raw_witin.padding_by_strategy(); + Ok((raw_witin, lk_multiplicity)) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index a4e572935..12aae2438 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; -use ceno_emul::{Change, InsnKind, StepRecord, SyscallSpec}; -use ff_ext::ExtensionField; +use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, SyscallSpec}; +use ff_ext::{ExtensionField, SmallField}; use itertools::Itertools; use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; @@ -11,7 +11,7 @@ use crate::{ error::ZKVMError, expression::{ToExpr, WitIn}, instructions::{ - Instruction, + GKRIOPInstruction, Instruction, riscv::{constants::UInt, insn_base::WriteRD}, }, set_val, @@ -19,6 +19,14 @@ use crate::{ }; use ff_ext::FieldInto; +use gkr_iop::{ + ProtocolWitnessGenerator, + precompiles::{ + AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakTrace, RANGE_LOOKUPS_PER_ROUND, + XOR_LOOKUPS_PER_ROUND, + }, +}; + /// LargeEcallDummy can handle any instruction and produce its effects, /// including multiple memory operations. /// @@ -70,11 +78,44 @@ impl Instruction for LargeEcallDummy }) .collect::, _>>()?; + // Temporarily set this to < 24 to avoid cb.num_witin overflow + let active_rounds = 12; + + let mut lookups = Vec::with_capacity( + active_rounds + * (3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND), + ); + + dbg!(lookups.capacity()); + + for round in 0..active_rounds { + for i in 0..AND_LOOKUPS_PER_ROUND { + let a = cb.create_witin(|| format!("and_lookup_{round}_{i}_a")); + let b = cb.create_witin(|| format!("and_lookup_{round}_{i}_b")); + let c = cb.create_witin(|| format!("and_lookup_{round}_{i}_c")); + cb.lookup_and_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..XOR_LOOKUPS_PER_ROUND { + let a = cb.create_witin(|| format!("xor_lookup_{round}_{i}_a")); + let b = cb.create_witin(|| format!("xor_lookup_{round}_{i}_b")); + let c = cb.create_witin(|| format!("xor_lookup_{round}_{i}_c")); + cb.lookup_xor_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..RANGE_LOOKUPS_PER_ROUND { + let wit = cb.create_witin(|| format!("range_lookup_{round}_{i}")); + cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; + lookups.push(wit); + } + } + Ok(LargeEcallConfig { dummy_insn, start_addr, reg_writes, mem_writes, + lookups, }) } @@ -111,6 +152,65 @@ impl Instruction for LargeEcallDummy } } +impl GKRIOPInstruction for LargeEcallDummy { + type Layout = KeccakLayout; + + fn phase1_witness_from_steps( + layout: &Self::Layout, + steps: Vec, + ) -> Vec::BaseField>> { + let instances = steps + .iter() + .map(|step| { + step.syscall() + .unwrap() + .mem_ops + .iter() + .map(|op| op.value.before) + .collect_vec() + .try_into() + .unwrap() + }) + .collect_vec(); + + layout.phase1_witness(KeccakTrace { instances }) + } + + fn assign_instance_with_gkr_iop( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + lookups: Vec, + ) -> Result<(), ZKVMError> { + Self::assign_instance(config, instance, lk_multiplicity, step)?; + + let active_rounds = 12; + let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); + let mut var_iter = config.lookups.iter(); + + let mut pop_arg = || -> u64 { + let wit = wit_iter.next().unwrap(); + let var = var_iter.next().unwrap(); + set_val!(instance, var, wit); + wit + }; + + for round in 0..active_rounds { + for i in 0..AND_LOOKUPS_PER_ROUND { + lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); + } + for i in 0..XOR_LOOKUPS_PER_ROUND { + lk_multiplicity.lookup_xor_byte(pop_arg(), pop_arg()); + } + for i in 0..RANGE_LOOKUPS_PER_ROUND { + lk_multiplicity.assert_ux::<16>(pop_arg()); + } + } + + Ok(()) + } +} #[derive(Debug)] pub struct LargeEcallConfig { dummy_insn: DummyConfig, @@ -119,4 +219,5 @@ pub struct LargeEcallConfig { start_addr: WitIn, mem_writes: Vec<(WitIn, Change, WriteMEM)>, + lookups: Vec, } diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index fa76acf7d..92173d9c3 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -13,7 +13,7 @@ fn main() { for _ in 0..ITERATIONS { syscall_keccak_permute(&mut state); - log_state(&state); + // log_state(&state); } } diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index d4b1f6007..9d82274e5 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -316,6 +316,10 @@ pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; pub const LOOKUPS_PER_ROUND: usize = AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; +pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; +pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; +pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; + macro_rules! allocate_and_split { ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); @@ -739,7 +743,8 @@ use p3_field::Field; impl From> for KeccakTrace { fn from(rmm: RowMajorMatrix) -> Self { // clarify rmm encoding for large_ecall_dummy to obtain input - unimplemented!(); + //unimplemented!(); + KeccakTrace { instances: vec![] } } } diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index 775b246ba..2d840460d 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -4,5 +4,6 @@ mod utils; pub use keccak_f::run_keccakf; pub use { faster_keccak::run_faster_keccakf, faster_keccak::KeccakLayout, faster_keccak::KeccakParams, - faster_keccak::KeccakTrace, + faster_keccak::KeccakTrace, faster_keccak::AND_LOOKUPS_PER_ROUND, + faster_keccak::RANGE_LOOKUPS_PER_ROUND, faster_keccak::XOR_LOOKUPS_PER_ROUND, }; From ee541560f4012c673d924405e63e3e575226aaba Mon Sep 17 00:00:00 2001 From: Mihai Date: Thu, 10 Apr 2025 15:28:25 +0300 Subject: [PATCH 05/17] add lookups to DummyEcall (wip, Apr 10) --- ceno_emul/src/syscalls.rs | 2 + ceno_emul/src/syscalls/keccak_permute.rs | 1 + ceno_zkvm/benches/riscv_add.rs | 1 + ceno_zkvm/src/e2e.rs | 16 ++-- ceno_zkvm/src/instructions.rs | 47 +++++++---- .../instructions/riscv/dummy/dummy_ecall.rs | 82 +++++++++++-------- ceno_zkvm/src/instructions/riscv/rv32im.rs | 4 +- ceno_zkvm/src/scheme/prover.rs | 11 +-- ceno_zkvm/src/scheme/tests.rs | 1 + ceno_zkvm/src/structs.rs | 29 ++++--- gkr_iop/src/precompiles/faster_keccak.rs | 20 +++-- gkr_iop/src/precompiles/utils.rs | 13 +++ rust-toolchain.toml | 2 +- sanity.sh | 17 ++++ 14 files changed, 165 insertions(+), 81 deletions(-) create mode 100755 sanity.sh diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index b6fd82ac2..363402243 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -20,6 +20,8 @@ pub trait SyscallSpec { const REG_OPS_COUNT: usize; const MEM_OPS_COUNT: usize; const CODE: u32; + + const HAS_LOOKUPS: bool = false; } /// Trace the inputs and effects of a syscall. diff --git a/ceno_emul/src/syscalls/keccak_permute.rs b/ceno_emul/src/syscalls/keccak_permute.rs index 3748fc3ea..7afd3ef83 100644 --- a/ceno_emul/src/syscalls/keccak_permute.rs +++ b/ceno_emul/src/syscalls/keccak_permute.rs @@ -16,6 +16,7 @@ impl SyscallSpec for KeccakSpec { const REG_OPS_COUNT: usize = 2; const MEM_OPS_COUNT: usize = KECCAK_WORDS; const CODE: u32 = ceno_rt::syscalls::KECCAK_PERMUTE; + const HAS_LOOKUPS: bool = true; } /// Wrapper type for the keccak_permute argument that implements conversions diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index b0e6cd2b0..1a972587f 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -95,6 +95,7 @@ fn bench_add(c: &mut Criterion) { "ADD", &prover.pk.pp, &circuit_pk, + &None, polys.into_iter().map(|mle| mle.into()).collect_vec(), commit, &[], diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 4f0a2dee3..48720a634 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -295,7 +295,7 @@ fn generate_fixed_traces( } pub fn generate_witness( - system_config: &ConstraintSystemConfig, + system_config: &mut ConstraintSystemConfig, emul_result: EmulationResult, program: &Program, is_mock_proving: bool, @@ -312,7 +312,7 @@ pub fn generate_witness( .unwrap(); system_config .dummy_config - .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .assign_opcode_circuit(&mut system_config.zkvm_cs, &mut zkvm_witness, dummy_records) .unwrap(); zkvm_witness.finalize_lk_multiplicities(is_mock_proving); @@ -394,7 +394,7 @@ pub fn run_e2e_with_checkpoint< }; let program = Arc::new(program); - let system_config = construct_configs::(program_params); + let mut system_config = construct_configs::(program_params); let reg_init = system_config.mmu_config.initial_registers(); // IO is not used in this program, but it must have a particular size at the moment. @@ -432,7 +432,7 @@ pub fn run_e2e_with_checkpoint< init_full_mem, platform, hints, - &system_config, + &mut system_config, pk, zkvm_fixed_traces, is_mock_proving, @@ -451,11 +451,13 @@ pub fn run_e2e_with_checkpoint< if let Checkpoint::PrepWitnessGen = checkpoint { return ( None, - Box::new(move || _ = generate_witness(&system_config, emul_result, &program, false)), + Box::new(move || { + _ = generate_witness(&mut system_config, emul_result, &program, false) + }), ); } - let zkvm_witness = generate_witness(&system_config, emul_result, &program, is_mock_proving); + let zkvm_witness = generate_witness(&mut system_config, emul_result, &program, is_mock_proving); // proving let prover = ZKVMProver::new(pk); @@ -496,7 +498,7 @@ pub fn run_e2e_proof, - system_config: &ConstraintSystemConfig, + system_config: &mut ConstraintSystemConfig, pk: ZKVMProvingKey, zkvm_fixed_traces: ZKVMFixedTraces, is_mock_proving: bool, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 38b1469e8..cbfe6abd4 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -84,12 +84,12 @@ where instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, - lookups: Vec, + lookups: &[E::BaseField], ) -> Result<(), ZKVMError>; fn phase1_witness_from_steps( layout: &Self::Layout, - steps: Vec, + steps: &[StepRecord], ) -> Vec>; fn assign_instances_with_gkr_iop( @@ -111,34 +111,51 @@ where let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); let gkr_witness = gkr_layout.gkr_witness( - &Self::phase1_witness_from_steps(gkr_layout, steps.clone()), + &Self::phase1_witness_from_steps(gkr_layout, &steps), &vec![], ); - let lookups = gkr_witness - .layers - .last() - .unwrap() - .bases - .iter() - .map(|instance| instance.into_iter().skip(100).collect_vec()) - .collect_vec(); + let lookups = { + let mut lookups = vec![vec![]; steps.len()]; + for witness in gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .skip(100) + .cloned() + { + for i in 0..witness.len() { + lookups[i].push(witness[i]); + } + } + lookups + }; + + // dbg!(&lookups); raw_witin_iter - .zip(steps.par_chunks(num_instance_per_batch)) + .zip( + steps + .iter() + .enumerate() + .collect_vec() + .par_chunks(num_instance_per_batch), + ) .flat_map(|(instances, steps)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) .zip(steps) - .map(|(instance, step)| { - let lookups = vec![]; + .map(|(instance, (i, step))| { + // dbg!(i, step); Self::assign_instance_with_gkr_iop( config, instance, &mut lk_multiplicity, step, - lookups, + &lookups[*i], ) }) .collect::>() diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 12aae2438..b247c5a4c 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -79,34 +79,36 @@ impl Instruction for LargeEcallDummy .collect::, _>>()?; // Temporarily set this to < 24 to avoid cb.num_witin overflow - let active_rounds = 12; + let active_rounds = 0; let mut lookups = Vec::with_capacity( active_rounds * (3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND), ); - dbg!(lookups.capacity()); - - for round in 0..active_rounds { - for i in 0..AND_LOOKUPS_PER_ROUND { - let a = cb.create_witin(|| format!("and_lookup_{round}_{i}_a")); - let b = cb.create_witin(|| format!("and_lookup_{round}_{i}_b")); - let c = cb.create_witin(|| format!("and_lookup_{round}_{i}_c")); - cb.lookup_and_byte(a.into(), b.into(), c.into())?; - lookups.extend(vec![a, b, c]); - } - for i in 0..XOR_LOOKUPS_PER_ROUND { - let a = cb.create_witin(|| format!("xor_lookup_{round}_{i}_a")); - let b = cb.create_witin(|| format!("xor_lookup_{round}_{i}_b")); - let c = cb.create_witin(|| format!("xor_lookup_{round}_{i}_c")); - cb.lookup_xor_byte(a.into(), b.into(), c.into())?; - lookups.extend(vec![a, b, c]); - } - for i in 0..RANGE_LOOKUPS_PER_ROUND { - let wit = cb.create_witin(|| format!("range_lookup_{round}_{i}")); - cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; - lookups.push(wit); + if S::HAS_LOOKUPS { + dbg!(lookups.capacity()); + + for round in 0..active_rounds { + for i in 0..AND_LOOKUPS_PER_ROUND { + let a = cb.create_witin(|| format!("and_lookup_{round}_{i}_a")); + let b = cb.create_witin(|| format!("and_lookup_{round}_{i}_b")); + let c = cb.create_witin(|| format!("and_lookup_{round}_{i}_c")); + cb.lookup_and_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..XOR_LOOKUPS_PER_ROUND { + let a = cb.create_witin(|| format!("xor_lookup_{round}_{i}_a")); + let b = cb.create_witin(|| format!("xor_lookup_{round}_{i}_b")); + let c = cb.create_witin(|| format!("xor_lookup_{round}_{i}_c")); + cb.lookup_xor_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..RANGE_LOOKUPS_PER_ROUND { + let wit = cb.create_witin(|| format!("range_lookup_{round}_{i}")); + cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; + lookups.push(wit); + } } } @@ -157,7 +159,7 @@ impl GKRIOPInstruction for LargeEcallDummy fn phase1_witness_from_steps( layout: &Self::Layout, - steps: Vec, + steps: &[StepRecord], ) -> Vec::BaseField>> { let instances = steps .iter() @@ -181,33 +183,47 @@ impl GKRIOPInstruction for LargeEcallDummy instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, - lookups: Vec, + lookups: &[E::BaseField], ) -> Result<(), ZKVMError> { Self::assign_instance(config, instance, lk_multiplicity, step)?; - let active_rounds = 12; + let active_rounds = 0; let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); let mut var_iter = config.lookups.iter(); + dbg!(wit_iter.clone().count()); + dbg!(var_iter.clone().count()); + let mut pop_arg = || -> u64 { let wit = wit_iter.next().unwrap(); let var = var_iter.next().unwrap(); - set_val!(instance, var, wit); + // set_val!(instance, var, wit); + set_val!(instance, var, 0); wit }; - for round in 0..active_rounds { - for i in 0..AND_LOOKUPS_PER_ROUND { - lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); + for _round in 0..active_rounds { + for _i in 0..AND_LOOKUPS_PER_ROUND { + // lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); + lk_multiplicity.lookup_and_byte(0, 0); + pop_arg(); + pop_arg(); + pop_arg(); } - for i in 0..XOR_LOOKUPS_PER_ROUND { - lk_multiplicity.lookup_xor_byte(pop_arg(), pop_arg()); + for _i in 0..XOR_LOOKUPS_PER_ROUND { + lk_multiplicity.lookup_xor_byte(0, 0); + pop_arg(); + pop_arg(); + pop_arg(); } - for i in 0..RANGE_LOOKUPS_PER_ROUND { - lk_multiplicity.assert_ux::<16>(pop_arg()); + for _i in 0..RANGE_LOOKUPS_PER_ROUND { + lk_multiplicity.assert_ux::<16>(0); + pop_arg(); } } + assert!(var_iter.next().is_none()); + Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index ed42f25eb..d7dada846 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -524,7 +524,7 @@ impl DummyExtraConfig { pub fn assign_opcode_circuit( &self, - cs: &ZKVMConstraintSystem, + cs: &mut ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, steps: GroupedSteps, ) -> Result<(), ZKVMError> { @@ -562,7 +562,7 @@ impl DummyExtraConfig { } } - witness.assign_keccakf_circuit(&cs, &self.keccak_config, keccak_steps)?; + witness.assign_keccakf_circuit(cs, &self.keccak_config, keccak_steps)?; witness.assign_opcode_circuit::>( cs, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 5474edc77..801d17bc2 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -181,7 +181,8 @@ impl> ZKVMProver { && cs.w_table_expressions.is_empty(); if *circuit_name == as Instruction>::name() { - unimplemented!("keccak impl wip"); + // unimplemented!("keccak impl wip"); + () } else if is_opcode_circuit { tracing::debug!( "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", @@ -278,10 +279,10 @@ impl> ZKVMProver { let log2_num_instances = ceil_log2(next_pow2_instances); let (chip_record_alpha, _) = (challenges[0], challenges[1]); - if let Some(gkr_iop_pk) = gkr_iop_pk { - let mut gkr_iop_pk = gkr_iop_pk.clone(); - unimplemented!("cannot fully handle GKRIOP component yet") - } + // if let Some(gkr_iop_pk) = gkr_iop_pk { + // let mut gkr_iop_pk = gkr_iop_pk.clone(); + // unimplemented!("cannot fully handle GKRIOP component yet") + // } // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 6b1a2151e..8e73ae0b3 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -139,6 +139,7 @@ fn test_rw_lk_expression_combination() { name.as_str(), &prover.pk.pp, prover.pk.circuit_pks.get(&name).unwrap(), + &None, wits_in, commit, &[], diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 2d9fd7013..a8775c3d0 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -2,7 +2,7 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::Expression, - instructions::{Instruction, riscv::dummy::LargeEcallDummy}, + instructions::{GKRIOPInstruction, Instruction, riscv::dummy::LargeEcallDummy}, state::StateCircuit, tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, @@ -365,7 +365,8 @@ impl ZKVMWitnesses { pub fn assign_keccakf_circuit( &mut self, - css: &ZKVMConstraintSystem, + // TODO: do without mutability requirement + css: &mut ZKVMConstraintSystem, config: & as Instruction>::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -373,17 +374,19 @@ impl ZKVMWitnesses { let cs = css .get_cs(&LargeEcallDummy::::name()) .unwrap(); - let (witness, logup_multiplicity) = LargeEcallDummy::::assign_instances( - config, - cs.num_witin as usize, - records, - )?; - - // Intercept row-major matrix, convert into KeccakTrace and obtain phase1_wit - self.keccak_phase1wit = css - .keccak_gkr_iop - .layout - .phase1_witness(KeccakTrace::from(witness.clone())); + let (witness, logup_multiplicity) = + LargeEcallDummy::::assign_instances_with_gkr_iop( + config, + cs.num_witin as usize, + records, + &mut css.keccak_gkr_iop.layout, + )?; + + // // Intercept row-major matrix, convert into KeccakTrace and obtain phase1_wit + // self.keccak_phase1wit = css + // .keccak_gkr_iop + // .layout + // .phase1_witness(KeccakTrace::from(witness.clone())); assert!( self.witnesses_opcodes diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 9d82274e5..3a896c527 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -314,7 +314,7 @@ pub const AND_LOOKUPS_PER_ROUND: usize = 200; pub const XOR_LOOKUPS_PER_ROUND: usize = 608; pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; pub const LOOKUPS_PER_ROUND: usize = - AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; + 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; @@ -656,8 +656,11 @@ impl ProtocolBuilder for KeccakLayout { let beta = Constant::Base(0); // Send all lookups to the final output layer - for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups).enumerate() { - expressions.push(lookup.compress(alpha.clone(), beta.clone())); + for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) + .flatten() + .enumerate() + { + expressions.push(lookup.clone()); expr_names.push(format!("{i}th: {:?}", lookup)); evals.push(lookup_outputs[lookup_index].clone()); lookup_index += 1; @@ -683,6 +686,8 @@ impl ProtocolBuilder for KeccakLayout { }, ); + assert!(lookup_index == LOOKUPS_PER_ROUND * ROUNDS); + let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); let state8: ArrayView<(Witness, EvalExpression), Ix3> = @@ -790,12 +795,16 @@ where let mut add_and = |a: u64, b: u64, round: usize| { let c = a & b; - and_lookups[round].push((c << 16) + (b << 8) + a); + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + and_lookups[round].extend(vec![a, b, c]); }; let mut add_xor = |a: u64, b: u64, round: usize| { let c = a ^ b; - xor_lookups[round].push((c << 16) + (b << 8) + a); + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + xor_lookups[round].extend(vec![a, b, c]); }; let mut add_range = |value: u64, size: usize, round: usize| { @@ -803,6 +812,7 @@ where range_lookups[round].push(value); if size < 16 { range_lookups[round].push(value << (16 - size)); + assert!(value << (16 - size) < (1 << 16)); } }; diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index bec858e43..0dd484b14 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -129,6 +129,19 @@ pub enum CenoLookup { U16(Expression), } +impl IntoIterator for CenoLookup { + type Item = Expression; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + match self { + CenoLookup::And(a, b, c) => vec![a, b, c].into_iter(), + CenoLookup::Xor(a, b, c) => vec![a, b, c].into_iter(), + CenoLookup::U16(a) => vec![a].into_iter(), + } + } +} + impl CenoLookup { // combine lookup arguments with challenges pub fn compress(&self, alpha: Constant, beta: Constant) -> Expression { diff --git a/rust-toolchain.toml b/rust-toolchain.toml index faa027589..98e47aaf7 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2025-01-06" +channel = "nightly-2025-03-23" targets = ["riscv32im-unknown-none-elf"] # We need the sources for build-std. components = ["rust-src"] diff --git a/sanity.sh b/sanity.sh new file mode 100755 index 000000000..efdaec941 --- /dev/null +++ b/sanity.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +cargo make build + +# Run the first command +cargo run --bin e2e --release -- examples/target/riscv32im-ceno-zkvm-elf/release/examples/secp256k1_add_syscall +if [ $? -ne 0 ]; then + echo "Command 1 failed" + exit 1 +fi + +# Run the second command +cargo run --bin e2e --release -- examples/target/riscv32im-ceno-zkvm-elf/release/examples/ceno_rt_keccak +if [ $? -ne 0 ]; then + echo "Command 2 failed" + exit 1 +fi From dcf7c6209c7ecffed4cd94a4c04e7d79ee52ca96 Mon Sep 17 00:00:00 2001 From: Mihai Date: Mon, 14 Apr 2025 10:56:53 +0300 Subject: [PATCH 06/17] debug --- ceno_zkvm/src/instructions.rs | 11 +++- .../instructions/riscv/dummy/dummy_ecall.rs | 27 +++++----- ceno_zkvm/src/instructions/riscv/rv32im.rs | 10 +++- ceno_zkvm/src/scheme/prover.rs | 50 ++++++++++++++----- ceno_zkvm/src/structs.rs | 6 ++- gkr_iop/src/gkr.rs | 4 +- gkr_iop/src/gkr/layer.rs | 2 +- gkr_iop/src/precompiles/faster_keccak.rs | 45 ++++++++++------- rust-toolchain.toml | 2 +- subprotocols/src/zerocheck.rs | 11 ++-- 10 files changed, 113 insertions(+), 55 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index cbfe6abd4..5d63ab602 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -97,7 +97,14 @@ where num_witin: usize, steps: Vec, gkr_layout: &mut Self::Layout, - ) -> Result<(RowMajorMatrix, LkMultiplicity), ZKVMError> { + ) -> Result< + ( + RowMajorMatrix, + GKRCircuitWitness, + LkMultiplicity, + ), + ZKVMError, + > { let nthreads = max_usable_threads(); let num_instance_per_batch = if steps.len() > 256 { steps.len().div_ceil(nthreads) @@ -163,6 +170,6 @@ where .collect::>()?; raw_witin.padding_by_strategy(); - Ok((raw_witin, lk_multiplicity)) + Ok((raw_witin, gkr_witness, lk_multiplicity)) } } diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index b247c5a4c..bec7911a7 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -79,7 +79,7 @@ impl Instruction for LargeEcallDummy .collect::, _>>()?; // Temporarily set this to < 24 to avoid cb.num_witin overflow - let active_rounds = 0; + let active_rounds = 12; let mut lookups = Vec::with_capacity( active_rounds @@ -187,7 +187,7 @@ impl GKRIOPInstruction for LargeEcallDummy ) -> Result<(), ZKVMError> { Self::assign_instance(config, instance, lk_multiplicity, step)?; - let active_rounds = 0; + let active_rounds = 12; let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); let mut var_iter = config.lookups.iter(); @@ -197,28 +197,29 @@ impl GKRIOPInstruction for LargeEcallDummy let mut pop_arg = || -> u64 { let wit = wit_iter.next().unwrap(); let var = var_iter.next().unwrap(); - // set_val!(instance, var, wit); - set_val!(instance, var, 0); + set_val!(instance, var, wit); + // set_val!(instance, var, 0); wit }; for _round in 0..active_rounds { for _i in 0..AND_LOOKUPS_PER_ROUND { - // lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); - lk_multiplicity.lookup_and_byte(0, 0); - pop_arg(); - pop_arg(); + lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); + // lk_multiplicity.lookup_and_byte(0, 0); + // pop_arg(); + // pop_arg(); pop_arg(); } for _i in 0..XOR_LOOKUPS_PER_ROUND { - lk_multiplicity.lookup_xor_byte(0, 0); - pop_arg(); - pop_arg(); + lk_multiplicity.lookup_xor_byte(pop_arg(), pop_arg()); + // lk_multiplicity.lookup_xor_byte(0, 0); + // pop_arg(); + // pop_arg(); pop_arg(); } for _i in 0..RANGE_LOOKUPS_PER_ROUND { - lk_multiplicity.assert_ux::<16>(0); - pop_arg(); + lk_multiplicity.assert_ux::<16>(pop_arg()); + // pop_arg(); } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index d7dada846..6d279ca18 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -466,6 +466,7 @@ pub struct DummyExtraConfig { impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { let ecall_config = cs.register_opcode_circuit::>(); + // let keccak_config = cs.register_opcode_circuit::>(); let keccak_config = cs.register_keccakf_circuit(); let secp256k1_add_config = cs.register_opcode_circuit::>(); @@ -510,9 +511,10 @@ impl DummyExtraConfig { ) { fixed.register_opcode_circuit::>(cs); fixed.register_keccakf_circuit(cs); + // fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); - fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); @@ -564,6 +566,12 @@ impl DummyExtraConfig { witness.assign_keccakf_circuit(cs, &self.keccak_config, keccak_steps)?; + // witness.assign_opcode_circuit::>( + // cs, + // &self.keccak_config, + // keccak_steps, + //)?; + witness.assign_opcode_circuit::>( cs, &self.secp256k1_add_config, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 801d17bc2..c7a8ecf6b 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,5 +1,6 @@ use ceno_emul::KeccakSpec; use ff_ext::ExtensionField; +use gkr_iop::gkr::GKRCircuitWitness; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, sync::Arc, @@ -99,6 +100,7 @@ impl> ZKVMProver { let mut commitments = BTreeMap::new(); let mut wits = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); + let mut keccak_gkr_wit = Some(witnesses.keccak_gkr_wit.clone()); let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name @@ -180,10 +182,7 @@ impl> ZKVMProver { && cs.r_table_expressions.is_empty() && cs.w_table_expressions.is_empty(); - if *circuit_name == as Instruction>::name() { - // unimplemented!("keccak impl wip"); - () - } else if is_opcode_circuit { + if is_opcode_circuit { tracing::debug!( "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", circuit_name, @@ -197,7 +196,7 @@ impl> ZKVMProver { let gkr_iop_pk = if *circuit_name == as Instruction>::name() { - Some(self.pk.keccak_pk.clone()) + Some((&self.pk.keccak_pk, keccak_gkr_wit.take().unwrap())) } else { None }; @@ -206,7 +205,7 @@ impl> ZKVMProver { circuit_name, &self.pk.pp, pk, - &gkr_iop_pk, + gkr_iop_pk, witness, wits_commit, &pi, @@ -266,7 +265,10 @@ impl> ZKVMProver { name: &str, pp: &PCS::ProverParam, circuit_pk: &ProvingKey, - gkr_iop_pk: &Option>>, + gkr_iop_pk: Option<( + &GKRIOPProvingKey>, + GKRCircuitWitness, + )>, witnesses: Vec>, wits_commit: PCS::CommitmentWithWitness, pi: &[ArcMultilinearExtension<'_, E>], @@ -279,11 +281,6 @@ impl> ZKVMProver { let log2_num_instances = ceil_log2(next_pow2_instances); let (chip_record_alpha, _) = (challenges[0], challenges[1]); - // if let Some(gkr_iop_pk) = gkr_iop_pk { - // let mut gkr_iop_pk = gkr_iop_pk.clone(); - // unimplemented!("cannot fully handle GKRIOP component yet") - // } - // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); assert!( @@ -331,6 +328,7 @@ impl> ZKVMProver { // infer all tower witness after last layer let span = entered_span!("tower_witness_r_layers"); + let r_wit_layers = infer_tower_product_witness( log2_num_instances + log2_r_count, r_records_last_layer, @@ -455,6 +453,8 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); + dbg!(rt_r.len(), rt_w.len(), rt_lk.len()); + let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), @@ -481,6 +481,8 @@ impl> ZKVMProver { ); } + dbg!(sel_r.len()); + let mut sel_w = build_eq_x_r_vec(&rt_w[log2_w_count..]); if num_instances < sel_w.len() { sel_w.splice( @@ -630,6 +632,11 @@ impl> ZKVMProver { distrinct_zerocheck_terms_set.len() + 1 // +1 from sel_non_lc_zero_sumcheck } ); + + dbg!(r_counts_per_instance); + dbg!(w_counts_per_instance); + dbg!(lk_counts_per_instance); + let mut main_sel_evals_iter = main_sel_evals.into_iter(); main_sel_evals_iter.next(); // skip sel_r let r_records_in_evals = (0..r_counts_per_instance) @@ -653,6 +660,12 @@ impl> ZKVMProver { distrinct_zerocheck_terms_set.len() + 1 } ); + + // if let Some(gkr_iop_pk) = gkr_iop_pk { + // let mut gkr_iop_pk = gkr_iop_pk.clone(); + // unimplemented!("cannot fully handle GKRIOP component yet") + // } + let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); exit_span!(main_sel_span); @@ -689,6 +702,19 @@ impl> ZKVMProver { exit_span!(pcs_open_span); let wits_commit = PCS::get_pure_commitment(&wits_commit); + if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { + let mut gkr_iop_pk = gkr_iop_pk.clone(); + let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit(); + + // let mut prover_transcript = transcript::BasicTranscript::::new(b"protocol"); + let out_evals = vec![]; + + let gkr_iop::gkr::GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_wit, &out_evals, &vec![], transcript) + .expect("Failed to prove phase"); + // unimplemented!("cannot fully handle GKRIOP component yet") + } + Ok(ZKVMOpcodeProof { num_instances, record_r_out_evals, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index a8775c3d0..91706e037 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -11,6 +11,7 @@ use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord}; use ff_ext::ExtensionField; use gkr_iop::{ ProtocolWitnessGenerator, + gkr::GKRCircuitWitness, precompiles::{KeccakLayout, KeccakTrace}, }; use itertools::{Itertools, chain}; @@ -343,7 +344,7 @@ impl ZKVMFixedTraces { #[derive(Default, Clone)] pub struct ZKVMWitnesses { - keccak_phase1wit: Vec>, + pub keccak_gkr_wit: GKRCircuitWitness, witnesses_opcodes: BTreeMap>, witnesses_tables: BTreeMap>, lk_mlts: BTreeMap, @@ -374,13 +375,14 @@ impl ZKVMWitnesses { let cs = css .get_cs(&LargeEcallDummy::::name()) .unwrap(); - let (witness, logup_multiplicity) = + let (witness, gkr_witness, logup_multiplicity) = LargeEcallDummy::::assign_instances_with_gkr_iop( config, cs.num_witin as usize, records, &mut css.keccak_gkr_iop.layout, )?; + self.keccak_gkr_wit = gkr_witness; // // Intercept row-major matrix, convert into KeccakTrace and obtain phase1_wit // self.keccak_phase1wit = css diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 50ae2e7a7..eb908d047 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -22,7 +22,7 @@ pub struct GKRCircuit { pub ext_openings: Vec<(usize, EvalExpression)>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct GKRCircuitWitness { pub layers: Vec>, } @@ -85,7 +85,7 @@ impl GKRCircuit { let mut challenges = challenges.to_vec(); let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); - for (layer, layer_proof) in izip!(self.layers.clone(), sumcheck_proofs) { + for (layer, layer_proof) in izip!(&self.layers, sumcheck_proofs) { layer.verify(layer_proof, &mut evaluations, &mut challenges, transcript)?; } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 00749d822..b9e51708c 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -53,7 +53,7 @@ pub struct Layer { pub expr_names: Vec, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct LayerWitness { pub bases: Vec>, pub exts: Vec>, diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 3a896c527..73fa6d106 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -314,6 +314,8 @@ pub const AND_LOOKUPS_PER_ROUND: usize = 200; pub const XOR_LOOKUPS_PER_ROUND: usize = 608; pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; pub const LOOKUPS_PER_ROUND: usize = + AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; +pub const LOOKUP_FELTS_PER_ROUND: usize = 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; @@ -348,14 +350,15 @@ impl ProtocolBuilder for KeccakLayout { fn build_gkr_phase(&mut self, chip: &mut Chip) { let final_outputs = - chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUPS_PER_ROUND * ROUNDS }>(); + chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); let mut final_outputs_iter = final_outputs.iter(); let [keccak_output32, keccak_input32, lookup_outputs] = [ KECCAK_OUTPUT_SIZE, KECCAK_INPUT_SIZE, - LOOKUPS_PER_ROUND * ROUNDS, + //LOOKUPS_PER_ROUND + LOOKUP_FELTS_PER_ROUND * ROUNDS, ] .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); @@ -656,12 +659,18 @@ impl ProtocolBuilder for KeccakLayout { let beta = Constant::Base(0); // Send all lookups to the final output layer + // for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) + // .flatten() + // .enumerate() + // { + // expressions.push(lookup.clone()); for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) .flatten() .enumerate() { - expressions.push(lookup.clone()); - expr_names.push(format!("{i}th: {:?}", lookup)); + expressions.push(lookup); + //expressions.push(lookup.compress(alpha.clone(), beta.clone())); + expr_names.push(format!("{i}th: aha")); evals.push(lookup_outputs[lookup_index].clone()); lookup_index += 1; } @@ -686,7 +695,7 @@ impl ProtocolBuilder for KeccakLayout { }, ); - assert!(lookup_index == LOOKUPS_PER_ROUND * ROUNDS); + assert!(lookup_index == LOOKUP_FELTS_PER_ROUND * ROUNDS); let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); @@ -743,16 +752,6 @@ pub struct KeccakTrace { pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, } -use p3_field::Field; - -impl From> for KeccakTrace { - fn from(rmm: RowMajorMatrix) -> Self { - // clarify rmm encoding for large_ecall_dummy to obtain input - //unimplemented!(); - KeccakTrace { instances: vec![] } - } -} - impl ProtocolWitnessGenerator for KeccakLayout where E: ExtensionField, @@ -842,8 +841,9 @@ where if layer_wits[curr_layer].bases.is_empty() { layer_wits[curr_layer] = LayerWitness::new(nest::(&felts), vec![]); } else { - for (i, bases) in layer_wits[curr_layer].bases.iter_mut().enumerate() { - bases.push(felts[i]); + assert_eq!(felts.len(), layer_wits[curr_layer].bases.len()); + for (i, base) in layer_wits[curr_layer].bases.iter_mut().enumerate() { + base.push(felts[i]); } } curr_layer += 1; @@ -1065,6 +1065,8 @@ where ); } + dbg!(layer_wits.len()); + let len = layer_wits.len() - 1; layer_wits[..len].reverse(); @@ -1075,7 +1077,6 @@ where pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> () { let params = KeccakParams {}; let (layout, chip) = KeccakLayout::build(params); - let gkr_circuit = chip.gkr_circuit(); let mut instances = vec![]; for state in states { @@ -1094,6 +1095,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( ); } + dbg!(instances.len()); let phase1_witness = layout.phase1_witness(KeccakTrace { instances }); let mut prover_transcript = BasicTranscript::::new(b"protocol"); @@ -1139,6 +1141,11 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( // } let len = final_output1.len(); + assert_eq!( + len, + KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS + ); + assert!(final_output1.len() == final_output2.len()); let gkr_outputs = chain!( zip(final_output1, once(point1).cycle().take(len)), zip(final_output2, once(point2).cycle().take(len)) @@ -1152,6 +1159,8 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( .collect_vec() }; + let gkr_circuit = chip.gkr_circuit(); + dbg!(&gkr_circuit.layers.len()); let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) .expect("Failed to prove phase"); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 98e47aaf7..faa027589 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2025-03-23" +channel = "nightly-2025-01-06" targets = ["riscv32im-unknown-none-elf"] # We need the sources for build-std. components = ["rust-src"] diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs index 7b6d989e4..aad339b5a 100644 --- a/subprotocols/src/zerocheck.rs +++ b/subprotocols/src/zerocheck.rs @@ -258,6 +258,8 @@ where base_mle_evals, } = proof; + // dbg!(&exprs); + let (in_point, expected_claims) = univariate_polys.into_iter().enumerate().fold( (vec![], sigmas), |(mut last_point, last_sigmas), (round, round_msg)| { @@ -272,11 +274,14 @@ where let sigmas = izip!(&exprs, &inv_of_one_minus_points, round_msg, last_sigmas) .map(|((_, point), inv_of_one_minus_point, poly, last_sigma)| { let len = poly.len() + 1; + // dbg!(&round, &point, &inv_of_one_minus_point, &poly); // last_sigma = (1 - point[round]) * eval_at_0 + point[round] * eval_at_1 // eval_at_0 = (last_sigma - point[round] * eval_at_1) * inv(1 - point[round]) - let eval_at_0 = - (last_sigma - point[round] * poly[0]) * inv_of_one_minus_point[round]; - + let eval_at_0 = if !poly.is_empty() { + (last_sigma - point[round] * poly[0]) * inv_of_one_minus_point[round] + } else { + last_sigma + }; // Evaluations on degree, degree - 1, ..., 1, 0. let evals_iter_rev = chain![poly.into_iter().rev(), iter::once(eval_at_0)]; From db68cfd2bb7893524302859bb6e0ca80b028457a Mon Sep 17 00:00:00 2001 From: Mihai Date: Thu, 17 Apr 2025 13:30:04 +0300 Subject: [PATCH 07/17] correct phase1_witness and commit --- gkr_iop/src/precompiles/faster_keccak.rs | 25 ++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 73fa6d106..6591137b9 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -39,7 +39,7 @@ pub struct KeccakParams {} pub struct KeccakLayout { params: KeccakParams, - committed_bits_id: usize, + input_columns: Vec, result: Vec, _marker: PhantomData, @@ -345,7 +345,8 @@ impl ProtocolBuilder for KeccakLayout { } fn build_commit_phase(&mut self, chip: &mut Chip) { - [self.committed_bits_id] = chip.allocate_committed_base(); + let input_columns: [usize; KECCAK_INPUT_SIZE] = chip.allocate_committed_base(); + self.input_columns = input_columns.into(); } fn build_gkr_phase(&mut self, chip: &mut Chip) { @@ -759,13 +760,14 @@ where type Trace = KeccakTrace; fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { - let mut res = vec![]; + let mut poly = vec![vec![]; KECCAK_INPUT_SIZE]; for instance in phase1.instances { - res.push(u64s_to_felts::( - instance.into_iter().map(|e| e as u64).collect_vec(), - )); + let felts = u64s_to_felts::(instance.into_iter().map(|e| e as u64).collect_vec()); + for i in 0..KECCAK_INPUT_SIZE { + poly[i].push(felts[i]); + } } - res + poly } fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { @@ -779,7 +781,9 @@ where n_layers ]; - for com_state in phase1 { + let num_instances = phase1[0].len(); + + for i in 0..num_instances { fn conv64to8(input: u64) -> [u64; 8] { MaskRepresentation::new(vec![(64, input).into()]) .convert(vec![8; 8]) @@ -788,6 +792,11 @@ where .unwrap() } + let mut com_state = vec![]; + for j in 0..KECCAK_INPUT_SIZE { + com_state.push(phase1[j][i]); + } + let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; From dd572224df947d044bf9a1c4d9cbc11c7faa9013 Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 18 Apr 2025 11:00:00 +0300 Subject: [PATCH 08/17] add aux wits to PCS --- ceno_zkvm/benches/riscv_add.rs | 2 +- ceno_zkvm/src/instructions.rs | 19 ++++++++++++++++-- .../instructions/riscv/dummy/dummy_ecall.rs | 20 ++++++++++++++++--- ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/structs.rs | 2 +- ceno_zkvm/src/utils.rs | 4 ++-- gkr_iop/src/precompiles/faster_keccak.rs | 8 ++++---- 7 files changed, 43 insertions(+), 14 deletions(-) diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 1a972587f..b35af5385 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -95,7 +95,7 @@ fn bench_add(c: &mut Criterion) { "ADD", &prover.pk.pp, &circuit_pk, - &None, + None, polys.into_iter().map(|mle| mle.into()).collect_vec(), commit, &[], diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 5d63ab602..8448b898f 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -85,6 +85,7 @@ where lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, lookups: &[E::BaseField], + aux_wits: &[E::BaseField], ) -> Result<(), ZKVMError>; fn phase1_witness_from_steps( @@ -122,7 +123,8 @@ where &vec![], ); - let lookups = { + let (lookups, aux_wits) = { + // Keccak-specific let mut lookups = vec![vec![]; steps.len()]; for witness in gkr_witness .layers @@ -137,7 +139,19 @@ where lookups[i].push(witness[i]); } } - lookups + + let mut aux_wits: Vec> = vec![vec![]; steps.len()]; + let n_layers = gkr_witness.layers.len(); + + for i in 0..steps.len() { + for layer in gkr_witness.layers[..n_layers - 1].iter() { + for base in layer.bases.iter() { + aux_wits[i].push(base[i]); + } + } + } + + (lookups, aux_wits) }; // dbg!(&lookups); @@ -163,6 +177,7 @@ where &mut lk_multiplicity, step, &lookups[*i], + &aux_wits[*i], ) }) .collect::>() diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index bec7911a7..401853718 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, SyscallSpec}; use ff_ext::{ExtensionField, SmallField}; -use itertools::Itertools; +use itertools::{Itertools, zip_eq}; use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; use crate::{ @@ -79,13 +79,15 @@ impl Instruction for LargeEcallDummy .collect::, _>>()?; // Temporarily set this to < 24 to avoid cb.num_witin overflow - let active_rounds = 12; + let active_rounds = 24; let mut lookups = Vec::with_capacity( active_rounds * (3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND), ); + let mut aux_wits = vec![]; + if S::HAS_LOOKUPS { dbg!(lookups.capacity()); @@ -110,6 +112,10 @@ impl Instruction for LargeEcallDummy lookups.push(wit); } } + + for i in 0..40144 { + aux_wits.push(cb.create_witin(|| format!("aux_wit{i}"))); + } } Ok(LargeEcallConfig { @@ -118,6 +124,7 @@ impl Instruction for LargeEcallDummy reg_writes, mem_writes, lookups, + aux_wits, }) } @@ -184,10 +191,11 @@ impl GKRIOPInstruction for LargeEcallDummy lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, lookups: &[E::BaseField], + aux_wits: &[E::BaseField], ) -> Result<(), ZKVMError> { Self::assign_instance(config, instance, lk_multiplicity, step)?; - let active_rounds = 12; + let active_rounds = 24; let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); let mut var_iter = config.lookups.iter(); @@ -223,6 +231,10 @@ impl GKRIOPInstruction for LargeEcallDummy } } + dbg!(aux_wits.len()); + for (aux_wit_var, aux_wit) in zip_eq(config.aux_wits.iter(), aux_wits) { + set_val!(instance, aux_wit_var, (aux_wit.to_canonical_u64())); + } assert!(var_iter.next().is_none()); Ok(()) @@ -236,5 +248,7 @@ pub struct LargeEcallConfig { start_addr: WitIn, mem_writes: Vec<(WitIn, Change, WriteMEM)>, + + aux_wits: Vec, lookups: Vec, } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 8e73ae0b3..98ef23176 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -139,7 +139,7 @@ fn test_rw_lk_expression_combination() { name.as_str(), &prover.pk.pp, prover.pk.circuit_pks.get(&name).unwrap(), - &None, + None, wits_in, commit, &[], diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 91706e037..cdaae5cdc 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -50,7 +50,7 @@ pub struct TowerProverSpec<'a, E: ExtensionField> { pub witness: Vec>>, } -pub type WitnessId = u16; +pub type WitnessId = u32; pub type ChallengeId = u16; #[derive(Copy, Clone, Debug, EnumIter, PartialEq, Eq, Hash)] diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 4a45a0a5e..d413f017b 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -225,7 +225,7 @@ pub fn add_mle_list_by_expr<'a, E: ExtensionField>( challenges: &[E], // sumcheck batch challenge alpha: E, -) -> BTreeSet { +) -> BTreeSet { assert!(expr.is_monomial_form()); let monomial_terms = expr.evaluate( &|_| unreachable!(), @@ -287,7 +287,7 @@ pub fn add_mle_list_by_expr<'a, E: ExtensionField>( monomial_terms .into_iter() .flat_map(|(_, monomial_term)| monomial_term.into_iter().collect_vec()) - .collect::>() + .collect::>() } #[cfg(test)] diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 6591137b9..3e35ff09c 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -345,8 +345,10 @@ impl ProtocolBuilder for KeccakLayout { } fn build_commit_phase(&mut self, chip: &mut Chip) { - let input_columns: [usize; KECCAK_INPUT_SIZE] = chip.allocate_committed_base(); - self.input_columns = input_columns.into(); + //let input_columns: [usize; KECCAK_INPUT_SIZE] = + // chip.allocate_committed_base(); self.input_columns = + // input_columns.into(); + let committed = chip.allocate_committed_base::<{ 50 + 40144 }>(); } fn build_gkr_phase(&mut self, chip: &mut Chip) { @@ -1074,8 +1076,6 @@ where ); } - dbg!(layer_wits.len()); - let len = layer_wits.len() - 1; layer_wits[..len].reverse(); From 039465bfe9efe77dd9da10b5b142117059ce04b4 Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 18 Apr 2025 12:10:12 +0300 Subject: [PATCH 09/17] debug evals out of bounds --- ceno_zkvm/src/scheme/prover.rs | 20 +++++++++++++++++--- examples/examples/ceno_rt_keccak.rs | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c7a8ecf6b..c8bea77c9 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,12 +1,12 @@ use ceno_emul::KeccakSpec; use ff_ext::ExtensionField; -use gkr_iop::gkr::GKRCircuitWitness; +use gkr_iop::{evaluation::PointAndEval, gkr::GKRCircuitWitness}; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, sync::Arc, }; -use itertools::{Itertools, enumerate, izip}; +use itertools::{Itertools, chain, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, @@ -706,8 +706,21 @@ impl> ZKVMProver { let mut gkr_iop_pk = gkr_iop_pk.clone(); let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit(); + let point = Arc::new(input_open_point); + dbg!(&point); // let mut prover_transcript = transcript::BasicTranscript::::new(b"protocol"); - let out_evals = vec![]; + let out_evals = chain!( + r_records_in_evals.clone(), + w_records_in_evals.clone(), + lk_records_in_evals.clone() + ) + .map(|record| PointAndEval { + point: point.clone(), + eval: record, + }) + .collect_vec(); + + // out_evals from point and output polynomials instead of *_records which is combined let gkr_iop::gkr::GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove(gkr_wit, &out_evals, &vec![], transcript) @@ -715,6 +728,7 @@ impl> ZKVMProver { // unimplemented!("cannot fully handle GKRIOP component yet") } + // extend with Optio(gkr evals (not combined)) Ok(ZKVMOpcodeProof { num_instances, record_r_out_evals, diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index 92173d9c3..4d09e4d89 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -6,7 +6,7 @@ extern crate ceno_rt; use ceno_rt::{info_out, syscalls::syscall_keccak_permute}; use core::slice; -const ITERATIONS: usize = 3; +const ITERATIONS: usize = 8; fn main() { let mut state = [0_u64; 25]; From 08b3f5dc8cffab59fa62a3f2aa086ec0656383c5 Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 18 Apr 2025 13:51:05 +0300 Subject: [PATCH 10/17] weird verify behavior for wrong evals --- gkr_iop/src/precompiles/faster_keccak.rs | 45 ++++++++++-------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 3e35ff09c..303bafe20 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -1113,9 +1113,8 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &vec![]); let out_evals = { - let point1 = Arc::new(vec![E::ZERO]); - let point2 = Arc::new(vec![E::ONE]); - let final_output1 = gkr_witness + let point = Arc::new(vec![E::from_u64(29)]); + let output_records1 = gkr_witness .layers .last() .unwrap() @@ -1124,7 +1123,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( .iter() .map(|base| base[0].clone()) .collect_vec(); - let final_output2 = gkr_witness + let output_records2 = gkr_witness .layers .last() .unwrap() @@ -1134,36 +1133,27 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( .map(|base| base[1].clone()) .collect_vec(); - // if test { - // // confront outputs with tiny_keccak result - // let mut keccak_output64 = state_mask64.values().try_into().unwrap(); - // keccakf(&mut keccak_output64); - - // let keccak_output32 = MaskRepresentation::from( - // keccak_output64.into_iter().map(|e| (64, e)).collect_vec(), - // ) - // .convert(vec![32; 50]) - // .values(); - - // let keccak_output32 = u64s_to_felts::(keccak_output32); - // assert_eq!(keccak_output32, final_output[..50]); - // } + let evals = zip(&output_records1, &output_records2) + .map(|(p0, p1)| { + let poly = vec![p0.clone(), p1.clone()]; + subprotocols::utils::evaluate_mle_ext(&poly, &point) + }) + .collect_vec(); - let len = final_output1.len(); + dbg!(&evals); + let len = output_records1.len(); assert_eq!( len, KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS ); - assert!(final_output1.len() == final_output2.len()); - let gkr_outputs = chain!( - zip(final_output1, once(point1).cycle().take(len)), - zip(final_output2, once(point2).cycle().take(len)) - ); - gkr_outputs + assert!(output_records1.len() == output_records2.len()); + + evals .into_iter() - .map(|(elem, point)| PointAndEval { + .map(|elem| PointAndEval { point: point.clone(), - eval: E::from_bases(&[elem, Goldilocks::ZERO]), + //eval: elem, + eval: E::ONE, }) .collect_vec() }; @@ -1176,6 +1166,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( if verify { { + dbg!("sanity"); let mut verifier_transcript = BasicTranscript::::new(b"protocol"); gkr_circuit From 4cb0d9582e18125161c44863e8954215277f8abb Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 18 Apr 2025 14:12:29 +0300 Subject: [PATCH 11/17] fix expr_names bug in verifier --- gkr_iop/src/precompiles/faster_keccak.rs | 4 ++-- subprotocols/src/sumcheck.rs | 3 +++ subprotocols/src/zerocheck.rs | 7 +++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 303bafe20..8afadfc3e 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -1152,8 +1152,8 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( .into_iter() .map(|elem| PointAndEval { point: point.clone(), - //eval: elem, - eval: E::ONE, + eval: elem, + //eval: E::ONE, }) .collect_vec() }; diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs index 3b1ebc079..0541f6ad4 100644 --- a/subprotocols/src/sumcheck.rs +++ b/subprotocols/src/sumcheck.rs @@ -205,6 +205,9 @@ where transcript: &'a mut Trans, expr_names: Vec, ) -> Self { + // Fill in missing debug data + let mut expr_names = expr_names; + expr_names.resize(1, "nothing".to_owned()); Self { sigma, expr, diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs index aad339b5a..3b16aded9 100644 --- a/subprotocols/src/zerocheck.rs +++ b/subprotocols/src/zerocheck.rs @@ -215,6 +215,10 @@ where challenges: &'a [E], transcript: &'a mut Trans, ) -> Self { + // Fill in missing debug data + let mut expr_names = expr_names; + expr_names.resize(exprs.len(), "nothing".to_owned()); + let inv_of_one_minus_points = points .iter() .flat_map(|point| point.iter().map(|p| E::ONE - *p).collect_vec()) @@ -294,6 +298,9 @@ where ); // Check the final evaluations. + assert_eq!(expr_names.len(), exprs.len()); + // assert_eq!(expected_claims.len(), expr_names.len()); + for (expected_claim, (expr, _), expr_name) in izip!(expected_claims, exprs, expr_names) { let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); From b398a1455b230618f77bd4571b3d081e2cf87abd Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 18 Apr 2025 17:22:27 +0300 Subject: [PATCH 12/17] debugging verifier --- Cargo.lock | 1 + ceno_zkvm/Cargo.toml | 1 + ceno_zkvm/src/scheme.rs | 2 ++ ceno_zkvm/src/scheme/prover.rs | 42 ++++++++++++++++++++--------- ceno_zkvm/src/scheme/verifier.rs | 19 +++++++++++++ examples/examples/ceno_rt_keccak.rs | 2 +- 6 files changed, 53 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 16d310661..6d0f99745 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -598,6 +598,7 @@ dependencies = [ "serde_json", "strum", "strum_macros", + "subprotocols", "sumcheck", "tempfile", "thread_local", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 2446aa9cc..5fee3d3c5 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -23,6 +23,7 @@ gkr_iop = { path = "../gkr_iop" } mpcs = { path = "../mpcs" } multilinear_extensions = { version = "0", path = "../multilinear_extensions" } p3 = { path = "../p3" } +subprotocols = {path = "../subprotocols"} sumcheck = { version = "0", path = "../sumcheck" } transcript = { path = "../transcript" } witness = { path = "../witness" } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 4571fec73..2a875820f 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -50,6 +50,8 @@ pub struct ZKVMOpcodeProof pub wits_commit: PCS::Commitment, pub wits_opening_proof: PCS::Proof, pub wits_in_evals: Vec, + + pub gkr_out_evals: Option>, } #[derive(Clone, Serialize, Deserialize)] diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c8bea77c9..4015be756 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -702,23 +702,35 @@ impl> ZKVMProver { exit_span!(pcs_open_span); let wits_commit = PCS::get_pure_commitment(&wits_commit); - if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { + let gkr_out_evals = if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { let mut gkr_iop_pk = gkr_iop_pk.clone(); let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit(); let point = Arc::new(input_open_point); dbg!(&point); - // let mut prover_transcript = transcript::BasicTranscript::::new(b"protocol"); - let out_evals = chain!( - r_records_in_evals.clone(), - w_records_in_evals.clone(), - lk_records_in_evals.clone() - ) - .map(|record| PointAndEval { - point: point.clone(), - eval: record, - }) - .collect_vec(); + // // let mut prover_transcript = transcript::BasicTranscript::::new(b"protocol"); + // let out_evals = chain!( + // r_records_in_evals.clone(), + // w_records_in_evals.clone(), + // lk_records_in_evals.clone() + // ) + // .map(|record| PointAndEval { + // point: point.clone(), + // eval: record, + // }) + // .collect_vec(); + + let out_evals = gkr_wit + .layers + .last() + .unwrap() + .bases + .iter() + .map(|base| PointAndEval { + point: point.clone(), + eval: subprotocols::utils::evaluate_mle_ext(&base, &point), + }) + .collect_vec(); // out_evals from point and output polynomials instead of *_records which is combined @@ -726,7 +738,10 @@ impl> ZKVMProver { .prove(gkr_wit, &out_evals, &vec![], transcript) .expect("Failed to prove phase"); // unimplemented!("cannot fully handle GKRIOP component yet") - } + Some(out_evals.into_iter().map(|pae| pae.eval).collect_vec()) + } else { + None + }; // extend with Optio(gkr evals (not combined)) Ok(ZKVMOpcodeProof { @@ -745,6 +760,7 @@ impl> ZKVMProver { wits_commit, wits_opening_proof, wits_in_evals, + gkr_out_evals, }) } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 53178f0aa..fd773aa4c 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -307,6 +307,25 @@ impl> ZKVMVerifier num_product_fanin, transcript, )?; + + if let Some(evals) = &proof.gkr_out_evals { + let gkr_reads = evals.iter().take(50).collect_vec(); + let intersection_size = gkr_reads + .iter() + .counts() + .iter() + .map(|(val, count)| { + let proof_count = proof + .r_records_in_evals + .iter() + .filter(|x| x == *val) + .count(); + *count.min(&proof_count) + }) + .sum::(); + dbg!(&intersection_size); + panic!(); + } assert!(record_evals.len() == 2, "[r_record, w_record]"); assert!(logup_q_evals.len() == 1, "[lk_q_record]"); assert!(logup_p_evals.len() == 1, "[lk_p_record]"); diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index 4d09e4d89..135e8d9b0 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -6,7 +6,7 @@ extern crate ceno_rt; use ceno_rt::{info_out, syscalls::syscall_keccak_permute}; use core::slice; -const ITERATIONS: usize = 8; +const ITERATIONS: usize = 2; fn main() { let mut state = [0_u64; 25]; From 34956c917662ab93f8fd11f4bf44986089b1a171 Mon Sep 17 00:00:00 2001 From: Mihai Date: Wed, 23 Apr 2025 13:00:21 +0300 Subject: [PATCH 13/17] debug verifier --- ceno_zkvm/src/scheme/verifier.rs | 43 +++++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index fd773aa4c..5bfdc682a 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -273,6 +273,9 @@ impl> ZKVMVerifier cs.w_expressions.len(), cs.lk_expressions.len(), ); + + // dbg!(&cs.r_expressions); + let (log2_r_count, log2_w_count, log2_lk_count) = ( ceil_log2(r_counts_per_instance), ceil_log2(w_counts_per_instance), @@ -308,24 +311,6 @@ impl> ZKVMVerifier transcript, )?; - if let Some(evals) = &proof.gkr_out_evals { - let gkr_reads = evals.iter().take(50).collect_vec(); - let intersection_size = gkr_reads - .iter() - .counts() - .iter() - .map(|(val, count)| { - let proof_count = proof - .r_records_in_evals - .iter() - .filter(|x| x == *val) - .count(); - *count.min(&proof_count) - }) - .sum::(); - dbg!(&intersection_size); - panic!(); - } assert!(record_evals.len() == 2, "[r_record, w_record]"); assert!(logup_q_evals.len() == 1, "[lk_q_record]"); assert!(logup_p_evals.len() == 1, "[lk_p_record]"); @@ -345,6 +330,8 @@ impl> ZKVMVerifier logup_q_evals[0].point.clone(), ); + // dbg!(&rt_r, &rt_w, &rt_lk); + let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -421,6 +408,26 @@ impl> ZKVMVerifier ) }; + if let Some(evals) = &proof.gkr_out_evals { + let gkr_reads = evals.iter().take(50).collect_vec(); + let intersection_size = gkr_reads + .iter() + .counts() + .iter() + .map(|(val, count)| { + let proof_count = proof + .r_records_in_evals + .iter() + .enumerate() + .filter(|(i, x)| **x * eq_r[*i] == ***val) + .count(); + *count.min(&proof_count) + }) + .sum::(); + dbg!(&intersection_size); + panic!(); + } + let computed_evals = [ // read *alpha_read From 8fcd809b49697b592bb586576697543a60a281b3 Mon Sep 17 00:00:00 2001 From: Mihai Date: Wed, 23 Apr 2025 18:21:21 +0300 Subject: [PATCH 14/17] refactor & clean-up --- ceno_emul/src/syscalls.rs | 2 + ceno_zkvm/src/e2e.rs | 16 +-- ceno_zkvm/src/instructions.rs | 55 +++++--- .../instructions/riscv/dummy/dummy_ecall.rs | 133 +++++++++--------- ceno_zkvm/src/instructions/riscv/rv32im.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 3 +- ceno_zkvm/src/scheme/verifier.rs | 9 +- ceno_zkvm/src/structs.rs | 26 ++-- gkr_iop/src/precompiles/faster_keccak.rs | 64 +++++---- gkr_iop/src/precompiles/mod.rs | 5 +- 10 files changed, 175 insertions(+), 140 deletions(-) diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 363402243..1d6cc7df8 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -22,6 +22,8 @@ pub trait SyscallSpec { const CODE: u32; const HAS_LOOKUPS: bool = false; + + const GKR_OUTPUTS: usize = 0; } /// Trace the inputs and effects of a syscall. diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 48720a634..4f0a2dee3 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -295,7 +295,7 @@ fn generate_fixed_traces( } pub fn generate_witness( - system_config: &mut ConstraintSystemConfig, + system_config: &ConstraintSystemConfig, emul_result: EmulationResult, program: &Program, is_mock_proving: bool, @@ -312,7 +312,7 @@ pub fn generate_witness( .unwrap(); system_config .dummy_config - .assign_opcode_circuit(&mut system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) .unwrap(); zkvm_witness.finalize_lk_multiplicities(is_mock_proving); @@ -394,7 +394,7 @@ pub fn run_e2e_with_checkpoint< }; let program = Arc::new(program); - let mut system_config = construct_configs::(program_params); + let system_config = construct_configs::(program_params); let reg_init = system_config.mmu_config.initial_registers(); // IO is not used in this program, but it must have a particular size at the moment. @@ -432,7 +432,7 @@ pub fn run_e2e_with_checkpoint< init_full_mem, platform, hints, - &mut system_config, + &system_config, pk, zkvm_fixed_traces, is_mock_proving, @@ -451,13 +451,11 @@ pub fn run_e2e_with_checkpoint< if let Checkpoint::PrepWitnessGen = checkpoint { return ( None, - Box::new(move || { - _ = generate_witness(&mut system_config, emul_result, &program, false) - }), + Box::new(move || _ = generate_witness(&system_config, emul_result, &program, false)), ); } - let zkvm_witness = generate_witness(&mut system_config, emul_result, &program, is_mock_proving); + let zkvm_witness = generate_witness(&system_config, emul_result, &program, is_mock_proving); // proving let prover = ZKVMProver::new(pk); @@ -498,7 +496,7 @@ pub fn run_e2e_proof, - system_config: &mut ConstraintSystemConfig, + system_config: &ConstraintSystemConfig, pk: ZKVMProvingKey, zkvm_fixed_traces: ZKVMFixedTraces, is_mock_proving: bool, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 8448b898f..fdece9214 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -74,11 +74,39 @@ pub trait Instruction { } } +pub struct GKRinfo { + pub and_lookups: usize, + pub xor_lookups: usize, + pub range_lookups: usize, + pub aux_wits: usize, +} + +impl GKRinfo { + fn lookup_total(&self) -> usize { + self.and_lookups + self.xor_lookups + self.range_lookups + } +} + pub trait GKRIOPInstruction where Self: Instruction, { type Layout: ProtocolWitnessGenerator + ProtocolBuilder; + + fn construct_circuit_with_gkr_iop( + cb: &mut CircuitBuilder, + ) -> Result { + unimplemented!(); + } + + fn phase1_witness_from_steps( + layout: &Self::Layout, + steps: &[StepRecord], + ) -> Vec>; + + // Number of lookup arguments used by this GKR proof + fn gkr_info() -> GKRinfo; + fn assign_instance_with_gkr_iop( config: &Self::InstructionConfig, instance: &mut [E::BaseField], @@ -88,16 +116,11 @@ where aux_wits: &[E::BaseField], ) -> Result<(), ZKVMError>; - fn phase1_witness_from_steps( - layout: &Self::Layout, - steps: &[StepRecord], - ) -> Vec>; - fn assign_instances_with_gkr_iop( config: &Self::InstructionConfig, num_witin: usize, steps: Vec, - gkr_layout: &mut Self::Layout, + gkr_layout: &Self::Layout, ) -> Result< ( RowMajorMatrix, @@ -124,17 +147,14 @@ where ); let (lookups, aux_wits) = { - // Keccak-specific + // Extract lookups and auxiliary witnesses from GKR protocol + // Here we assume that the gkr_witness's last layer is a convenience layer which holds + // the output records for all instances; further, we assume that the last ```Self::lookup_count()``` + // elements of this layer are the lookup arguments. let mut lookups = vec![vec![]; steps.len()]; - for witness in gkr_witness - .layers - .last() - .unwrap() - .bases - .iter() - .skip(100) - .cloned() - { + let last_layer = gkr_witness.layers.last().unwrap().bases.clone(); + let len = last_layer.len(); + for witness in last_layer[len - Self::gkr_info().lookup_total()..].iter() { for i in 0..witness.len() { lookups[i].push(witness[i]); } @@ -144,6 +164,7 @@ where let n_layers = gkr_witness.layers.len(); for i in 0..steps.len() { + // Omit last layer, which stores outputs and not real witnesses for layer in gkr_witness.layers[..n_layers - 1].iter() { for base in layer.bases.iter() { aux_wits[i].push(base[i]); @@ -154,8 +175,6 @@ where (lookups, aux_wits) }; - // dbg!(&lookups); - raw_witin_iter .zip( steps diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 401853718..fd5a7d079 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -11,7 +11,7 @@ use crate::{ error::ZKVMError, expression::{ToExpr, WitIn}, instructions::{ - GKRIOPInstruction, Instruction, + GKRIOPInstruction, GKRinfo, Instruction, riscv::{constants::UInt, insn_base::WriteRD}, }, set_val, @@ -21,10 +21,7 @@ use ff_ext::FieldInto; use gkr_iop::{ ProtocolWitnessGenerator, - precompiles::{ - AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakTrace, RANGE_LOOKUPS_PER_ROUND, - XOR_LOOKUPS_PER_ROUND, - }, + precompiles::{AND_LOOKUPS, KeccakLayout, KeccakTrace, RANGE_LOOKUPS, XOR_LOOKUPS}, }; /// LargeEcallDummy can handle any instruction and produce its effects, @@ -37,7 +34,7 @@ impl Instruction for LargeEcallDummy type InstructionConfig = LargeEcallConfig; fn name() -> String { - format!("{}_DUMMY", S::NAME) + S::NAME.to_owned() } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let dummy_insn = DummyConfig::construct_circuit( @@ -78,45 +75,9 @@ impl Instruction for LargeEcallDummy }) .collect::, _>>()?; - // Temporarily set this to < 24 to avoid cb.num_witin overflow - let active_rounds = 24; - - let mut lookups = Vec::with_capacity( - active_rounds - * (3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND), - ); - - let mut aux_wits = vec![]; - - if S::HAS_LOOKUPS { - dbg!(lookups.capacity()); - - for round in 0..active_rounds { - for i in 0..AND_LOOKUPS_PER_ROUND { - let a = cb.create_witin(|| format!("and_lookup_{round}_{i}_a")); - let b = cb.create_witin(|| format!("and_lookup_{round}_{i}_b")); - let c = cb.create_witin(|| format!("and_lookup_{round}_{i}_c")); - cb.lookup_and_byte(a.into(), b.into(), c.into())?; - lookups.extend(vec![a, b, c]); - } - for i in 0..XOR_LOOKUPS_PER_ROUND { - let a = cb.create_witin(|| format!("xor_lookup_{round}_{i}_a")); - let b = cb.create_witin(|| format!("xor_lookup_{round}_{i}_b")); - let c = cb.create_witin(|| format!("xor_lookup_{round}_{i}_c")); - cb.lookup_xor_byte(a.into(), b.into(), c.into())?; - lookups.extend(vec![a, b, c]); - } - for i in 0..RANGE_LOOKUPS_PER_ROUND { - let wit = cb.create_witin(|| format!("range_lookup_{round}_{i}")); - cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; - lookups.push(wit); - } - } - - for i in 0..40144 { - aux_wits.push(cb.create_witin(|| format!("aux_wit{i}"))); - } - } + // Will be filled in by GKR Instruction trait + let lookups = vec![]; + let aux_wits = vec![]; Ok(LargeEcallConfig { dummy_insn, @@ -164,6 +125,57 @@ impl Instruction for LargeEcallDummy impl GKRIOPInstruction for LargeEcallDummy { type Layout = KeccakLayout; + fn gkr_info() -> crate::instructions::GKRinfo { + GKRinfo { + and_lookups: 3 * AND_LOOKUPS, + xor_lookups: 3 * XOR_LOOKUPS, + range_lookups: RANGE_LOOKUPS, + aux_wits: 40144, + } + } + + fn construct_circuit_with_gkr_iop( + cb: &mut CircuitBuilder, + ) -> Result { + let mut partial_config = Self::construct_circuit(cb)?; + + assert!(partial_config.lookups.is_empty()); + assert!(partial_config.aux_wits.is_empty()); + + // TODO: capacity + let mut lookups = vec![]; + let mut aux_wits = vec![]; + + for i in 0..AND_LOOKUPS { + let a = cb.create_witin(|| format!("and_lookup_{i}_a")); + let b = cb.create_witin(|| format!("and_lookup_{i}_b")); + let c = cb.create_witin(|| format!("and_lookup_{i}_c")); + cb.lookup_and_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..XOR_LOOKUPS { + let a = cb.create_witin(|| format!("xor_lookup_{i}_a")); + let b = cb.create_witin(|| format!("xor_lookup_{i}_b")); + let c = cb.create_witin(|| format!("xor_lookup_{i}_c")); + cb.lookup_xor_byte(a.into(), b.into(), c.into())?; + lookups.extend(vec![a, b, c]); + } + for i in 0..RANGE_LOOKUPS { + let wit = cb.create_witin(|| format!("range_lookup_{i}")); + cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; + lookups.push(wit); + } + + for i in 0..40144 { + aux_wits.push(cb.create_witin(|| format!("aux_wit{i}"))); + } + + partial_config.lookups = lookups; + partial_config.aux_wits = aux_wits; + + Ok(partial_config) + } + fn phase1_witness_from_steps( layout: &Self::Layout, steps: &[StepRecord], @@ -195,7 +207,6 @@ impl GKRIOPInstruction for LargeEcallDummy ) -> Result<(), ZKVMError> { Self::assign_instance(config, instance, lk_multiplicity, step)?; - let active_rounds = 24; let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); let mut var_iter = config.lookups.iter(); @@ -206,29 +217,19 @@ impl GKRIOPInstruction for LargeEcallDummy let wit = wit_iter.next().unwrap(); let var = var_iter.next().unwrap(); set_val!(instance, var, wit); - // set_val!(instance, var, 0); wit }; - for _round in 0..active_rounds { - for _i in 0..AND_LOOKUPS_PER_ROUND { - lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); - // lk_multiplicity.lookup_and_byte(0, 0); - // pop_arg(); - // pop_arg(); - pop_arg(); - } - for _i in 0..XOR_LOOKUPS_PER_ROUND { - lk_multiplicity.lookup_xor_byte(pop_arg(), pop_arg()); - // lk_multiplicity.lookup_xor_byte(0, 0); - // pop_arg(); - // pop_arg(); - pop_arg(); - } - for _i in 0..RANGE_LOOKUPS_PER_ROUND { - lk_multiplicity.assert_ux::<16>(pop_arg()); - // pop_arg(); - } + for _i in 0..AND_LOOKUPS { + lk_multiplicity.lookup_and_byte(pop_arg(), pop_arg()); + pop_arg(); + } + for _i in 0..XOR_LOOKUPS { + lk_multiplicity.lookup_xor_byte(pop_arg(), pop_arg()); + pop_arg(); + } + for _i in 0..RANGE_LOOKUPS { + lk_multiplicity.assert_ux::<16>(pop_arg()); } dbg!(aux_wits.len()); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 6d279ca18..bc8dbc02a 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -526,7 +526,7 @@ impl DummyExtraConfig { pub fn assign_opcode_circuit( &self, - cs: &mut ZKVMConstraintSystem, + cs: &ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, steps: GroupedSteps, ) -> Result<(), ZKVMError> { diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 4015be756..39418b2b7 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -481,7 +481,8 @@ impl> ZKVMProver { ); } - dbg!(sel_r.len()); + // dbg!(sel_r); + // panic!(); let mut sel_w = build_eq_x_r_vec(&rt_w[log2_w_count..]); if num_instances < sel_w.len() { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 5bfdc682a..7e14aec51 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -409,8 +409,9 @@ impl> ZKVMVerifier }; if let Some(evals) = &proof.gkr_out_evals { - let gkr_reads = evals.iter().take(50).collect_vec(); - let intersection_size = gkr_reads + dbg!(&sel_r); + // let gkr_reads = evals.iter().take(50).collect_vec(); + let intersection_size = evals .iter() .counts() .iter() @@ -419,13 +420,13 @@ impl> ZKVMVerifier .r_records_in_evals .iter() .enumerate() - .filter(|(i, x)| **x * eq_r[*i] == ***val) + .filter(|(i, x)| *alpha_read * **val == **x) .count(); *count.min(&proof_count) }) .sum::(); dbg!(&intersection_size); - panic!(); + // panic!(); } let computed_evals = [ diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index cdaae5cdc..e41486ed4 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -7,13 +7,9 @@ use crate::{ tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, }; -use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord}; +use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord, SyscallSpec}; use ff_ext::ExtensionField; -use gkr_iop::{ - ProtocolWitnessGenerator, - gkr::GKRCircuitWitness, - precompiles::{KeccakLayout, KeccakTrace}, -}; +use gkr_iop::{gkr::GKRCircuitWitness, precompiles::KeccakLayout}; use itertools::{Itertools, chain}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ @@ -264,7 +260,19 @@ impl ZKVMConstraintSystem { let (layout, chip) = as gkr_iop::ProtocolBuilder>::build(params); self.keccak_gkr_iop = KeccakGKRIOP { layout, chip }; - self.register_opcode_circuit::>() + let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", KeccakSpec::NAME)); + let mut circuit_builder = + CircuitBuilder::::new_with_params(&mut cs, self.params.clone()); + let config = + LargeEcallDummy::::construct_circuit_with_gkr_iop(&mut circuit_builder) + .unwrap(); + assert!( + self.circuit_css + .insert(KeccakSpec::NAME.to_owned(), cs) + .is_none() + ); + + config } pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { @@ -367,7 +375,7 @@ impl ZKVMWitnesses { pub fn assign_keccakf_circuit( &mut self, // TODO: do without mutability requirement - css: &mut ZKVMConstraintSystem, + css: &ZKVMConstraintSystem, config: & as Instruction>::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -380,7 +388,7 @@ impl ZKVMWitnesses { config, cs.num_witin as usize, records, - &mut css.keccak_gkr_iop.layout, + &css.keccak_gkr_iop.layout, )?; self.keccak_gkr_wit = gkr_witness; diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 8afadfc3e..407acbc99 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -313,8 +313,6 @@ pub const KECCAK_OUTPUT_SIZE: usize = 50; pub const AND_LOOKUPS_PER_ROUND: usize = 200; pub const XOR_LOOKUPS_PER_ROUND: usize = 608; pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; -pub const LOOKUPS_PER_ROUND: usize = - AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; pub const LOOKUP_FELTS_PER_ROUND: usize = 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; @@ -378,7 +376,9 @@ impl ProtocolBuilder for KeccakLayout { let mut evals = vec![]; let mut expr_names = vec![]; - let mut lookup_index = 0; + let mut global_and_lookup = 0; + let mut global_xor_lookup = 3 * AND_LOOKUPS; + let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; for x in 0..5 { for y in 0..5 { @@ -657,25 +657,21 @@ impl ProtocolBuilder for KeccakLayout { }) .count(); - // TODO: use real challenge - let alpha = Constant::Base(1 << 8); - let beta = Constant::Base(0); - - // Send all lookups to the final output layer - // for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) - // .flatten() - // .enumerate() - // { - // expressions.push(lookup.clone()); for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) .flatten() .enumerate() { expressions.push(lookup); - //expressions.push(lookup.compress(alpha.clone(), beta.clone())); - expr_names.push(format!("{i}th: aha")); - evals.push(lookup_outputs[lookup_index].clone()); - lookup_index += 1; + expr_names.push(format!("round {round}: {i}th lookup felt")); + let idx = if i < 3 * AND_LOOKUPS_PER_ROUND { + &mut global_and_lookup + } else if i < 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND { + &mut global_xor_lookup + } else { + &mut global_range_lookup + }; + evals.push(lookup_outputs[*idx].clone()); + *idx += 1; } chip.add_layer(Layer::new( @@ -698,7 +694,9 @@ impl ProtocolBuilder for KeccakLayout { }, ); - assert!(lookup_index == LOOKUP_FELTS_PER_ROUND * ROUNDS); + assert!(global_and_lookup == 3 * AND_LOOKUPS); + assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); + assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); @@ -1054,17 +1052,24 @@ where // For temporary convenience, use one extra layer to store the correct outputs // of the circuit This is not used during proving - let lookups = (0..24) - .into_iter() - .rev() - .flat_map(|i| { - chain!( - and_lookups[i].clone(), - xor_lookups[i].clone(), - range_lookups[i].clone() - ) - }) - .collect_vec(); + let lookups = chain!( + (0..ROUNDS) + .into_iter() + .rev() + .map(|i| and_lookups[i].clone()) + .flatten(), + (0..ROUNDS) + .into_iter() + .rev() + .map(|i| xor_lookups[i].clone()) + .flatten(), + (0..ROUNDS) + .into_iter() + .rev() + .map(|i| range_lookups[i].clone()) + .flatten() + ) + .collect_vec(); push_instance( chain!( @@ -1140,7 +1145,6 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( }) .collect_vec(); - dbg!(&evals); let len = output_records1.len(); assert_eq!( len, diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index 2d840460d..cc17ee536 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -4,6 +4,7 @@ mod utils; pub use keccak_f::run_keccakf; pub use { faster_keccak::run_faster_keccakf, faster_keccak::KeccakLayout, faster_keccak::KeccakParams, - faster_keccak::KeccakTrace, faster_keccak::AND_LOOKUPS_PER_ROUND, - faster_keccak::RANGE_LOOKUPS_PER_ROUND, faster_keccak::XOR_LOOKUPS_PER_ROUND, + faster_keccak::KeccakTrace, faster_keccak::AND_LOOKUPS, faster_keccak::AND_LOOKUPS_PER_ROUND, + faster_keccak::RANGE_LOOKUPS, faster_keccak::RANGE_LOOKUPS_PER_ROUND, + faster_keccak::XOR_LOOKUPS, faster_keccak::XOR_LOOKUPS_PER_ROUND, }; From 29e38a4b6073f1a3eaa8dceb14b3d1484f469f6f Mon Sep 17 00:00:00 2001 From: Mihai Date: Thu, 24 Apr 2025 10:33:58 +0300 Subject: [PATCH 15/17] back to the verifier --- Cargo.lock | 2 + Cargo.toml | 2 +- ceno_zkvm/src/instructions.rs | 5 ++ .../instructions/riscv/dummy/dummy_ecall.rs | 15 +++++ ceno_zkvm/src/scheme.rs | 12 +++- ceno_zkvm/src/scheme/prover.rs | 36 +++++------ ceno_zkvm/src/scheme/verifier.rs | 64 +++++++++++++------ gkr_iop/Cargo.toml | 1 + gkr_iop/src/evaluation.rs | 15 +++-- gkr_iop/src/gkr.rs | 6 +- gkr_iop/src/gkr/layer.rs | 5 +- gkr_iop/src/precompiles/faster_keccak.rs | 17 ++++- subprotocols/Cargo.toml | 1 + subprotocols/src/expression.rs | 7 +- subprotocols/src/sumcheck.rs | 2 + 15 files changed, 137 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6d0f99745..4eb8ac199 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1183,6 +1183,7 @@ dependencies = [ "p3-goldilocks", "rand", "rayon", + "serde", "subprotocols", "thiserror 1.0.69", "tiny-keccak", @@ -2735,6 +2736,7 @@ dependencies = [ "p3-goldilocks", "rand", "rayon", + "serde", "thiserror 1.0.69", "transcript", ] diff --git a/Cargo.toml b/Cargo.toml index 560c75a18..608af8c6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ rand_core = "0.6" rand_xorshift = "0.3" rayon = "1.10" secp = "0.4.1" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" strum = "0.26" strum_macros = "0.26" diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index fdece9214..839026fcd 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -99,6 +99,11 @@ where unimplemented!(); } + // Returns index of `i-th` GKR-IOP output-eval in PCS + fn output_map(i: usize) -> usize { + unimplemented!(); + } + fn phase1_witness_from_steps( layout: &Self::Layout, steps: &[StepRecord], diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index fd5a7d079..e217cd152 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -112,6 +112,7 @@ impl Instruction for LargeEcallDummy // Assign memory. for ((addr, value, writer), op) in config.mem_writes.iter().zip_eq(&ops.mem_ops) { + dbg!(&op.value.before); set_val!(instance, value.before, op.value.before as u64); set_val!(instance, value.after, op.value.after as u64); set_val!(instance, addr, u64::from(op.addr)); @@ -173,9 +174,23 @@ impl GKRIOPInstruction for LargeEcallDummy partial_config.lookups = lookups; partial_config.aux_wits = aux_wits; + dbg!(&partial_config.mem_writes[0].1.after); + dbg!(&partial_config.mem_writes[1].1.before); + dbg!(&partial_config.lookups[0]); + Ok(partial_config) } + fn output_map(i: usize) -> usize { + if i < 50 { + 27 + 6 * i + } else if i < 100 { + 26 + 6 * (i - 50) + } else { + 326 + i - 100 + } + } + fn phase1_witness_from_steps( layout: &Self::Layout, steps: &[StepRecord], diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 2a875820f..400e60963 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use gkr_iop::gkr::{Evaluation, GKRCircuit, GKRProverOutput}; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; use p3::field::PrimeCharacteristicRing; @@ -6,6 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::BTreeMap, fmt::{self, Debug}, + marker::PhantomData, ops::Div, }; use sumcheck::structs::IOPProverMessage; @@ -51,7 +53,15 @@ pub struct ZKVMOpcodeProof pub wits_opening_proof: PCS::Proof, pub wits_in_evals: Vec, - pub gkr_out_evals: Option>, + pub gkr_opcode_proof: Option>, +} + +#[derive(Clone, Serialize)] +pub struct GKROpcodeProof> { + output_evals: Vec, + prover_output: GKRProverOutput>, + circuit: GKRCircuit, + _marker: PhantomData, } #[derive(Clone, Serialize, Deserialize)] diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 39418b2b7..4a1cffd92 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -3,6 +3,7 @@ use ff_ext::ExtensionField; use gkr_iop::{evaluation::PointAndEval, gkr::GKRCircuitWitness}; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, + marker::PhantomData, sync::Arc, }; @@ -31,6 +32,7 @@ use crate::{ riscv::{dummy::LargeEcallDummy, ecall::EcallDummy}, }, scheme::{ + GKROpcodeProof, constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, @@ -481,9 +483,6 @@ impl> ZKVMProver { ); } - // dbg!(sel_r); - // panic!(); - let mut sel_w = build_eq_x_r_vec(&rt_w[log2_w_count..]); if num_instances < sel_w.len() { sel_w.splice( @@ -703,23 +702,11 @@ impl> ZKVMProver { exit_span!(pcs_open_span); let wits_commit = PCS::get_pure_commitment(&wits_commit); - let gkr_out_evals = if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { - let mut gkr_iop_pk = gkr_iop_pk.clone(); + let gkr_opcode_proof = if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { + let gkr_iop_pk = gkr_iop_pk.clone(); let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit(); let point = Arc::new(input_open_point); - dbg!(&point); - // // let mut prover_transcript = transcript::BasicTranscript::::new(b"protocol"); - // let out_evals = chain!( - // r_records_in_evals.clone(), - // w_records_in_evals.clone(), - // lk_records_in_evals.clone() - // ) - // .map(|record| PointAndEval { - // point: point.clone(), - // eval: record, - // }) - // .collect_vec(); let out_evals = gkr_wit .layers @@ -735,11 +722,20 @@ impl> ZKVMProver { // out_evals from point and output polynomials instead of *_records which is combined - let gkr_iop::gkr::GKRProverOutput { gkr_proof, .. } = gkr_circuit + let prover_output = gkr_circuit .prove(gkr_wit, &out_evals, &vec![], transcript) .expect("Failed to prove phase"); // unimplemented!("cannot fully handle GKRIOP component yet") - Some(out_evals.into_iter().map(|pae| pae.eval).collect_vec()) + + dbg!(&prover_output.opening_evaluations.len()); + let output_evals = out_evals.into_iter().map(|pae| pae.eval).collect_vec(); + + Some(GKROpcodeProof { + output_evals, + prover_output, + circuit: gkr_circuit, + _marker: PhantomData::default(), + }) } else { None }; @@ -761,7 +757,7 @@ impl> ZKVMProver { wits_commit, wits_opening_proof, wits_in_evals, - gkr_out_evals, + gkr_opcode_proof, }) } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 7e14aec51..a49084a55 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,6 +1,7 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; use ark_std::iterable::Iterable; +use ceno_emul::{KeccakSpec, SyscallSpec}; use ff_ext::ExtensionField; use itertools::{Itertools, chain, interleave, izip}; @@ -17,6 +18,7 @@ use witness::next_pow2_instance_padding; use crate::{ error::ZKVMError, expression::{Instance, StructuralWitIn}, + instructions::{GKRIOPInstruction, riscv::dummy::LargeEcallDummy}, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, utils::eval_by_expr_with_instance, @@ -408,25 +410,51 @@ impl> ZKVMVerifier ) }; - if let Some(evals) = &proof.gkr_out_evals { - dbg!(&sel_r); - // let gkr_reads = evals.iter().take(50).collect_vec(); - let intersection_size = evals - .iter() - .counts() + if name == KeccakSpec::NAME { + // TODO: remove clone + let gkr_iop = proof + .gkr_opcode_proof + .clone() + .expect("Keccak syscall should contain GKR-IOP proof"); + + // Match output_evals with EcallDummy polynomials + let mut matches = 0; + for (i, gkr_out_eval) in gkr_iop.output_evals.iter().enumerate() { + // assert_eq!( + // *gkr_out_eval, + // proof.wits_in_evals[LargeEcallDummy::::output_map(i)], + // "{i}" + // ); + if *gkr_out_eval + == proof.wits_in_evals[LargeEcallDummy::::output_map(i)] + { + matches += 1; + } else { + dbg!(i); + } + } + dbg!(matches); + panic!(); + // Verify GKR proof + let point = Arc::new(input_opening_point.clone()); + let out_evals = gkr_iop + .output_evals .iter() - .map(|(val, count)| { - let proof_count = proof - .r_records_in_evals - .iter() - .enumerate() - .filter(|(i, x)| *alpha_read * **val == **x) - .count(); - *count.min(&proof_count) + .map(|eval| gkr_iop::evaluation::PointAndEval { + point: point.clone(), + eval: eval.clone(), }) - .sum::(); - dbg!(&intersection_size); - // panic!(); + .collect_vec(); + + gkr_iop + .circuit + .verify( + gkr_iop.prover_output.gkr_proof.clone(), + &out_evals, + &vec![], + transcript, + ) + .expect("GKR IOP verify failure"); } let computed_evals = [ diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index e05c48488..3a784c393 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -20,6 +20,7 @@ p3-field.workspace = true p3-goldilocks.workspace = true rand.workspace = true rayon.workspace = true +serde.workspace = true subprotocols = { path = "../subprotocols" } thiserror = "1" tiny-keccak.workspace = true diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index 365f1a250..4ba4a6b79 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -1,12 +1,14 @@ use std::sync::Arc; use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::{izip, Itertools}; use multilinear_extensions::virtual_poly::build_eq_x_r_vec_sequential; +use serde::Serialize; use subprotocols::expression::{Constant, Point}; -/// Evaluation expression for the gkr layer reduction and PCS opening preparation. -#[derive(Clone, Debug)] +/// Evaluation expression for the gkr layer reduction and PCS opening +/// preparation. +#[derive(Clone, Debug, Serialize)] pub enum EvalExpression { /// Single entry in the evaluation vector. Single(usize), @@ -14,9 +16,10 @@ pub enum EvalExpression { Linear(usize, Constant, Constant), /// Merging multiple evaluations which denotes a partition of the original /// polynomial. `(usize, Constant)` denote the modification of the point. - /// For example, when it receive a point `(p0, p1, p2, p3)` from a succeeding - /// layer, `vec![(2, c0), (4, c1)]` will modify the point to `(p0, p1, c0, p2, c1, p3)`. - /// where the indices specify how the partition applied to the original polynomial. + /// For example, when it receive a point `(p0, p1, p2, p3)` from a + /// succeeding layer, `vec![(2, c0), (4, c1)]` will modify the point to + /// `(p0, p1, c0, p2, c1, p3)`. where the indices specify how the + /// partition applied to the original polynomial. Partition(Vec>, Vec<(usize, Constant)>), } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index eb908d047..e34ab7a80 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,6 +1,7 @@ use ff_ext::ExtensionField; use itertools::{chain, izip, Itertools}; use layer::{Layer, LayerWitness}; +use serde::Serialize; use subprotocols::{expression::Point, sumcheck::SumcheckProof}; use transcript::Transcript; @@ -12,7 +13,7 @@ use crate::{ pub mod layer; pub mod mock; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize)] pub struct GKRCircuit { pub layers: Vec, @@ -27,13 +28,16 @@ pub struct GKRCircuitWitness { pub layers: Vec>, } +#[derive(Clone, Serialize)] pub struct GKRProverOutput { pub gkr_proof: GKRProof, pub opening_evaluations: Vec, } +#[derive(Clone, Serialize)] pub struct GKRProof(pub Vec>); +#[derive(Clone, Debug, Serialize)] pub struct Evaluation { pub value: E, pub point: Point, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index b9e51708c..3b30e1dc3 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -2,6 +2,7 @@ use ark_std::log2; use ff_ext::ExtensionField; use itertools::{chain, izip}; use linear_layer::LinearLayer; +use serde::Serialize; use subprotocols::{ expression::{Constant, Expression, Point}, sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, @@ -20,14 +21,14 @@ pub mod linear_layer; pub mod sumcheck_layer; pub mod zerocheck_layer; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize)] pub enum LayerType { Sumcheck, Zerocheck, Linear, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize)] pub struct Layer { pub name: String, pub ty: LayerType, diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 407acbc99..4ed5449ec 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -745,6 +745,10 @@ impl ProtocolBuilder for KeccakLayout { evals, expr_names, )); + + // TODO: allocate everything + chip.allocate_base_opening(0, state8[[0, 0, 0]].clone().1); + chip.allocate_base_opening(1, state8[[0, 0, 1]].clone().1); } } @@ -824,11 +828,20 @@ where } }; - let state32 = com_state + let mut state32 = com_state .into_iter() // TODO double check assumptions about canonical .map(|e| e.to_canonical_u64()) .collect_vec(); + + //dbg!(&state32); + //panic!(); + + // TODO: Need to swap to match front-end encoding? + // for i in 0..25 { + // state32.swap(2 * i, 2 * i + 1); + // } + let mut state64 = [[0u64; 5]; 5]; let mut state8 = [[[0u64; 8]; 5]; 5]; @@ -1071,6 +1084,7 @@ where ) .collect_vec(); + dbg!(&keccak_output32); push_instance( chain!( keccak_output32.into_iter().flatten().flatten(), @@ -1118,6 +1132,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &vec![]); let out_evals = { + // TODO: hard-coded for two instances, improve let point = Arc::new(vec![E::from_u64(29)]); let output_records1 = gkr_witness .layers diff --git a/subprotocols/Cargo.toml b/subprotocols/Cargo.toml index 35e5c9408..c5fc1ed8e 100644 --- a/subprotocols/Cargo.toml +++ b/subprotocols/Cargo.toml @@ -17,6 +17,7 @@ multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" p3-field.workspace = true rand.workspace = true rayon.workspace = true +serde.workspace = true thiserror = "1" transcript = { path = "../transcript" } diff --git a/subprotocols/src/expression.rs b/subprotocols/src/expression.rs index 7750280a6..5fa2a1556 100644 --- a/subprotocols/src/expression.rs +++ b/subprotocols/src/expression.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use ff_ext::ExtensionField; +use serde::Serialize; mod evaluate; mod op; @@ -9,7 +10,7 @@ mod macros; pub type Point = Arc>; -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] pub enum Constant { /// Base field Base(i64), @@ -31,7 +32,7 @@ impl Default for Constant { } } -#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] pub enum Witness { /// Base field polynomial (index). BasePoly(usize), @@ -47,7 +48,7 @@ impl Default for Witness { } } -#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] pub enum Expression { /// Constant Const(Constant), diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs index 0541f6ad4..952bebf1b 100644 --- a/subprotocols/src/sumcheck.rs +++ b/subprotocols/src/sumcheck.rs @@ -3,6 +3,7 @@ use std::{iter, mem, sync::Arc, vec}; use ark_std::log2; use ff_ext::ExtensionField; use itertools::chain; +use serde::Serialize; use transcript::Transcript; use crate::{ @@ -40,6 +41,7 @@ where num_vars: usize, } +#[derive(Clone, Serialize)] pub struct SumcheckProof { /// Messages for each round. pub univariate_polys: Vec>>, From 639ee44e2b1946a32f9c5e62a9733559c012e8fd Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 25 Apr 2025 10:16:24 +0300 Subject: [PATCH 16/17] improve unit tests --- ceno_zkvm/src/scheme/prover.rs | 5 - ceno_zkvm/src/scheme/verifier.rs | 13 ++- gkr_iop/src/precompiles/faster_keccak.rs | 124 +++++++++++------------ 3 files changed, 68 insertions(+), 74 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 4a1cffd92..27a8b5f33 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -661,11 +661,6 @@ impl> ZKVMProver { } ); - // if let Some(gkr_iop_pk) = gkr_iop_pk { - // let mut gkr_iop_pk = gkr_iop_pk.clone(); - // unimplemented!("cannot fully handle GKRIOP component yet") - // } - let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); exit_span!(main_sel_span); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index a49084a55..8983f7dc2 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -420,11 +420,11 @@ impl> ZKVMVerifier // Match output_evals with EcallDummy polynomials let mut matches = 0; for (i, gkr_out_eval) in gkr_iop.output_evals.iter().enumerate() { - // assert_eq!( - // *gkr_out_eval, - // proof.wits_in_evals[LargeEcallDummy::::output_map(i)], - // "{i}" - // ); + assert_eq!( + *gkr_out_eval, + proof.wits_in_evals[LargeEcallDummy::::output_map(i)], + "{i}" + ); if *gkr_out_eval == proof.wits_in_evals[LargeEcallDummy::::output_map(i)] { @@ -434,7 +434,6 @@ impl> ZKVMVerifier } } dbg!(matches); - panic!(); // Verify GKR proof let point = Arc::new(input_opening_point.clone()); let out_evals = gkr_iop @@ -454,7 +453,7 @@ impl> ZKVMVerifier &vec![], transcript, ) - .expect("GKR IOP verify failure"); + .expect("GKR-IOP verify failure"); } let computed_evals = [ diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs index 4ed5449ec..5c455f46a 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -16,6 +16,7 @@ use crate::{ precompiles::utils::{nest, not8_expr, zero_expr, MaskRepresentation}, ProtocolBuilder, ProtocolWitnessGenerator, }; +use ark_std::log2; use ndarray::{range, s, ArrayView, Ix2, Ix3}; use witness::RowMajorMatrix; @@ -828,20 +829,12 @@ where } }; - let mut state32 = com_state + let state32 = com_state .into_iter() // TODO double check assumptions about canonical .map(|e| e.to_canonical_u64()) .collect_vec(); - //dbg!(&state32); - //panic!(); - - // TODO: Need to swap to match front-end encoding? - // for i in 0..25 { - // state32.swap(2 * i, 2 * i + 1); - // } - let mut state64 = [[0u64; 5]; 5]; let mut state8 = [[[0u64; 8]; 5]; 5]; @@ -1033,16 +1026,6 @@ where .collect_vec(), ]; - // let sizes = all_wits64.iter().map(|e| e.len()).collect_vec(); - // dbg!(&sizes); - - // let all_wits = nest::( - // &all_wits64 - // .into_iter() - // .flat_map(|v| u64s_to_felts::(v)) - // .collect_vec(), - // ); - push_instance(all_wits64.into_iter().flatten().collect_vec()); state8 = iota_output8; @@ -1084,7 +1067,6 @@ where ) .collect_vec(); - dbg!(&keccak_output32); push_instance( chain!( keccak_output32.into_iter().flatten().flatten(), @@ -1102,14 +1084,14 @@ where } } -pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> () { +pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) -> () { let params = KeccakParams {}; let (layout, chip) = KeccakLayout::build(params); let mut instances = vec![]; - for state in states { + for state in &states { let state_mask64 = - MaskRepresentation::from(state.into_iter().map(|e| (64, e)).collect_vec()); + MaskRepresentation::from(state.into_iter().map(|e| (64, *e)).collect_vec()); let state_mask32 = state_mask64.convert(vec![32; 50]); instances.push( @@ -1123,8 +1105,10 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( ); } - dbg!(instances.len()); - let phase1_witness = layout.phase1_witness(KeccakTrace { instances }); + let num_instances = instances.len(); + let phase1_witness = layout.phase1_witness(KeccakTrace { + instances: instances.clone(), + }); let mut prover_transcript = BasicTranscript::::new(b"protocol"); @@ -1132,49 +1116,62 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &vec![]); let out_evals = { - // TODO: hard-coded for two instances, improve - let point = Arc::new(vec![E::from_u64(29)]); - let output_records1 = gkr_witness - .layers - .last() - .unwrap() - .bases - //.clone() - .iter() - .map(|base| base[0].clone()) - .collect_vec(); - let output_records2 = gkr_witness + let log2_num_instances = (num_instances as usize) + .next_power_of_two() + .trailing_zeros() as usize; + let point = Arc::new(vec![E::from_u64(29); log2_num_instances]); + + if test_outputs { + // Confront outputs with tiny_keccak::keccakf call + let mut instance_outputs = vec![vec![]; num_instances]; + for base in gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .take(KECCAK_OUTPUT_SIZE) + { + assert_eq!(base.len(), num_instances); + for i in 0..num_instances { + instance_outputs[i].push(base[i].clone()); + } + } + + for i in 0..num_instances { + let mut state = states[i]; + keccakf(&mut state); + assert_eq!( + state + .to_vec() + .iter() + .map(|e| vec![*e as u32, (e >> 32) as u32]) + .flatten() + .map(|e| Goldilocks::from_u64(e as u64)) + .collect_vec(), + instance_outputs[i] + ); + } + } + + let out_evals = gkr_witness .layers .last() .unwrap() .bases - //.clone() .iter() - .map(|base| base[1].clone()) - .collect_vec(); - - let evals = zip(&output_records1, &output_records2) - .map(|(p0, p1)| { - let poly = vec![p0.clone(), p1.clone()]; - subprotocols::utils::evaluate_mle_ext(&poly, &point) + .map(|base| PointAndEval { + point: point.clone(), + eval: subprotocols::utils::evaluate_mle_ext(&base, &point), }) .collect_vec(); - let len = output_records1.len(); assert_eq!( - len, + out_evals.len(), KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS ); - assert!(output_records1.len() == output_records2.len()); - evals - .into_iter() - .map(|elem| PointAndEval { - point: point.clone(), - eval: elem, - //eval: E::ONE, - }) - .collect_vec() + out_evals }; let gkr_circuit = chip.gkr_circuit(); @@ -1185,7 +1182,6 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> ( if verify { { - dbg!("sanity"); let mut verifier_transcript = BasicTranscript::::new(b"protocol"); gkr_circuit @@ -1205,13 +1201,17 @@ mod tests { #[test] fn test_v2_keccakf() { for _ in 0..3 { - let random_u64: u64 = rand::random(); + //let random_u64: u64 = rand::random(); // Use seeded rng for debugging convenience let mut rng = rand::rngs::StdRng::seed_from_u64(42); - let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); - let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); - // let state = [0; 50]; - run_faster_keccakf(vec![state1, state2], true, true); + + let num_instances = 8; + let mut states: Vec<[u64; 25]> = vec![]; + + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())) + } + run_faster_keccakf(states, true, true); } } } From 6b9015db7936d000bf7c5d3cd63a70cb7835090a Mon Sep 17 00:00:00 2001 From: Mihai Date: Fri, 25 Apr 2025 12:09:38 +0300 Subject: [PATCH 17/17] refactor & polish --- ceno_host/tests/test_elf.rs | 2 +- ceno_zkvm/Cargo.toml | 2 +- ceno_zkvm/src/instructions.rs | 64 +++-- .../instructions/riscv/dummy/dummy_ecall.rs | 35 +-- ceno_zkvm/src/instructions/riscv/rv32im.rs | 8 - ceno_zkvm/src/keygen.rs | 1 - ceno_zkvm/src/scheme.rs | 1 + ceno_zkvm/src/scheme/prover.rs | 27 +- ceno_zkvm/src/scheme/verifier.rs | 19 +- ceno_zkvm/src/structs.rs | 8 - examples/examples/ceno_rt_keccak.rs | 3 +- gkr_iop/Cargo.toml | 6 +- .../{keccak_f.rs => bitwise_keccakf.rs} | 11 +- .../{faster_keccak.rs => lookup_keccakf.rs} | 7 +- gkr_iop/examples/multi_layer_logup.rs | 33 +-- gkr_iop/src/chip/builder.rs | 3 +- gkr_iop/src/evaluation.rs | 2 +- gkr_iop/src/gkr.rs | 2 +- gkr_iop/src/gkr/layer.rs | 8 +- gkr_iop/src/gkr/layer/linear_layer.rs | 2 +- gkr_iop/src/gkr/mock.rs | 4 +- .../{keccak_f.rs => bitwise_keccakf.rs} | 269 ++++++++---------- .../{faster_keccak.rs => lookup_keccakf.rs} | 186 ++++++------ gkr_iop/src/precompiles/mod.rs | 14 +- gkr_iop/src/precompiles/utils.rs | 46 +-- subprotocols/src/sumcheck.rs | 3 +- subprotocols/src/zerocheck.rs | 3 - 27 files changed, 335 insertions(+), 434 deletions(-) rename gkr_iop/benches/{keccak_f.rs => bitwise_keccakf.rs} (71%) rename gkr_iop/benches/{faster_keccak.rs => lookup_keccakf.rs} (80%) rename gkr_iop/src/precompiles/{keccak_f.rs => bitwise_keccakf.rs} (64%) rename gkr_iop/src/precompiles/{faster_keccak.rs => lookup_keccakf.rs} (90%) diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index 6353a0cf9..d848652d9 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -231,7 +231,7 @@ fn test_ceno_rt_keccak() -> Result<()> { let steps = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. - const ITERATIONS: usize = 3; + const ITERATIONS: usize = 4; let keccak_outs = sample_keccak_f(ITERATIONS); let all_messages = read_all_messages(&state); diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 5fee3d3c5..1d525e471 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -23,7 +23,7 @@ gkr_iop = { path = "../gkr_iop" } mpcs = { path = "../mpcs" } multilinear_extensions = { version = "0", path = "../multilinear_extensions" } p3 = { path = "../p3" } -subprotocols = {path = "../subprotocols"} +subprotocols = { path = "../subprotocols" } sumcheck = { version = "0", path = "../sumcheck" } transcript = { path = "../transcript" } witness = { path = "../witness" } diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 839026fcd..634c35d56 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,10 +1,6 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; -use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuitWitness, layer::LayerWitness}, - precompiles::KeccakLayout, -}; +use gkr_iop::{ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuitWitness}; use itertools::Itertools; use multilinear_extensions::util::max_usable_threads; use rayon::{ @@ -87,31 +83,34 @@ impl GKRinfo { } } +// Trait that should be implemented by opcodes which use the +// GKR-IOP prover. Presently, such opcodes should also implement +// the Instruction trait and in general methods of GKRIOPInstruction +// will call corresponding methods of Instruction and do something extra +// with respect to syncing state between GKR-IOP and old-style circuits/witnesses pub trait GKRIOPInstruction where Self: Instruction, { type Layout: ProtocolWitnessGenerator + ProtocolBuilder; + /// Similar to Instruction::construct_circuit; generally + /// meant to extend InstructionConfig with GKR-specific + /// fields + #[allow(unused_variables)] fn construct_circuit_with_gkr_iop( cb: &mut CircuitBuilder, ) -> Result { unimplemented!(); } - // Returns index of `i-th` GKR-IOP output-eval in PCS - fn output_map(i: usize) -> usize { - unimplemented!(); - } - + /// Should generate phase1 witness for GKR from step records fn phase1_witness_from_steps( layout: &Self::Layout, steps: &[StepRecord], ) -> Vec>; - // Number of lookup arguments used by this GKR proof - fn gkr_info() -> GKRinfo; - + /// Similar to Instruction::assign_instance, but with access to GKR lookups and wits fn assign_instance_with_gkr_iop( config: &Self::InstructionConfig, instance: &mut [E::BaseField], @@ -121,6 +120,8 @@ where aux_wits: &[E::BaseField], ) -> Result<(), ZKVMError>; + /// Similar to Instruction::assign_instances, but with access to the GKR layout. + #[allow(clippy::type_complexity)] fn assign_instances_with_gkr_iop( config: &Self::InstructionConfig, num_witin: usize, @@ -146,10 +147,8 @@ where RowMajorMatrix::::new(steps.len(), num_witin, Self::padding_strategy()); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); - let gkr_witness = gkr_layout.gkr_witness( - &Self::phase1_witness_from_steps(gkr_layout, &steps), - &vec![], - ); + let gkr_witness = + gkr_layout.gkr_witness(&Self::phase1_witness_from_steps(gkr_layout, &steps), &[]); let (lookups, aux_wits) = { // Extract lookups and auxiliary witnesses from GKR protocol @@ -157,15 +156,24 @@ where // the output records for all instances; further, we assume that the last ```Self::lookup_count()``` // elements of this layer are the lookup arguments. let mut lookups = vec![vec![]; steps.len()]; + let mut aux_wits: Vec> = vec![vec![]; steps.len()]; let last_layer = gkr_witness.layers.last().unwrap().bases.clone(); let len = last_layer.len(); - for witness in last_layer[len - Self::gkr_info().lookup_total()..].iter() { + let lookup_total = Self::gkr_info().lookup_total(); + + let non_lookups = if len == 0 { + 0 + } else { + assert!(len >= lookup_total); + len - lookup_total + }; + + for witness in last_layer.iter().skip(non_lookups) { for i in 0..witness.len() { lookups[i].push(witness[i]); } } - let mut aux_wits: Vec> = vec![vec![]; steps.len()]; let n_layers = gkr_witness.layers.len(); for i in 0..steps.len() { @@ -194,7 +202,6 @@ where .chunks_mut(num_witin) .zip(steps) .map(|(instance, (i, step))| { - // dbg!(i, step); Self::assign_instance_with_gkr_iop( config, instance, @@ -211,4 +218,21 @@ where raw_witin.padding_by_strategy(); Ok((raw_witin, gkr_witness, lk_multiplicity)) } + + /// Lookup and witness counts used by GKR proof + fn gkr_info() -> GKRinfo; + + /// Returns corresponding column in RMM for the i-th + /// output evaluation of the GKR proof + #[allow(unused_variables)] + fn output_evals_map(i: usize) -> usize { + unimplemented!(); + } + + /// Returns corresponding column in RMM for the i-th + /// witness of the GKR proof + #[allow(unused_variables)] + fn witness_map(i: usize) -> usize { + unimplemented!(); + } } diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index e217cd152..4fa0b0e5f 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -61,8 +61,8 @@ impl Instruction for LargeEcallDummy let mem_writes = (0..S::MEM_OPS_COUNT) .map(|i| { - let val_before = cb.create_witin(|| format!("mem_before_{}", i)); - let val_after = cb.create_witin(|| format!("mem_after_{}", i)); + let val_before = cb.create_witin(|| format!("mem_before_{}_READ_ARG", i)); + let val_after = cb.create_witin(|| format!("mem_after_{}_WRITE_ARG", i)); let addr = cb.create_witin(|| format!("addr_{}", i)); WriteMEM::construct_circuit( cb, @@ -112,7 +112,6 @@ impl Instruction for LargeEcallDummy // Assign memory. for ((addr, value, writer), op) in config.mem_writes.iter().zip_eq(&ops.mem_ops) { - dbg!(&op.value.before); set_val!(instance, value.before, op.value.before as u64); set_val!(instance, value.after, op.value.after as u64); set_val!(instance, addr, u64::from(op.addr)); @@ -148,40 +147,40 @@ impl GKRIOPInstruction for LargeEcallDummy let mut aux_wits = vec![]; for i in 0..AND_LOOKUPS { - let a = cb.create_witin(|| format!("and_lookup_{i}_a")); - let b = cb.create_witin(|| format!("and_lookup_{i}_b")); - let c = cb.create_witin(|| format!("and_lookup_{i}_c")); + let a = cb.create_witin(|| format!("and_lookup_{i}_a_LOOKUP_ARG")); + let b = cb.create_witin(|| format!("and_lookup_{i}_b_LOOKUP_ARG")); + let c = cb.create_witin(|| format!("and_lookup_{i}_c_LOOKUP_ARG")); cb.lookup_and_byte(a.into(), b.into(), c.into())?; lookups.extend(vec![a, b, c]); } for i in 0..XOR_LOOKUPS { - let a = cb.create_witin(|| format!("xor_lookup_{i}_a")); - let b = cb.create_witin(|| format!("xor_lookup_{i}_b")); - let c = cb.create_witin(|| format!("xor_lookup_{i}_c")); + let a = cb.create_witin(|| format!("xor_lookup_{i}_a_LOOKUP_ARG")); + let b = cb.create_witin(|| format!("xor_lookup_{i}_b_LOOKUP_ARG")); + let c = cb.create_witin(|| format!("xor_lookup_{i}_c_LOOKUP_ARG")); cb.lookup_xor_byte(a.into(), b.into(), c.into())?; lookups.extend(vec![a, b, c]); } for i in 0..RANGE_LOOKUPS { - let wit = cb.create_witin(|| format!("range_lookup_{i}")); + let wit = cb.create_witin(|| format!("range_lookup_{i}_LOOKUP_ARG")); cb.assert_ux::<_, _, 16>(|| "nada", wit.into())?; lookups.push(wit); } for i in 0..40144 { - aux_wits.push(cb.create_witin(|| format!("aux_wit{i}"))); + aux_wits.push(cb.create_witin(|| format!("{i}_GKR_WITNESS"))); } partial_config.lookups = lookups; partial_config.aux_wits = aux_wits; - dbg!(&partial_config.mem_writes[0].1.after); - dbg!(&partial_config.mem_writes[1].1.before); - dbg!(&partial_config.lookups[0]); - Ok(partial_config) } - fn output_map(i: usize) -> usize { + // TODO: make this nicer without access to config + // one alternative: the verifier uses the namespaces + // contained in the constraint system to select + // gkr-specific fields + fn output_evals_map(i: usize) -> usize { if i < 50 { 27 + 6 * i } else if i < 100 { @@ -225,9 +224,6 @@ impl GKRIOPInstruction for LargeEcallDummy let mut wit_iter = lookups.iter().map(|f| f.to_canonical_u64()); let mut var_iter = config.lookups.iter(); - dbg!(wit_iter.clone().count()); - dbg!(var_iter.clone().count()); - let mut pop_arg = || -> u64 { let wit = wit_iter.next().unwrap(); let var = var_iter.next().unwrap(); @@ -247,7 +243,6 @@ impl GKRIOPInstruction for LargeEcallDummy lk_multiplicity.assert_ux::<16>(pop_arg()); } - dbg!(aux_wits.len()); for (aux_wit_var, aux_wit) in zip_eq(config.aux_wits.iter(), aux_wits) { set_val!(instance, aux_wit_var, (aux_wit.to_canonical_u64())); } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index bc8dbc02a..47fb43e4b 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -466,7 +466,6 @@ pub struct DummyExtraConfig { impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { let ecall_config = cs.register_opcode_circuit::>(); - // let keccak_config = cs.register_opcode_circuit::>(); let keccak_config = cs.register_keccakf_circuit(); let secp256k1_add_config = cs.register_opcode_circuit::>(); @@ -511,7 +510,6 @@ impl DummyExtraConfig { ) { fixed.register_opcode_circuit::>(cs); fixed.register_keccakf_circuit(cs); - // fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); @@ -566,12 +564,6 @@ impl DummyExtraConfig { witness.assign_keccakf_circuit(cs, &self.keccak_config, keccak_steps)?; - // witness.assign_opcode_circuit::>( - // cs, - // &self.keccak_config, - // keccak_steps, - //)?; - witness.assign_opcode_circuit::>( cs, &self.secp256k1_add_config, diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 112ba1bc6..f32d72800 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -1,6 +1,5 @@ use crate::{ error::ZKVMError, - instructions::riscv::dummy::LargeEcallDummy, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey}, }; use ff_ext::ExtensionField; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 400e60963..02915479e 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -57,6 +57,7 @@ pub struct ZKVMOpcodeProof } #[derive(Clone, Serialize)] +// WARN/TODO: depends on serde's `arc` feature which might not behave correctly pub struct GKROpcodeProof> { output_evals: Vec, prover_output: GKRProverOutput>, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 27a8b5f33..4cbab1e00 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -7,7 +7,7 @@ use std::{ sync::Arc, }; -use itertools::{Itertools, chain, enumerate, izip}; +use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, @@ -27,10 +27,7 @@ use witness::{RowMajorMatrix, next_pow2_instance_padding}; use crate::{ error::ZKVMError, expression::Instance, - instructions::{ - Instruction, - riscv::{dummy::LargeEcallDummy, ecall::EcallDummy}, - }, + instructions::{Instruction, riscv::dummy::LargeEcallDummy}, scheme::{ GKROpcodeProof, constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, @@ -261,6 +258,7 @@ impl> ZKVMProver { /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input #[allow(clippy::too_many_arguments)] + #[allow(clippy::type_complexity)] #[tracing::instrument(skip_all, name = "create_opcode_proof", fields(circuit_name=name,profiling_2), level="trace")] pub fn create_opcode_proof( &self, @@ -455,8 +453,6 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); - dbg!(rt_r.len(), rt_w.len(), rt_lk.len()); - let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), @@ -633,10 +629,6 @@ impl> ZKVMProver { } ); - dbg!(r_counts_per_instance); - dbg!(w_counts_per_instance); - dbg!(lk_counts_per_instance); - let mut main_sel_evals_iter = main_sel_evals.into_iter(); main_sel_evals_iter.next(); // skip sel_r let r_records_in_evals = (0..r_counts_per_instance) @@ -695,6 +687,7 @@ impl> ZKVMProver { opening_dur.elapsed(), ); exit_span!(pcs_open_span); + let wits_commit = PCS::get_pure_commitment(&wits_commit); let gkr_opcode_proof = if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk { @@ -711,25 +704,25 @@ impl> ZKVMProver { .iter() .map(|base| PointAndEval { point: point.clone(), - eval: subprotocols::utils::evaluate_mle_ext(&base, &point), + eval: subprotocols::utils::evaluate_mle_ext(base, &point), }) .collect_vec(); - // out_evals from point and output polynomials instead of *_records which is combined - let prover_output = gkr_circuit - .prove(gkr_wit, &out_evals, &vec![], transcript) + .prove(gkr_wit, &out_evals, &[], transcript) .expect("Failed to prove phase"); // unimplemented!("cannot fully handle GKRIOP component yet") - dbg!(&prover_output.opening_evaluations.len()); + let _gkr_open_point = prover_output.opening_evaluations[0].point.clone(); + // TODO: open polynomials for GKR proof + let output_evals = out_evals.into_iter().map(|pae| pae.eval).collect_vec(); Some(GKROpcodeProof { output_evals, prover_output, circuit: gkr_circuit, - _marker: PhantomData::default(), + _marker: PhantomData, }) } else { None diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 8983f7dc2..cf80b4eee 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -276,8 +276,6 @@ impl> ZKVMVerifier cs.lk_expressions.len(), ); - // dbg!(&cs.r_expressions); - let (log2_r_count, log2_w_count, log2_lk_count) = ( ceil_log2(r_counts_per_instance), ceil_log2(w_counts_per_instance), @@ -332,8 +330,6 @@ impl> ZKVMVerifier logup_q_evals[0].point.clone(), ); - // dbg!(&rt_r, &rt_w, &rt_lk); - let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -418,22 +414,13 @@ impl> ZKVMVerifier .expect("Keccak syscall should contain GKR-IOP proof"); // Match output_evals with EcallDummy polynomials - let mut matches = 0; for (i, gkr_out_eval) in gkr_iop.output_evals.iter().enumerate() { assert_eq!( *gkr_out_eval, - proof.wits_in_evals[LargeEcallDummy::::output_map(i)], + proof.wits_in_evals[LargeEcallDummy::::output_evals_map(i)], "{i}" ); - if *gkr_out_eval - == proof.wits_in_evals[LargeEcallDummy::::output_map(i)] - { - matches += 1; - } else { - dbg!(i); - } } - dbg!(matches); // Verify GKR proof let point = Arc::new(input_opening_point.clone()); let out_evals = gkr_iop @@ -441,7 +428,7 @@ impl> ZKVMVerifier .iter() .map(|eval| gkr_iop::evaluation::PointAndEval { point: point.clone(), - eval: eval.clone(), + eval: *eval, }) .collect_vec(); @@ -450,7 +437,7 @@ impl> ZKVMVerifier .verify( gkr_iop.prover_output.gkr_proof.clone(), &out_evals, - &vec![], + &[], transcript, ) .expect("GKR-IOP verify failure"); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index e41486ed4..768310e85 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -374,12 +374,10 @@ impl ZKVMWitnesses { pub fn assign_keccakf_circuit( &mut self, - // TODO: do without mutability requirement css: &ZKVMConstraintSystem, config: & as Instruction>::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { - // Ugly copy paste from assign_opcode_circuit, but we need to use the row major matrix let cs = css .get_cs(&LargeEcallDummy::::name()) .unwrap(); @@ -392,12 +390,6 @@ impl ZKVMWitnesses { )?; self.keccak_gkr_wit = gkr_witness; - // // Intercept row-major matrix, convert into KeccakTrace and obtain phase1_wit - // self.keccak_phase1wit = css - // .keccak_gkr_iop - // .layout - // .phase1_witness(KeccakTrace::from(witness.clone())); - assert!( self.witnesses_opcodes .insert(LargeEcallDummy::::name(), witness) diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index 135e8d9b0..46f950142 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -6,7 +6,7 @@ extern crate ceno_rt; use ceno_rt::{info_out, syscalls::syscall_keccak_permute}; use core::slice; -const ITERATIONS: usize = 2; +const ITERATIONS: usize = 4; fn main() { let mut state = [0_u64; 25]; @@ -17,6 +17,7 @@ fn main() { } } +#[allow(dead_code)] fn log_state(state: &[u64; 25]) { let out = unsafe { slice::from_raw_parts(state.as_ptr() as *const u8, state.len() * size_of::()) diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index 3a784c393..e366c5db6 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -12,7 +12,6 @@ version.workspace = true [dependencies] ark-std.workspace = true ff_ext = { path = "../ff_ext" } -witness = { path = "../witness" } itertools.workspace = true multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } ndarray.workspace = true @@ -25,14 +24,15 @@ subprotocols = { path = "../subprotocols" } thiserror = "1" tiny-keccak.workspace = true transcript = { path = "../transcript" } +witness = { path = "../witness" } [dev-dependencies] criterion.workspace = true [[bench]] harness = false -name = "keccak_f" +name = "bitwise_keccakf" [[bench]] harness = false -name = "faster_keccak" \ No newline at end of file +name = "lookup_keccakf" diff --git a/gkr_iop/benches/keccak_f.rs b/gkr_iop/benches/bitwise_keccakf.rs similarity index 71% rename from gkr_iop/benches/keccak_f.rs rename to gkr_iop/benches/bitwise_keccakf.rs index b3b330725..902b63a1b 100644 --- a/gkr_iop/benches/keccak_f.rs +++ b/gkr_iop/benches/bitwise_keccakf.rs @@ -1,9 +1,7 @@ use std::time::Duration; use criterion::*; -use gkr_iop::precompiles::{run_faster_keccakf, run_keccakf}; -use p3_field::extension::BinomialExtensionField; -use p3_goldilocks::Goldilocks; +use gkr_iop::precompiles::run_keccakf; use rand::{Rng, SeedableRng}; criterion_group!(benches, keccak_f_fn); criterion_main!(benches); @@ -12,11 +10,11 @@ const NUM_SAMPLES: usize = 10; fn keccak_f_fn(c: &mut Criterion) { // expand more input size once runtime is acceptable - let mut group = c.benchmark_group(format!("keccak_f")); + let mut group = c.benchmark_group("keccak_f".to_string()); group.sample_size(NUM_SAMPLES); // Benchmark the proving time - group.bench_function(BenchmarkId::new("keccak_f", format!("keccak_f")), |b| { + group.bench_function(BenchmarkId::new("keccak_f", "keccak_f"), |b| { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { @@ -25,7 +23,8 @@ fn keccak_f_fn(c: &mut Criterion) { let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); let instant = std::time::Instant::now(); - let _ = black_box(run_keccakf(state, false, false)); + #[allow(clippy::unit_arg)] + black_box(run_keccakf(state, false, false)); let elapsed = instant.elapsed(); time += elapsed; } diff --git a/gkr_iop/benches/faster_keccak.rs b/gkr_iop/benches/lookup_keccakf.rs similarity index 80% rename from gkr_iop/benches/faster_keccak.rs rename to gkr_iop/benches/lookup_keccakf.rs index 380ef91f5..a9ab4f5f0 100644 --- a/gkr_iop/benches/faster_keccak.rs +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -11,11 +11,11 @@ const NUM_SAMPLES: usize = 10; fn keccak_f_fn(c: &mut Criterion) { // expand more input size once runtime is acceptable - let mut group = c.benchmark_group(format!("keccak_f")); + let mut group = c.benchmark_group("keccakf"); group.sample_size(NUM_SAMPLES); // Benchmark the proving time - group.bench_function(BenchmarkId::new("keccak_f", format!("keccak_f")), |b| { + group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { @@ -25,7 +25,8 @@ fn keccak_f_fn(c: &mut Criterion) { let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); let instant = std::time::Instant::now(); - let _ = black_box(run_faster_keccakf(vec![state1, state2], false, false)); + #[allow(clippy::unit_arg)] + black_box(run_faster_keccakf(vec![state1, state2], false, false)); let elapsed = instant.elapsed(); time += elapsed; } diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index 6bf7928d2..2ac39416b 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -2,18 +2,18 @@ use std::{marker::PhantomData, mem, sync::Arc}; use ff_ext::ExtensionField; use gkr_iop::{ + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, evaluation::{EvalExpression, PointAndEval}, gkr::{ - layer::{Layer, LayerType, LayerWitness}, GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, }, - ProtocolBuilder, ProtocolWitnessGenerator, }; -use itertools::{izip, Itertools}; -use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; +use itertools::{Itertools, izip}; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; -use rand::{rngs::OsRng, Rng}; +use rand::{Rng, rngs::OsRng}; use subprotocols::expression::{Constant, Expression}; use transcript::{BasicTranscript, Transcript}; @@ -87,20 +87,17 @@ impl ProtocolBuilder for TowerChipLayout { num_1.0.into(), ]; let (in_bases, in_exts) = if i == height - 1 { - ( - vec![num_0.1.clone(), num_1.1.clone()], - vec![den_0.1.clone(), den_1.1.clone()], - ) + (vec![num_0.1.clone(), num_1.1.clone()], vec![ + den_0.1.clone(), + den_1.1.clone(), + ]) } else { - ( - vec![], - vec![ - den_0.1.clone(), - den_1.1.clone(), - num_0.1.clone(), - num_1.1.clone(), - ], - ) + (vec![], vec![ + den_0.1.clone(), + den_1.1.clone(), + num_0.1.clone(), + num_1.1.clone(), + ]) }; chip.add_layer(Layer::new( format!("Tower_layer_{}", i), diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index c05165d3d..e8c5079e1 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -57,10 +57,9 @@ impl Chip { // -> [EvalExpression; N] { self.n_evaluations += N; - //array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + // array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) // TODO: hotfix to avoid stack overflow, fix later (0..N) - .into_iter() .map(|i| EvalExpression::Single(i + self.n_evaluations - N)) .collect_vec() } diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index 4ba4a6b79..264486c83 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use ff_ext::ExtensionField; -use itertools::{izip, Itertools}; +use itertools::{Itertools, izip}; use multilinear_extensions::virtual_poly::build_eq_x_r_vec_sequential; use serde::Serialize; use subprotocols::expression::{Constant, Point}; diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index e34ab7a80..97906626b 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,5 +1,5 @@ use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; +use itertools::{Itertools, chain, izip}; use layer::{Layer, LayerWitness}; use serde::Serialize; use subprotocols::{expression::Point, sumcheck::SumcheckProof}; diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 3b30e1dc3..ef0005d18 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -222,10 +222,10 @@ impl Layer { ext_mle_evals: &[E], point: &Point, ) { - for (value, pos) in izip!( - chain![base_mle_evals, ext_mle_evals], - chain![&self.in_bases, &self.in_exts] - ) { + for (value, pos) in izip!(chain![base_mle_evals, ext_mle_evals], chain![ + &self.in_bases, + &self.in_exts + ]) { *(pos.entry_mut(claims)) = PointAndEval { point: point.clone(), eval: *value, diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 6367d1b10..8387c4d80 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -1,5 +1,5 @@ use ff_ext::ExtensionField; -use itertools::{izip, Itertools}; +use itertools::{Itertools, izip}; use subprotocols::{ error::VerifierError, expression::Point, diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 06241fea5..3c359bdfe 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use ff_ext::ExtensionField; -use itertools::{izip, Itertools}; +use itertools::{Itertools, izip}; use rand::rngs::OsRng; use subprotocols::{ expression::{Expression, VectorType}, @@ -12,7 +12,7 @@ use thiserror::Error; use crate::{evaluation::EvalExpression, utils::SliceIterator}; -use super::{layer::LayerType, GKRCircuit, GKRCircuitWitness}; +use super::{GKRCircuit, GKRCircuitWitness, layer::LayerType}; pub struct MockProver(PhantomData); diff --git a/gkr_iop/src/precompiles/keccak_f.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs similarity index 64% rename from gkr_iop/src/precompiles/keccak_f.rs rename to gkr_iop/src/precompiles/bitwise_keccakf.rs index 0a2e44f09..153e01d2b 100644 --- a/gkr_iop/src/precompiles/keccak_f.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -1,17 +1,17 @@ use std::{array::from_fn, marker::PhantomData, sync::Arc}; use crate::{ + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, evaluation::{EvalExpression, PointAndEval}, gkr::{ - layer::{Layer, LayerType, LayerWitness}, GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, }, - ProtocolBuilder, ProtocolWitnessGenerator, }; use ff_ext::ExtensionField; -use itertools::{chain, iproduct, Itertools}; -use p3_field::{extension::BinomialExtensionField, Field, PrimeCharacteristicRing}; +use itertools::{Itertools, chain, iproduct}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; use subprotocols::expression::{Constant, Expression, Witness}; @@ -24,11 +24,11 @@ struct KeccakParams {} #[derive(Clone, Debug, Default)] struct KeccakLayout { - params: KeccakParams, + _params: KeccakParams, committed_bits_id: usize, - result: Vec, + _result: Vec, _marker: PhantomData, } @@ -82,7 +82,7 @@ fn c(x: usize, z: usize, bits: &[F]) -> F { fn c_expr(x: usize, z: usize, state_wits: &[Witness]) -> Expression { (0..5) .map(|y| Expression::from(state_wits[from_xyz(x, y, z)])) - .fold(zero_expr(), |acc, x| xor_expr(acc, x)) + .fold(zero_expr(), xor_expr) } fn from_xz(x: usize, z: usize) -> usize { @@ -129,18 +129,6 @@ fn not(a: F) -> F { F::ONE - a } -fn bools_to_u64s(state: &[bool]) -> Vec { - state - .chunks(64) - .map(|chunk| { - chunk - .iter() - .enumerate() - .fold(0u64, |acc, (i, &bit)| acc | ((bit as u64) << i)) - }) - .collect::>() -} - fn u64s_to_bools(state64: &[u64]) -> Vec { state64 .iter() @@ -148,14 +136,7 @@ fn u64s_to_bools(state64: &[u64]) -> Vec { .collect() } -fn dbg_vec(vec: &Vec) { - let res = vec - .iter() - .map(|f| if *f == F::ZERO { false } else { true }) - .collect_vec(); - // println!("{:?}", bools_to_u64s(&res)); -} -fn chi(bits: &Vec) -> Vec { +fn chi(bits: &[F]) -> Vec { assert_eq!(bits.len(), STATE_SIZE); bits.iter() @@ -200,9 +181,9 @@ const RC: [u64; ROUNDS] = [ 0x8000000080008008u64, ]; -fn iota(bits: &Vec, round_value: u64) -> Vec { +fn iota(bits: &[F], round_value: u64) -> Vec { assert_eq!(bits.len(), STATE_SIZE); - let mut ret = bits.clone(); + let mut ret = bits.to_vec(); let cast = |x| match x { 0 => F::ZERO, @@ -247,7 +228,7 @@ impl ProtocolBuilder for KeccakLayout { fn init(params: Self::Params) -> Self { Self { - params, + _params: params, ..Default::default() } } @@ -259,126 +240,118 @@ impl ProtocolBuilder for KeccakLayout { fn build_gkr_phase(&mut self, chip: &mut Chip) { let final_output = chip.allocate_output_evals::(); - (0..ROUNDS) - .rev() - .into_iter() - .fold(final_output, |round_output, round| { - let (chi_output, _) = chip.allocate_wits_in_layer::(); + (0..ROUNDS).rev().fold(final_output, |round_output, round| { + let (chi_output, _) = chip.allocate_wits_in_layer::(); - let exprs = (0..STATE_SIZE) - .map(|i| iota_expr(&chi_output.iter().map(|e| e.0).collect_vec(), i, RC[round])) - .collect_vec(); + let exprs = (0..STATE_SIZE) + .map(|i| iota_expr(&chi_output.iter().map(|e| e.0).collect_vec(), i, RC[round])) + .collect_vec(); - chip.add_layer(Layer::new( - format!("Round {round}: Iota:: compute output"), - LayerType::Zerocheck, - exprs, - vec![], - chi_output.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - round_output.to_vec(), - vec![], - )); + chip.add_layer(Layer::new( + format!("Round {round}: Iota:: compute output"), + LayerType::Zerocheck, + exprs, + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + round_output.to_vec(), + vec![], + )); - let (theta_output, _) = chip.allocate_wits_in_layer::(); + let (theta_output, _) = chip.allocate_wits_in_layer::(); - // Apply the effects of the rho + pi permutation directly o the argument of chi - // No need for a separate layer - let perm = rho_and_pi_permutation(); - let permuted = (0..STATE_SIZE) - .map(|i| theta_output[perm[i]].0) - .collect_vec(); + // Apply the effects of the rho + pi permutation directly o the argument of chi + // No need for a separate layer + let perm = rho_and_pi_permutation(); + let permuted = (0..STATE_SIZE) + .map(|i| theta_output[perm[i]].0) + .collect_vec(); - let exprs = (0..STATE_SIZE) - .map(|i| chi_expr(i, &permuted)) - .collect_vec(); + let exprs = (0..STATE_SIZE) + .map(|i| chi_expr(i, &permuted)) + .collect_vec(); - chip.add_layer(Layer::new( - format!("Round {round}: Chi:: apply rho, pi and chi"), - LayerType::Zerocheck, - exprs, - vec![], - theta_output.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - chi_output.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - )); + chip.add_layer(Layer::new( + format!("Round {round}: Chi:: apply rho, pi and chi"), + LayerType::Zerocheck, + exprs, + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); - let (d_and_state, _) = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }, 0>(); - let (d, state2) = d_and_state.split_at(D_SIZE); - - // Compute post-theta state using original state and D[][] values - let exprs = (0..STATE_SIZE) - .map(|i| { - let (x, y, z) = to_xyz(i); - xor_expr(state2[i].0.into(), d[from_xz(x, z)].0.into()) - }) - .collect_vec(); - - chip.add_layer(Layer::new( - format!("Round {round}: Theta::compute output"), - LayerType::Zerocheck, - exprs, - vec![], - d_and_state.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - theta_output.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - )); + let (d_and_state, _) = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }, 0>(); + let (d, state2) = d_and_state.split_at(D_SIZE); - let (c, []) = chip.allocate_wits_in_layer::<{ C_SIZE }, 0>(); + // Compute post-theta state using original state and D[][] values + let exprs = (0..STATE_SIZE) + .map(|i| { + let (x, _, z) = to_xyz(i); + xor_expr(state2[i].0.into(), d[from_xz(x, z)].0.into()) + }) + .collect_vec(); - let c_wits = c.iter().map(|e| e.0.clone()).collect_vec(); - // Compute D[][] from C[][] values - let d_exprs = iproduct!(0..5usize, 0..64usize) - .map(|(x, z)| d_expr(x, z, &c_wits)) - .collect_vec(); + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute output"), + LayerType::Zerocheck, + exprs, + vec![], + d_and_state.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); - chip.add_layer(Layer::new( - format!("Round {round}: Theta::compute D[x][z]"), - LayerType::Zerocheck, - d_exprs, - vec![], - c.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - d.iter().map(|e| e.1.clone()).collect_vec(), - vec![], - )); + let (c, []) = chip.allocate_wits_in_layer::<{ C_SIZE }, 0>(); - let (state, []) = chip.allocate_wits_in_layer::(); - let state_wits = state.iter().map(|s| s.0).collect_vec(); + let c_wits = c.iter().map(|e| e.0).collect_vec(); + // Compute D[][] from C[][] values + let d_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d_expr(x, z, &c_wits)) + .collect_vec(); - // Compute C[][] from state - let c_exprs = iproduct!(0..5usize, 0..64usize) - .map(|(x, z)| c_expr(x, z, &state_wits)) - .collect_vec(); + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute D[x][z]"), + LayerType::Zerocheck, + d_exprs, + vec![], + c.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + d.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); - // Copy state - let id_exprs: Vec = - (0..STATE_SIZE).map(|i| state_wits[i].into()).collect_vec(); + let (state, []) = chip.allocate_wits_in_layer::(); + let state_wits = state.iter().map(|s| s.0).collect_vec(); - chip.add_layer(Layer::new( - format!("Round {round}: Theta::compute C[x][z]"), - LayerType::Zerocheck, - chain!(c_exprs, id_exprs).collect_vec(), - vec![], - state.iter().map(|t| t.1.clone()).collect_vec(), - vec![], - chain!( - c.iter().map(|e| e.1.clone()), - state2.iter().map(|e| e.1.clone()) - ) - .collect_vec(), - vec![], - )); + // Compute C[][] from state + let c_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c_expr(x, z, &state_wits)) + .collect_vec(); + + // Copy state + let id_exprs: Vec = + (0..STATE_SIZE).map(|i| state_wits[i].into()).collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute C[x][z]"), + LayerType::Zerocheck, + chain!(c_exprs, id_exprs).collect_vec(), + vec![], + state.iter().map(|t| t.1.clone()).collect_vec(), + vec![], + chain!( + c.iter().map(|e| e.1.clone()), + state2.iter().map(|e| e.1.clone()) + ) + .collect_vec(), + vec![], + )); - state - .iter() - .map(|e| e.1.clone()) - .collect_vec() - .try_into() - .unwrap() - }); + state.iter().map(|e| e.1.clone()).collect_vec() + }); // Skip base opening allocation } @@ -404,14 +377,15 @@ where res } - fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { let mut bits = phase1[self.committed_bits_id].clone(); let n_layers = 100; let mut layer_wits = Vec::>::with_capacity(n_layers + 1); - for i in 0..24 { - if i == 0 { + #[allow(clippy::needless_range_loop)] + for round in 0..24 { + if round == 0 { layer_wits.push(LayerWitness::new( bits.clone().into_iter().map(|b| vec![b]).collect_vec(), vec![], @@ -458,8 +432,8 @@ where vec![], )); - if i < 23 { - bits = iota(&bits, RC[i]); + if round < 23 { + bits = iota(&bits, RC[round]); layer_wits.push(LayerWitness::new( bits.clone().into_iter().map(|b| vec![b]).collect_vec(), vec![], @@ -495,7 +469,7 @@ fn rho(state: &[T]) -> Vec { (x, y) = (y, (2 * x + 3 * y) % Y); } - return ret.to_vec(); + ret.to_vec() } // https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_pi.cpp @@ -504,11 +478,10 @@ fn pi(state: &[T]) -> Vec { let mut ret = [T::default(); STATE_SIZE]; iproduct!(0..X, 0..Y, 0..Z) - .into_iter() .map(|(x, y, z)| ret[from_xyz(x, y, z)] = state[from_xyz((x + 3 * y) % X, x, z)]) .count(); - return ret.to_vec(); + ret.to_vec() } // Combines rho and pi steps into a single permutation @@ -517,7 +490,7 @@ fn rho_and_pi_permutation() -> Vec { pi(&rho(&perm)) } -pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { +pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) { let params = KeccakParams {}; let (layout, chip) = KeccakLayout::build(params); let gkr_circuit = chip.gkr_circuit(); @@ -530,7 +503,7 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { let mut prover_transcript = BasicTranscript::::new(b"protocol"); // Omit the commit phase1 and phase2. - let gkr_witness = layout.gkr_witness(&phase1_witness, &vec![]); + let gkr_witness = layout.gkr_witness(&phase1_witness, &[]); let out_evals = { let point = Arc::new(vec![]); @@ -546,7 +519,7 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { let expected_result_manual = iota(&last_witness, RC[23]); if test { - let mut state = state.clone(); + let mut state = state; keccakf(&mut state); let state = u64s_to_bools(&state) .into_iter() @@ -565,7 +538,7 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { }; let GKRProverOutput { gkr_proof, .. } = gkr_circuit - .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) + .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) .expect("Failed to prove phase"); if verify { @@ -573,7 +546,7 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { let mut verifier_transcript = BasicTranscript::::new(b"protocol"); gkr_circuit - .verify(gkr_proof, &out_evals, &vec![], &mut verifier_transcript) + .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) .expect("GKR verify failed"); // Omit the PCS opening phase. diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs similarity index 90% rename from gkr_iop/src/precompiles/faster_keccak.rs rename to gkr_iop/src/precompiles/lookup_keccakf.rs index 5c455f46a..50dd09384 100644 --- a/gkr_iop/src/precompiles/faster_keccak.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -1,31 +1,22 @@ -use std::{ - array::from_fn, - cmp::Ordering, - iter::{once, zip}, - marker::PhantomData, - sync::Arc, -}; +use std::{cmp::Ordering, marker::PhantomData, sync::Arc}; use crate::{ + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, evaluation::{EvalExpression, PointAndEval}, gkr::{ - layer::{Layer, LayerType, LayerWitness}, GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, }, - precompiles::utils::{nest, not8_expr, zero_expr, MaskRepresentation}, - ProtocolBuilder, ProtocolWitnessGenerator, + precompiles::utils::{MaskRepresentation, nest, not8_expr, zero_expr}, }; -use ark_std::log2; -use ndarray::{range, s, ArrayView, Ix2, Ix3}; +use ndarray::{ArrayView, Ix2, Ix3, s}; -use witness::RowMajorMatrix; - -use super::utils::{u64s_to_felts, zero_eval, CenoLookup}; -use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; +use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use ff_ext::{ExtensionField, SmallField}; -use itertools::{chain, iproduct, zip_eq, Itertools}; +use itertools::{Itertools, chain, iproduct, zip_eq}; use p3_goldilocks::Goldilocks; use subprotocols::expression::{Constant, Expression, Witness}; use tiny_keccak::keccakf; @@ -38,11 +29,9 @@ pub struct KeccakParams {} #[derive(Clone, Debug, Default)] pub struct KeccakLayout { - params: KeccakParams, - - input_columns: Vec, - - result: Vec, + _params: KeccakParams, + _input_columns: Vec, + _result: Vec, _marker: PhantomData, } @@ -53,7 +42,7 @@ fn expansion_expr(expansion: &[(usize, Witness)]) -> Expressi .fold((0, zero_expr()), |acc, (sz, felt)| { ( acc.0 + sz, - acc.1 * Expression::Const(Constant::Base(1 << sz)) + felt.clone().into(), + acc.1 * Expression::Const(Constant::Base(1 << sz)) + (*felt).into(), ) }); @@ -195,7 +184,7 @@ impl ConstraintSystem { /// This algorithm imposes the following general requirements for /// `split_rep`: /// - There exists a suffix of `split_rep` which sums to exactly `delta`. - /// This suffix can contain several elements. + /// This suffix can contain several elements. /// - Chunk sizes are at most 16 (so they can be range-checked) or they are /// exactly equal to 32. /// - There exists a prefix of chunks which sums exactly to 32. This must @@ -233,31 +222,31 @@ impl ConstraintSystem { // Lookup ranges for (size, elem) in split_rep { if *size != 32 { - self.lookup_range(elem.clone().into(), *size); + self.lookup_range((*elem).into(), *size); } } // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are // the same 64 bitstring - let mut helper = |rep8: &[Witness], repX: &[(usize, Witness)], chunks_rotation: usize| { + let mut helper = |rep8: &[Witness], rep_x: &[(usize, Witness)], chunks_rotation: usize| { // Do the same thing for the two 32-bit halves - let mut repX = repX.to_owned(); - repX.rotate_right(chunks_rotation); + let mut rep_x = rep_x.to_owned(); + rep_x.rotate_right(chunks_rotation); for i in 0..2 { // The respective 4 elements in the byte representation let lhs = rep8[4 * i..4 * (i + 1)] .iter() - .map(|wit| (8, wit.clone())) + .map(|wit| (8, *wit)) .collect_vec(); - let cnt = repX.len() / 2; - let rhs = &repX[cnt * i..cnt * (i + 1)]; + let cnt = rep_x.len() / 2; + let rhs = &rep_x[cnt * i..cnt * (i + 1)]; assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); self.constrain_reps_eq::<32>( &lhs, - &rhs, + rhs, format!( "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", sizes @@ -338,16 +327,13 @@ impl ProtocolBuilder for KeccakLayout { fn init(params: Self::Params) -> Self { Self { - params, + _params: params, ..Default::default() } } fn build_commit_phase(&mut self, chip: &mut Chip) { - //let input_columns: [usize; KECCAK_INPUT_SIZE] = - // chip.allocate_committed_base(); self.input_columns = - // input_columns.into(); - let committed = chip.allocate_committed_base::<{ 50 + 40144 }>(); + let _ = chip.allocate_committed_base::<{ 50 + 40144 }>(); } fn build_gkr_phase(&mut self, chip: &mut Chip) { @@ -359,7 +345,7 @@ impl ProtocolBuilder for KeccakLayout { let [keccak_output32, keccak_input32, lookup_outputs] = [ KECCAK_OUTPUT_SIZE, KECCAK_INPUT_SIZE, - //LOOKUPS_PER_ROUND + // LOOKUPS_PER_ROUND LOOKUP_FELTS_PER_ROUND * ROUNDS, ] .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); @@ -381,6 +367,8 @@ impl ProtocolBuilder for KeccakLayout { let mut global_xor_lookup = 3 * AND_LOOKUPS; let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; + let mut openings = 0; + for x in 0..5 { for y in 0..5 { for k in 0..2 { @@ -389,9 +377,8 @@ impl ProtocolBuilder for KeccakLayout { &keccak_output8 .slice(s![x, y, 4 * k..4 * (k + 1)]) .iter() - .map(|e| (8, e.0.clone())) - .collect_vec() - .as_slice(), + .map(|e| (8, e.0)) + .collect_vec(), ); expressions.push(expr); evals.push(keccak_output32[evals.len()].clone()); @@ -401,7 +388,7 @@ impl ProtocolBuilder for KeccakLayout { } chip.add_layer(Layer::new( - format!("build 32-bit output"), + "build 32-bit output".to_string(), LayerType::Zerocheck, expressions, vec![], @@ -414,7 +401,7 @@ impl ProtocolBuilder for KeccakLayout { expr_names, )); - let state8_loop = (0..ROUNDS).rev().into_iter().fold( + let state8_loop = (0..ROUNDS).rev().fold( keccak_output8.iter().map(|e| e.1.clone()).collect_vec(), |round_output, round| { #[allow(non_snake_case)] @@ -453,7 +440,12 @@ impl ProtocolBuilder for KeccakLayout { ) .collect_vec(); - // TODO: replace ndarrays + for wit in &bases { + chip.allocate_base_opening(openings, wit.clone().1); + openings += 1; + } + + // TODO: ndarrays can be replaced with normal arrays // Input state of the round in 8-bit chunks let state8: ArrayView<(Witness, EvalExpression), Ix3> = @@ -558,6 +550,7 @@ impl ProtocolBuilder for KeccakLayout { let mut rotation_witness = rotation_witness.iter(); for i in 0..5 { + #[allow(clippy::needless_range_loop)] for j in 0..5 { let arg = theta_output .slice(s!(j, i, ..)) @@ -567,7 +560,7 @@ impl ProtocolBuilder for KeccakLayout { let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); let many = sizes.len(); let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) - .map(|(sz, (wit, _))| (sz, wit.clone())) + .map(|(sz, (wit, _))| (sz, *wit)) .collect_vec(); let arg_rotated = rhopi_output .slice(s!((2 * i + 3 * j) % 5, j, ..)) @@ -686,12 +679,7 @@ impl ProtocolBuilder for KeccakLayout { expr_names, )); - state8 - .into_iter() - .map(|e| e.1.clone()) - .collect_vec() - .try_into() - .unwrap() + state8.into_iter().map(|e| e.1.clone()).collect_vec() }, ); @@ -713,10 +701,10 @@ impl ProtocolBuilder for KeccakLayout { for k in 0..2 { // create an expression combining 4 elements of state8 into a single 32-bit felt let expr = expansion_expr::<32>( - &state8 + state8 .slice(s![x, y, 4 * k..4 * (k + 1)]) .iter() - .map(|e| (8, e.0.clone())) + .map(|e| (8, e.0)) .collect_vec() .as_slice(), ); @@ -730,14 +718,14 @@ impl ProtocolBuilder for KeccakLayout { // TODO: eliminate this duplication zip_eq(state8.iter(), state8_loop.iter()) .map(|(e, e_loop)| { - expressions.push(e.0.clone().into()); + expressions.push(e.0.into()); evals.push(e_loop.clone()); - expr_names.push(format!("state8 identity")); + expr_names.push("state8 identity".to_string()); }) .count(); chip.add_layer(Layer::new( - format!("build 32-bit input"), + "build 32-bit input".to_string(), LayerType::Zerocheck, expressions, vec![], @@ -775,7 +763,7 @@ where poly } - fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { let n_layers = 24 + 2 + 1; let mut layer_wits = vec![ LayerWitness { @@ -798,6 +786,7 @@ where } let mut com_state = vec![]; + #[allow(clippy::needless_range_loop)] for j in 0..KECCAK_INPUT_SIZE { com_state.push(phase1[j][i]); } @@ -864,9 +853,10 @@ where curr_layer += 1; }; - push_instance(state8.clone().into_iter().flatten().flatten().collect_vec()); + push_instance(state8.into_iter().flatten().flatten().collect_vec()); - for round in 0..24 { + #[allow(clippy::needless_range_loop)] + for round in 0..ROUNDS { let mut c_aux64 = [[0u64; 5]; 5]; let mut c_aux8 = [[[0u64; 8]; 5]; 5]; @@ -918,7 +908,7 @@ where } } - let mut theta_state64 = state64.clone(); + let mut theta_state64 = state64; let mut theta_state8 = [[[0u64; 8]; 5]; 5]; let mut rotation_witness = vec![]; @@ -993,7 +983,7 @@ where } // Iota step - let mut iota_output64 = chi_output64.clone(); + let mut iota_output64 = chi_output64; let mut iota_output8 = [[[0u64; 8]; 5]; 5]; iota_output64[0][0] ^= RC[round]; @@ -1018,12 +1008,7 @@ where rhopi_output8.into_iter().flatten().flatten().collect_vec(), nonlinear8.into_iter().flatten().flatten().collect_vec(), chi_output8.into_iter().flatten().flatten().collect_vec(), - iota_output8 - .clone() - .into_iter() - .flatten() - .flatten() - .collect_vec(), + iota_output8.into_iter().flatten().flatten().collect_vec(), ]; push_instance(all_wits64.into_iter().flatten().collect_vec()); @@ -1044,26 +1029,14 @@ where } } - push_instance(state8.clone().into_iter().flatten().flatten().collect_vec()); + push_instance(state8.into_iter().flatten().flatten().collect_vec()); // For temporary convenience, use one extra layer to store the correct outputs // of the circuit This is not used during proving let lookups = chain!( - (0..ROUNDS) - .into_iter() - .rev() - .map(|i| and_lookups[i].clone()) - .flatten(), - (0..ROUNDS) - .into_iter() - .rev() - .map(|i| xor_lookups[i].clone()) - .flatten(), - (0..ROUNDS) - .into_iter() - .rev() - .map(|i| range_lookups[i].clone()) - .flatten() + (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), + (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), + (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) ) .collect_vec(); @@ -1084,14 +1057,13 @@ where } } -pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) -> () { +pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { let params = KeccakParams {}; let (layout, chip) = KeccakLayout::build(params); let mut instances = vec![]; for state in &states { - let state_mask64 = - MaskRepresentation::from(state.into_iter().map(|e| (64, *e)).collect_vec()); + let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); let state_mask32 = state_mask64.convert(vec![32; 50]); instances.push( @@ -1113,13 +1085,11 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo let mut prover_transcript = BasicTranscript::::new(b"protocol"); // Omit the commit phase1 and phase2. - let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &vec![]); + let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &[]); let out_evals = { - let log2_num_instances = (num_instances as usize) - .next_power_of_two() - .trailing_zeros() as usize; - let point = Arc::new(vec![E::from_u64(29); log2_num_instances]); + let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); + let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); if test_outputs { // Confront outputs with tiny_keccak::keccakf call @@ -1134,7 +1104,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo { assert_eq!(base.len(), num_instances); for i in 0..num_instances { - instance_outputs[i].push(base[i].clone()); + instance_outputs[i].push(base[i]); } } @@ -1145,8 +1115,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo state .to_vec() .iter() - .map(|e| vec![*e as u32, (e >> 32) as u32]) - .flatten() + .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) .map(|e| Goldilocks::from_u64(e as u64)) .collect_vec(), instance_outputs[i] @@ -1162,7 +1131,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo .iter() .map(|base| PointAndEval { point: point.clone(), - eval: subprotocols::utils::evaluate_mle_ext(&base, &point), + eval: subprotocols::utils::evaluate_mle_ext(base, &point), }) .collect_vec(); @@ -1177,7 +1146,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo let gkr_circuit = chip.gkr_circuit(); dbg!(&gkr_circuit.layers.len()); let GKRProverOutput { gkr_proof, .. } = gkr_circuit - .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) + .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) .expect("Failed to prove phase"); if verify { @@ -1185,7 +1154,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo let mut verifier_transcript = BasicTranscript::::new(b"protocol"); gkr_circuit - .verify(gkr_proof, &out_evals, &vec![], &mut verifier_transcript) + .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) .expect("GKR verify failed"); // Omit the PCS opening phase. @@ -1199,9 +1168,9 @@ mod tests { use rand::{Rng, SeedableRng}; #[test] - fn test_v2_keccakf() { + fn test_keccakf() { for _ in 0..3 { - //let random_u64: u64 = rand::random(); + // let random_u64: u64 = rand::random(); // Use seeded rng for debugging convenience let mut rng = rand::rngs::StdRng::seed_from_u64(42); @@ -1214,4 +1183,23 @@ mod tests { run_faster_keccakf(states, true, true); } } + + // TODO: make it pass + #[ignore] + #[test] + fn test_keccakf_nonpow2() { + for _ in 0..3 { + // let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let num_instances = 3; + let mut states: Vec<[u64; 25]> = vec![]; + + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())) + } + run_faster_keccakf(states, true, true); + } + } } diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index cc17ee536..ee783a711 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -1,10 +1,8 @@ -mod faster_keccak; -mod keccak_f; +mod bitwise_keccakf; +mod lookup_keccakf; mod utils; -pub use keccak_f::run_keccakf; -pub use { - faster_keccak::run_faster_keccakf, faster_keccak::KeccakLayout, faster_keccak::KeccakParams, - faster_keccak::KeccakTrace, faster_keccak::AND_LOOKUPS, faster_keccak::AND_LOOKUPS_PER_ROUND, - faster_keccak::RANGE_LOOKUPS, faster_keccak::RANGE_LOOKUPS_PER_ROUND, - faster_keccak::XOR_LOOKUPS, faster_keccak::XOR_LOOKUPS_PER_ROUND, +pub use bitwise_keccakf::run_keccakf; +pub use lookup_keccakf::{ + AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakParams, KeccakTrace, RANGE_LOOKUPS, + RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, run_faster_keccakf, }; diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index 0dd484b14..e5b17b3bd 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -9,10 +9,6 @@ pub fn zero_expr() -> Expression { Expression::Const(Constant::Base(0)) } -pub fn one_expr() -> Expression { - Expression::Const(Constant::Base(1)) -} - pub fn not8_expr(expr: Expression) -> Expression { Expression::Const(Constant::Base(0xFF)) - expr } @@ -21,15 +17,12 @@ pub fn zero_eval() -> EvalExpression { EvalExpression::Linear(0, Constant::Base(0), Constant::Base(0)) } -pub fn nest(v: &Vec) -> Vec> { - v.clone().into_iter().map(|e| vec![e]).collect_vec() +pub fn nest(v: &[E::BaseField]) -> Vec> { + v.iter().map(|e| vec![*e]).collect_vec() } pub fn u64s_to_felts(words: Vec) -> Vec { - words - .into_iter() - .map(|word| E::BaseField::from_u64(word)) - .collect() + words.into_iter().map(E::BaseField::from_u64).collect() } #[derive(Debug, Clone, PartialEq, Eq)] @@ -90,7 +83,7 @@ impl MaskRepresentation { for size in sizes { let mut mask = 0; for i in 0..size { - mask += (1 << i) * (bit_iter.next().unwrap() as u64); + mask += (1 << i) * bit_iter.next().unwrap(); } masks.push(Mask::new(size, mask)); } @@ -116,10 +109,6 @@ impl MaskRepresentation { pub fn masks(&self) -> Vec { self.rep.clone() } - - fn len(&self) -> usize { - self.rep.len() - } } #[derive(Debug)] @@ -142,32 +131,7 @@ impl IntoIterator for CenoLookup { } } -impl CenoLookup { - // combine lookup arguments with challenges - pub fn compress(&self, alpha: Constant, beta: Constant) -> Expression { - let [alpha, beta] = [alpha, beta].map(|e| { - // assert!(matches!(e, Constant::Challenge(_))); - Expression::Const(e) - }); - - match self { - CenoLookup::And(a, b, c) => { - a.clone() - + alpha.clone() * b.clone() - + alpha.clone() * alpha.clone() * c.clone() - + beta - } - CenoLookup::Xor(a, b, c) => { - a.clone() - + alpha.clone() * b.clone() - + alpha.clone() * alpha.clone() * c.clone() - + beta - } - CenoLookup::U16(a) => a.clone() + beta, - } - } -} - +#[cfg(test)] mod tests { use crate::precompiles::utils::{Mask, MaskRepresentation}; diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs index 952bebf1b..0de80b42d 100644 --- a/subprotocols/src/sumcheck.rs +++ b/subprotocols/src/sumcheck.rs @@ -266,12 +266,13 @@ where &in_point, challenges, ); + if expected_claim != got_claim { return Err(VerifierError::ClaimNotMatch( expr, expected_claim, got_claim, - "placeholder".to_string(), + expr_names[0].clone(), )); } diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs index 3b16aded9..05e9f4c1c 100644 --- a/subprotocols/src/zerocheck.rs +++ b/subprotocols/src/zerocheck.rs @@ -262,8 +262,6 @@ where base_mle_evals, } = proof; - // dbg!(&exprs); - let (in_point, expected_claims) = univariate_polys.into_iter().enumerate().fold( (vec![], sigmas), |(mut last_point, last_sigmas), (round, round_msg)| { @@ -278,7 +276,6 @@ where let sigmas = izip!(&exprs, &inv_of_one_minus_points, round_msg, last_sigmas) .map(|((_, point), inv_of_one_minus_point, poly, last_sigma)| { let len = poly.len() + 1; - // dbg!(&round, &point, &inv_of_one_minus_point, &poly); // last_sigma = (1 - point[round]) * eval_at_0 + point[round] * eval_at_1 // eval_at_0 = (last_sigma - point[round] * eval_at_1) * inv(1 - point[round]) let eval_at_0 = if !poly.is_empty() {