From 5e18d9a006f07cd8a3c732c3f5ffeb773d272c65 Mon Sep 17 00:00:00 2001 From: Mihai Date: Mon, 24 Mar 2025 19:14:50 +0200 Subject: [PATCH] add keccak_f precompiles & utilities --- Cargo.lock | 49 + Cargo.toml | 1 + gkr_iop/Cargo.toml | 13 + gkr_iop/benches/faster_keccak.rs | 38 + gkr_iop/benches/keccak_f.rs | 38 + gkr_iop/examples/multi_layer_logup.rs | 40 +- gkr_iop/src/chip/builder.rs | 31 +- gkr_iop/src/gkr/layer.rs | 33 +- gkr_iop/src/gkr/layer/linear_layer.rs | 6 +- gkr_iop/src/gkr/layer/sumcheck_layer.rs | 1 + gkr_iop/src/gkr/layer/zerocheck_layer.rs | 1 + gkr_iop/src/lib.rs | 7 +- gkr_iop/src/precompiles/faster_keccak.rs | 1162 ++++++++++++++++++++++ gkr_iop/src/precompiles/keccak_f.rs | 600 +++++++++++ gkr_iop/src/precompiles/mod.rs | 5 + gkr_iop/src/precompiles/utils.rs | 190 ++++ subprotocols/examples/zerocheck_logup.rs | 1 + subprotocols/src/error.rs | 4 +- subprotocols/src/sumcheck.rs | 6 + subprotocols/src/zerocheck.rs | 8 +- 20 files changed, 2189 insertions(+), 45 deletions(-) create mode 100644 gkr_iop/benches/faster_keccak.rs create mode 100644 gkr_iop/benches/keccak_f.rs create mode 100644 gkr_iop/src/precompiles/faster_keccak.rs create mode 100644 gkr_iop/src/precompiles/keccak_f.rs create mode 100644 gkr_iop/src/precompiles/mod.rs create mode 100644 gkr_iop/src/precompiles/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 3e76c8d84..b6418e00a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1172,15 +1172,18 @@ name = "gkr_iop" version = "0.1.0" dependencies = [ "ark-std", + "criterion", "ff_ext", "itertools 0.13.0", "multilinear_extensions", + "ndarray", "p3-field", "p3-goldilocks", "rand", "rayon", "subprotocols", "thiserror 1.0.69", + "tiny-keccak", "transcript", ] @@ -1487,6 +1490,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -1589,6 +1602,21 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nimue" version = "0.2.0" @@ -2089,6 +2117,21 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "poseidon" version = "0.1.0" @@ -2296,6 +2339,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 4fd37c506..560c75a18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" itertools = "0.13" +ndarray = "*" num-bigint = { version = "0.4.6" } num-derive = "0.4" num-traits = "0.2" diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index bbe7e85e0..323a24233 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -14,10 +14,23 @@ ark-std.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } +ndarray.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true rand.workspace = true rayon.workspace = true subprotocols = { path = "../subprotocols" } thiserror = "1" +tiny-keccak.workspace = true transcript = { path = "../transcript" } + +[dev-dependencies] +criterion.workspace = true + +[[bench]] +harness = false +name = "keccak_f" + +[[bench]] +harness = false +name = "faster_keccak" \ No newline at end of file diff --git a/gkr_iop/benches/faster_keccak.rs b/gkr_iop/benches/faster_keccak.rs new file mode 100644 index 000000000..380ef91f5 --- /dev/null +++ b/gkr_iop/benches/faster_keccak.rs @@ -0,0 +1,38 @@ +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::run_faster_keccakf; + +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("keccak_f")); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccak_f", format!("keccak_f")), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); + let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + let _ = black_box(run_faster_keccakf(vec![state1, state2], false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/benches/keccak_f.rs b/gkr_iop/benches/keccak_f.rs new file mode 100644 index 000000000..b3b330725 --- /dev/null +++ b/gkr_iop/benches/keccak_f.rs @@ -0,0 +1,38 @@ +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::{run_faster_keccakf, run_keccakf}; +use p3_field::extension::BinomialExtensionField; +use p3_goldilocks::Goldilocks; +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("keccak_f")); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccak_f", format!("keccak_f")), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + let _ = black_box(run_keccakf(state, false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index b97d3c9ba..6bf7928d2 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -2,18 +2,18 @@ use std::{marker::PhantomData, mem, sync::Arc}; use ff_ext::ExtensionField; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, evaluation::{EvalExpression, PointAndEval}, gkr::{ - GKRCircuitWitness, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, + GKRCircuitWitness, GKRProverOutput, }, + ProtocolBuilder, ProtocolWitnessGenerator, }; -use itertools::{Itertools, izip}; -use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; +use itertools::{izip, Itertools}; +use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; use p3_goldilocks::Goldilocks; -use rand::{Rng, rngs::OsRng}; +use rand::{rngs::OsRng, Rng}; use subprotocols::expression::{Constant, Expression}; use transcript::{BasicTranscript, Transcript}; @@ -64,14 +64,15 @@ impl ProtocolBuilder for TowerChipLayout { let height = self.params.height; let lookup_challenge = Expression::Const(self.lookup_challenge.clone()); - self.output_cumulative_sum = chip.allocate_output_evals(); + self.output_cumulative_sum = chip.allocate_output_evals::<2>().try_into().unwrap(); // Tower layers let ([updated_table, count], challenges) = (0..height).fold( (self.output_cumulative_sum.clone(), vec![]), |([den, num], challenges), i| { let [den_0, den_1, num_0, num_1] = if i == height - 1 { - // Allocate witnesses in the extension field, except numerator inputs in the base field. + // Allocate witnesses in the extension field, except numerator inputs in the + // base field. let ([num_0, num_1], [den_0, den_1]) = chip.allocate_wits_in_layer(); [den_0, den_1, num_0, num_1] } else { @@ -86,17 +87,20 @@ impl ProtocolBuilder for TowerChipLayout { num_1.0.into(), ]; let (in_bases, in_exts) = if i == height - 1 { - (vec![num_0.1.clone(), num_1.1.clone()], vec![ - den_0.1.clone(), - den_1.1.clone(), - ]) + ( + vec![num_0.1.clone(), num_1.1.clone()], + vec![den_0.1.clone(), den_1.1.clone()], + ) } else { - (vec![], vec![ - den_0.1.clone(), - den_1.1.clone(), - num_0.1.clone(), - num_1.1.clone(), - ]) + ( + vec![], + vec![ + den_0.1.clone(), + den_1.1.clone(), + num_0.1.clone(), + num_1.1.clone(), + ], + ) }; chip.add_layer(Layer::new( format!("Tower_layer_{}", i), @@ -109,6 +113,7 @@ impl ProtocolBuilder for TowerChipLayout { in_bases, in_exts, vec![den, num], + vec![], )); let [challenge] = chip.allocate_challenges(); ( @@ -138,6 +143,7 @@ impl ProtocolBuilder for TowerChipLayout { vec![table.1.clone()], vec![], vec![updated_table], + vec![], )); chip.allocate_base_opening(self.committed_table_id, table.1); diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index ac742fe04..c05165d3d 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -1,5 +1,6 @@ use std::array; +use itertools::Itertools; use subprotocols::expression::{Constant, Witness}; use crate::{ @@ -22,10 +23,11 @@ impl Chip { array::from_fn(|i| i + self.n_committed_exts - N) } - /// Allocate `Witness` and `EvalExpression` for the input polynomials in a layer. - /// Where `Witness` denotes the index and `EvalExpression` denotes the position - /// to place the evaluation of the polynomial after processing the layer prover - /// for each polynomial. This should be called at most once for each layer! + /// Allocate `Witness` and `EvalExpression` for the input polynomials in a + /// layer. Where `Witness` denotes the index and `EvalExpression` + /// denotes the position to place the evaluation of the polynomial after + /// processing the layer prover for each polynomial. This should be + /// called at most once for each layer! #[allow(clippy::type_complexity)] pub fn allocate_wits_in_layer( &mut self, @@ -51,9 +53,16 @@ impl Chip { } /// Generate the evaluation expression for each output. - pub fn allocate_output_evals(&mut self) -> [EvalExpression; N] { + pub fn allocate_output_evals(&mut self) -> Vec +// -> [EvalExpression; N] + { self.n_evaluations += N; - array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + //array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) + // TODO: hotfix to avoid stack overflow, fix later + (0..N) + .into_iter() + .map(|i| EvalExpression::Single(i + self.n_evaluations - N)) + .collect_vec() } /// Allocate challenges. @@ -62,14 +71,16 @@ impl Chip { array::from_fn(|i| Constant::Challenge(i + self.n_challenges - N)) } - /// Allocate a PCS opening action to a base polynomial with index `wit_index`. - /// The `EvalExpression` represents the expression to compute the evaluation. + /// Allocate a PCS opening action to a base polynomial with index + /// `wit_index`. The `EvalExpression` represents the expression to + /// compute the evaluation. pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.base_openings.push((wit_index, eval)); } - /// Allocate a PCS opening action to an ext polynomial with index `wit_index`. - /// The `EvalExpression` represents the expression to compute the evaluation. + /// Allocate a PCS opening action to an ext polynomial with index + /// `wit_index`. The `EvalExpression` represents the expression to + /// compute the evaluation. pub fn allocate_ext_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.ext_openings.push((wit_index, eval)); } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 2f23a6d77..00749d822 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -33,12 +33,13 @@ pub struct Layer { pub ty: LayerType, /// Challenges generated at the beginning of the layer protocol. pub challenges: Vec, - /// Expressions to prove in this layer. For zerocheck and linear layers, each - /// expression corresponds to an output. While in sumcheck, there is only 1 - /// expression, which corresponds to the sum of all outputs. This design is - /// for the convenience when building the following expression: - /// `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * eq(p_1, x)) expr(x)`. - /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. + /// Expressions to prove in this layer. For zerocheck and linear layers, + /// each expression corresponds to an output. While in sumcheck, there + /// is only 1 expression, which corresponds to the sum of all outputs. + /// This design is for the convenience when building the following + /// expression: `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * + /// eq(p_1, x)) expr(x)`. where `vec![e_0, beta * e_1]` will be the + /// output evaluation expressions. pub exprs: Vec, /// Positions to place the evaluations of the base inputs of this layer. pub in_bases: Vec, @@ -47,6 +48,9 @@ pub struct Layer { /// The expressions of the evaluations from the succeeding layers, which are /// connected to the outputs of this layer. pub outs: Vec, + + // For debugging purposes + pub expr_names: Vec, } #[derive(Clone, Debug)] @@ -66,7 +70,15 @@ impl Layer { in_bases: Vec, in_exts: Vec, outs: Vec, + expr_names: Vec, ) -> Self { + let mut expr_names = expr_names; + if expr_names.len() < exprs.len() { + expr_names.extend(vec![ + "unavailable".to_string(); + exprs.len() - expr_names.len() + ]); + } Self { name, ty, @@ -75,6 +87,7 @@ impl Layer { in_bases, in_exts, outs, + expr_names, } } @@ -208,10 +221,10 @@ impl Layer { ext_mle_evals: &[E], point: &Point, ) { - for (value, pos) in izip!(chain![base_mle_evals, ext_mle_evals], chain![ - &self.in_bases, - &self.in_exts - ]) { + for (value, pos) in izip!( + chain![base_mle_evals, ext_mle_evals], + chain![&self.in_bases, &self.in_exts] + ) { *(pos.entry_mut(claims)) = PointAndEval { point: point.clone(), eval: *value, diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 40d8a6ec8..6367d1b10 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -1,5 +1,5 @@ use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::{izip, Itertools}; use subprotocols::{ error::VerifierError, expression::Point, @@ -80,7 +80,7 @@ impl LinearLayer for Layer { transcript.append_field_element_exts(&ext_mle_evals); transcript.append_field_element_exts(&base_mle_evals); - for (sigma, expr) in izip!(sigmas, &self.exprs) { + for (sigma, expr, expr_name) in izip!(sigmas, &self.exprs, &self.expr_names) { let got = expr.evaluate( &ext_mle_evals, &base_mle_evals, @@ -91,7 +91,7 @@ impl LinearLayer for Layer { if *sigma != got { return Err(BackendError::LayerVerificationFailed( self.name.clone(), - VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got), + VerifierError::::ClaimNotMatch(expr.clone(), *sigma, got, expr_name.clone()), )); } } diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index afb65d9a2..01bd00e8f 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -66,6 +66,7 @@ impl SumcheckLayer for Layer { proof, challenges, transcript, + vec![], ); verifier_state diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index e9e9bc577..4df79c405 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -63,6 +63,7 @@ impl ZerocheckLayer for Layer { let verifier_state = ZerocheckVerifierState::new( sigmas, self.exprs.clone(), + vec![], out_points, proof, challenges, diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index dff90eb6d..06c83f486 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -9,6 +9,7 @@ pub mod chip; pub mod error; pub mod evaluation; pub mod gkr; +pub mod precompiles; pub mod utils; pub trait ProtocolBuilder: Sized { @@ -48,12 +49,14 @@ where fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness; } -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, +// `gkr_phase` and `opening_phase`. pub struct ProtocolProver, PCS>( PhantomData<(E, Trans, PCS)>, ); -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, `gkr_phase` and `opening_phase`. +// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, +// `gkr_phase` and `opening_phase`. pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); diff --git a/gkr_iop/src/precompiles/faster_keccak.rs b/gkr_iop/src/precompiles/faster_keccak.rs new file mode 100644 index 000000000..28ea0170d --- /dev/null +++ b/gkr_iop/src/precompiles/faster_keccak.rs @@ -0,0 +1,1162 @@ +use std::{ + array::from_fn, + cmp::Ordering, + iter::{once, zip}, + marker::PhantomData, + sync::Arc, +}; + +use crate::{ + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + layer::{Layer, LayerType, LayerWitness}, + GKRCircuitWitness, GKRProverOutput, + }, + precompiles::utils::{nest, not8_expr, zero_expr, MaskRepresentation}, + ProtocolBuilder, ProtocolWitnessGenerator, +}; +use ndarray::{range, s, ArrayView, Ix2, Ix3}; + +use super::utils::{u64s_to_felts, zero_eval, CenoLookup}; +use p3_field::{extension::BinomialExtensionField, PrimeCharacteristicRing}; + +use ff_ext::{ExtensionField, SmallField}; +use itertools::{chain, iproduct, zip_eq, Itertools}; +use p3_goldilocks::Goldilocks; +use subprotocols::expression::{Constant, Expression, Witness}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; + +#[derive(Clone, Debug, Default)] +struct KeccakParams {} + +#[derive(Clone, Debug, Default)] +struct KeccakLayout { + params: KeccakParams, + + committed_bits_id: usize, + + result: Vec, + _marker: PhantomData, +} + +fn expansion_expr(expansion: &[(usize, Witness)]) -> Expression { + let (total, ret) = expansion + .iter() + .rev() + .fold((0, zero_expr()), |acc, (sz, felt)| { + ( + acc.0 + sz, + acc.1 * Expression::Const(Constant::Base(1 << sz)) + felt.clone().into(), + ) + }); + + assert_eq!(total, SIZE); + ret +} + +/// Compute an adequate split of 64-bits into chunks for performing a rotation +/// by `delta`. The first element of the return value is the vec of chunk sizes. +/// The second one is the length of its suffix that needs to be rotated +fn rotation_split(delta: usize) -> (Vec, usize) { + let delta = delta % 64; + + if delta == 0 { + return (vec![32, 32], 0); + } + + // This split meets all requirements except for <= 16 sizes + let split32 = match delta.cmp(&32) { + Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], + Ordering::Equal => vec![32, 32], + Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], + }; + + // Split off large chunks + let split16 = split32 + .into_iter() + .flat_map(|size| { + assert!(size < 32); + if size <= 16 { + vec![size] + } else { + vec![16, size - 16] + } + }) + .collect_vec(); + + let mut sum = 0; + for (i, size) in split16.iter().rev().enumerate() { + sum += size; + if sum == delta { + return (split16, i + 1); + } + } + + panic!(); +} + +struct ConstraintSystem { + expressions: Vec, + expr_names: Vec, + evals: Vec, + and_lookups: Vec, + xor_lookups: Vec, + range_lookups: Vec, +} + +impl ConstraintSystem { + fn new() -> Self { + ConstraintSystem { + expressions: vec![], + evals: vec![], + expr_names: vec![], + and_lookups: vec![], + xor_lookups: vec![], + range_lookups: vec![], + } + } + + fn add_constraint(&mut self, expr: Expression, name: String) { + self.expressions.push(expr); + self.evals.push(zero_eval()); + self.expr_names.push(name); + } + + fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { + self.and_lookups.push(CenoLookup::And(a, b, c)); + } + + fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { + self.xor_lookups.push(CenoLookup::Xor(a, b, c)); + } + + /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. + /// In general it can be done by two U16 checks: one for `value` and one for + /// `value << (16 - size)`. + fn lookup_range(&mut self, value: Expression, size: usize) { + assert!(size <= 16); + self.range_lookups.push(CenoLookup::U16(value.clone())); + if size < 16 { + self.range_lookups.push(CenoLookup::U16( + value * Expression::Const(Constant::Base(1 << (16 - size))), + )) + } + } + + fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { + self.add_constraint(lhs - rhs, name); + } + + // Constrains that lhs and rhs encode the same value of SIZE bits + // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) + // This needs to be constrained separately + fn constrain_reps_eq( + &mut self, + lhs: &[(usize, Witness)], + rhs: &[(usize, Witness)], + name: String, + ) { + self.add_constraint( + expansion_expr::(lhs) - expansion_expr::(rhs), + name, + ); + } + + /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. + /// `rot8` and `input8` each consist of 8 chunks of 8-bits. + /// + /// `split_rep` is a chunk representation of the input which + /// allows to reduce the required rotation to an array rotation. It may use + /// non-uniform chunks. + /// + /// For example, when `delta = 2`, the 64 bits are split into chunks of + /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the + /// least significant bits so a left rotation will become a right rotation + /// of the array). To perform the required rotation, we can + /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. + /// + /// In the first step, we check that `rot8` and `split_rep` represent the + /// same 64 bits. In the second step we check that `rot8` and the appropiate + /// array rotation of `split_rep` represent the same 64 bits. + /// + /// This type of representation-equality check is done by packing chunks + /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, + /// 2c] to the first 4 elements of `rot8`). In addition, we do range + /// checks on `split_rep` which check that the felts meet the required + /// sizes. + /// + /// This algorithm imposes the following general requirements for + /// `split_rep`: + /// - There exists a suffix of `split_rep` which sums to exactly `delta`. + /// This suffix can contain several elements. + /// - Chunk sizes are at most 16 (so they can be range-checked) or they are + /// exactly equal to 32. + /// - There exists a prefix of chunks which sums exactly to 32. This must + /// hold for the rotated array as well. + /// - The number of chunks should be as small as possible. + /// + /// Consult the method `rotation_split` to see how splits are computed for a + /// given `delta + /// + /// Note that the function imposes range checks on chunk values, but it + /// makes two exceptions: + /// 1. It doesn't check the 8-bit reps (input and output). This is + /// because all 8-bit reps in the global circuit are implicitly + /// range-checked because they are lookup arguments. + /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit + /// chunk value is checked to be equal to the composition of 4 8-bit + /// chunks. As mentioned in 1., these can be trusted to be range + /// checked, so the resulting 32-bit is correct by construction as + /// well. + fn constrain_left_rotation64( + &mut self, + input8: &[Witness], + split_rep: &[(usize, Witness)], + rot8: &[Witness], + delta: usize, + label: String, + ) { + assert_eq!(input8.len(), 8); + assert_eq!(rot8.len(), 8); + + // Assert that the given split witnesses are correct for this delta + let (sizes, chunks_rotation) = rotation_split(delta); + assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); + + // Lookup ranges + for (size, elem) in split_rep { + if *size != 32 { + self.lookup_range(elem.clone().into(), *size); + } + } + + // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are + // the same 64 bitstring + let mut helper = |rep8: &[Witness], repX: &[(usize, Witness)], chunks_rotation: usize| { + // Do the same thing for the two 32-bit halves + let mut repX = repX.to_owned(); + repX.rotate_right(chunks_rotation); + + for i in 0..2 { + // The respective 4 elements in the byte representation + let lhs = rep8[4 * i..4 * (i + 1)] + .iter() + .map(|wit| (8, wit.clone())) + .collect_vec(); + let cnt = repX.len() / 2; + let rhs = &repX[cnt * i..cnt * (i + 1)]; + + assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); + + self.constrain_reps_eq::<32>( + &lhs, + &rhs, + format!( + "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", + sizes + ), + ); + } + }; + + helper(input8, split_rep, 0); + helper(rot8, split_rep, chunks_rotation); + } +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], +]; + +pub const KECCAK_INPUT_SIZE: usize = 50; +pub const KECCAK_OUTPUT_SIZE: usize = 50; + +pub const AND_LOOKUPS_PER_ROUND: usize = 200; +pub const XOR_LOOKUPS_PER_ROUND: usize = 608; +pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; +pub const LOOKUPS_PER_ROUND: usize = + AND_LOOKUPS_PER_ROUND + XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; + +macro_rules! allocate_and_split { + ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ + let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); + let mut iter = witnesses.into_iter(); + ( + $( + iter.by_ref().take($size).collect_vec(), + )* + ) + }}; + } + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(params: Self::Params) -> Self { + Self { + params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_bits_id] = chip.allocate_committed_base(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_outputs = + chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUPS_PER_ROUND * ROUNDS }>(); + + let mut final_outputs_iter = final_outputs.iter(); + + let [keccak_output32, keccak_input32, lookup_outputs] = [ + KECCAK_OUTPUT_SIZE, + KECCAK_INPUT_SIZE, + LOOKUPS_PER_ROUND * ROUNDS, + ] + .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); + + let keccak_output32 = keccak_output32.to_vec(); + let keccak_input32 = keccak_input32.to_vec(); + let lookup_outputs = lookup_outputs.to_vec(); + + let (keccak_output8, []) = chip.allocate_wits_in_layer::<200, 0>(); + + let keccak_output8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &keccak_output8).unwrap(); + + let mut expressions = vec![]; + let mut evals = vec![]; + let mut expr_names = vec![]; + + let mut lookup_index = 0; + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::<32>( + &keccak_output8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0.clone())) + .collect_vec() + .as_slice(), + ); + expressions.push(expr); + evals.push(keccak_output32[evals.len()].clone()); + expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); + } + } + } + + chip.add_layer(Layer::new( + format!("build 32-bit output"), + LayerType::Zerocheck, + expressions, + vec![], + keccak_output8 + .into_iter() + .map(|e| e.1.clone()) + .collect_vec(), + vec![], + evals, + expr_names, + )); + + let state8_loop = (0..ROUNDS).rev().into_iter().fold( + keccak_output8.iter().map(|e| e.1.clone()).collect_vec(), + |round_output, round| { + #[allow(non_snake_case)] + let ( + state8, + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + ) = allocate_and_split!( + chip, 1656, 200, 200, 30, 40, 40, 200, 146, 200, 200, 200, 200 + ); + + let total_witnesses = 200 + 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 200 + 200; + // dbg!(total_witnesses); + assert_eq!(1656, total_witnesses); + + let bases = chain!( + state8.clone(), + c_aux.clone(), + c_temp.clone(), + c_rot.clone(), + d.clone(), + theta_output.clone(), + rotation_witness.clone(), + rhopi_output.clone(), + nonlinear.clone(), + chi_output.clone(), + iota_output.clone(), + ) + .collect_vec(); + + // TODO: replace ndarrays + + // Input state of the round in 8-bit chunks + let state8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + let mut system = ConstraintSystem::new(); + + // The purpose is to compute the auxiliary array + // c[i] = XOR (state[j][i]) for j in 0..5 + // We unroll it into + // c_aux[i][j] = XOR (state[k][i]) for k in 0..j + // We use c_aux[i][4] instead of c[i] + // c_aux is also stored in 8-bit chunks + let c_aux: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); + + for i in 0..5 { + for k in 0..8 { + // Initialize first element + system.constrain_eq( + state8[[0, i, k]].0.into(), + c_aux[[i, 0, k]].0.into(), + "init c_aux".to_string(), + ); + } + for j in 1..5 { + // Check xor using lookups over all chunks + for k in 0..8 { + system.lookup_xor8( + c_aux[[i, j - 1, k]].0.into(), + state8[[j, i, k]].0.into(), + c_aux[[i, j, k]].0.into(), + ); + } + } + } + + // Compute c_rot[i] = c[i].rotate_left(1) + // To understand how rotations are performed in general, consult the + // documentation of `constrain_left_rotation64`. Here c_temp is the split + // witness for a 1-rotation. + + let c_temp: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 6), &c_temp).unwrap(); + let c_rot: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &c_rot).unwrap(); + + let (sizes, _) = rotation_split(1); + + for i in 0..5 { + assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); + + system.constrain_left_rotation64( + &c_aux.slice(s![i, 4, ..]).iter().map(|e| e.0).collect_vec(), + &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) + .map(|(e, sz)| (*sz, e.0)) + .collect_vec(), + &c_rot.slice(s![i, ..]).iter().map(|e| e.0).collect_vec(), + 1, + "theta rotation".to_string(), + ); + } + + // d is computed simply as XOR of required elements of c (and rotations) + // again stored as 8-bit chunks + let d: ArrayView<(Witness, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &d).unwrap(); + + for i in 0..5 { + for k in 0..8 { + system.lookup_xor8( + c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), + c_rot[[(i + 1) % 5, k]].0.into(), + d[[i, k]].0.into(), + ) + } + } + + // output state of the Theta sub-round, simple XOR, in 8-bit chunks + let theta_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_xor8( + state8[[j, i, k]].0.into(), + d[[i, k]].0.into(), + theta_output[[j, i, k]].0.into(), + ) + } + } + } + + // output state after applying both Rho and Pi sub-rounds + // sub-round Pi is a simple permutation of 64-bit lanes + // sub-round Rho requires rotations + let rhopi_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); + + // iterator over split witnesses + let mut rotation_witness = rotation_witness.iter(); + + for i in 0..5 { + for j in 0..5 { + let arg = theta_output + .slice(s!(j, i, ..)) + .iter() + .map(|e| e.0) + .collect_vec(); + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); + let many = sizes.len(); + let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) + .map(|(sz, (wit, _))| (sz, wit.clone())) + .collect_vec(); + let arg_rotated = rhopi_output + .slice(s!((2 * i + 3 * j) % 5, j, ..)) + .iter() + .map(|e| e.0) + .collect_vec(); + system.constrain_left_rotation64( + &arg, + &rep_split, + &arg_rotated, + ROTATION_CONSTANTS[j][i], + format!("RHOPI {i}, {j}"), + ); + } + } + + let chi_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); + + // for the Chi sub-round, we use an intermediate witness storing the result of + // the required AND + let nonlinear: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_and8( + not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), + rhopi_output[[j, (i + 2) % 5, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + ); + + system.lookup_xor8( + rhopi_output[[j, i, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + ); + } + } + } + + // TODO: 24/25 elements stay the same after Iota; eliminate duplication? + let iota_output: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + if i == 0 && j == 0 { + for k in 0..8 { + system.lookup_xor8( + chi_output[[j, i, k]].0.into(), + Expression::Const(Constant::Base( + ((RC[round] >> (k * 8)) & 0xFF) as i64, + )), + iota_output[[j, i, k]].0.into(), + ); + } + } else { + for k in 0..8 { + system.constrain_eq( + iota_output[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + "nothing special".to_string(), + ); + } + } + } + } + + let ConstraintSystem { + mut expressions, + mut expr_names, + mut evals, + and_lookups, + xor_lookups, + range_lookups, + .. + } = system; + + iota_output + .into_iter() + .enumerate() + .map(|(i, val)| { + expressions.push(val.0.into()); + expr_names.push(format!("iota_output {i}")); + evals.push(round_output[i].clone()); + }) + .count(); + + // TODO: use real challenge + let alpha = Constant::Base(1 << 8); + let beta = Constant::Base(0); + + // Send all lookups to the final output layer + for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups).enumerate() { + expressions.push(lookup.compress(alpha.clone(), beta.clone())); + expr_names.push(format!("{i}th: {:?}", lookup)); + evals.push(lookup_outputs[lookup_index].clone()); + lookup_index += 1; + } + + chip.add_layer(Layer::new( + format!("Round {round}"), + LayerType::Zerocheck, + expressions, + vec![], + bases.into_iter().map(|e| e.1).collect_vec(), + vec![], + evals, + expr_names, + )); + + state8 + .into_iter() + .map(|e| e.1.clone()) + .collect_vec() + .try_into() + .unwrap() + }, + ); + + let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); + + let state8: ArrayView<(Witness, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + let mut expressions = vec![]; + let mut evals = vec![]; + let mut expr_names = vec![]; + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::<32>( + &state8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0.clone())) + .collect_vec() + .as_slice(), + ); + expressions.push(expr); + evals.push(keccak_input32[evals.len()].clone()); + expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); + } + } + } + + // TODO: eliminate this duplication + zip_eq(state8.iter(), state8_loop.iter()) + .map(|(e, e_loop)| { + expressions.push(e.0.clone().into()); + evals.push(e_loop.clone()); + expr_names.push(format!("state8 identity")); + }) + .count(); + + chip.add_layer(Layer::new( + format!("build 32-bit input"), + LayerType::Zerocheck, + expressions, + vec![], + state8.into_iter().map(|e| e.1.clone()).collect_vec(), + vec![], + evals, + expr_names, + )); + } +} + +pub struct KeccakTrace { + pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut res = vec![]; + for instance in phase1.instances { + res.push(u64s_to_felts::( + instance.into_iter().map(|e| e as u64).collect_vec(), + )); + } + res + } + + fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + let n_layers = 24 + 2 + 1; + let mut layer_wits = vec![ + LayerWitness { + bases: vec![], + exts: vec![], + num_vars: 1 + }; + n_layers + ]; + + for com_state in phase1 { + fn conv64to8(input: u64) -> [u64; 8] { + MaskRepresentation::new(vec![(64, input).into()]) + .convert(vec![8; 8]) + .values() + .try_into() + .unwrap() + } + + let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; + + let mut add_and = |a: u64, b: u64, round: usize| { + let c = a & b; + and_lookups[round].push((c << 16) + (b << 8) + a); + }; + + let mut add_xor = |a: u64, b: u64, round: usize| { + let c = a ^ b; + xor_lookups[round].push((c << 16) + (b << 8) + a); + }; + + let mut add_range = |value: u64, size: usize, round: usize| { + assert!(size <= 16, "{size}"); + range_lookups[round].push(value); + if size < 16 { + range_lookups[round].push(value << (16 - size)); + } + }; + + let state32 = com_state + .into_iter() + // TODO double check assumptions about canonical + .map(|e| e.to_canonical_u64()) + .collect_vec(); + let mut state64 = [[0u64; 5]; 5]; + let mut state8 = [[[0u64; 8]; 5]; 5]; + + zip_eq(iproduct!(0..5, 0..5), state32.clone().iter().tuples()) + .map(|((x, y), (lo, hi))| { + state64[x][y] = lo | (hi << 32); + }) + .count(); + + for x in 0..5 { + for y in 0..5 { + state8[x][y] = conv64to8(state64[x][y]); + } + } + + let mut curr_layer = 0; + let mut push_instance = |wits: Vec| { + let felts = u64s_to_felts::(wits); + if layer_wits[curr_layer].bases.is_empty() { + layer_wits[curr_layer] = LayerWitness::new(nest::(&felts), vec![]); + } else { + for (i, bases) in layer_wits[curr_layer].bases.iter_mut().enumerate() { + bases.push(felts[i]); + } + } + curr_layer += 1; + }; + + push_instance(state8.clone().into_iter().flatten().flatten().collect_vec()); + + for round in 0..24 { + let mut c_aux64 = [[0u64; 5]; 5]; + let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + + for i in 0..5 { + c_aux64[i][0] = state64[0][i]; + c_aux8[i][0] = conv64to8(c_aux64[i][0]); + for j in 1..5 { + c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; + c_aux8[i][j] = conv64to8(c_aux64[i][j]); + + for k in 0..8 { + add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); + } + } + } + + let mut c64 = [0u64; 5]; + let mut c8 = [[0u64; 8]; 5]; + + for x in 0..5 { + c64[x] = c_aux64[x][4]; + c8[x] = conv64to8(c64[x]); + } + + let mut c_temp = [[0u64; 6]; 5]; + for i in 0..5 { + let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) + .convert(vec![16, 15, 1, 16, 15, 1]); + c_temp[i] = rep.values().try_into().unwrap(); + for mask in rep.rep { + add_range(mask.value, mask.size, round); + } + } + + let mut crot64 = [0u64; 5]; + let mut crot8 = [[0u64; 8]; 5]; + for i in 0..5 { + crot64[i] = c64[i].rotate_left(1); + crot8[i] = conv64to8(crot64[i]); + } + + let mut d64 = [0u64; 5]; + let mut d8 = [[0u64; 8]; 5]; + for x in 0..5 { + d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); + d8[x] = conv64to8(d64[x]); + for k in 0..8 { + add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); + } + } + + let mut theta_state64 = state64.clone(); + let mut theta_state8 = [[[0u64; 8]; 5]; 5]; + let mut rotation_witness = vec![]; + + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] ^= d64[x]; + theta_state8[y][x] = conv64to8(theta_state64[y][x]); + + for k in 0..8 { + add_xor(state8[y][x][k], d8[x][k], round); + } + + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) + .convert(sizes); + for mask in rep.rep.iter() { + if mask.size != 32 { + add_range(mask.value, mask.size, round); + } + } + rotation_witness.extend(rep.values()); + } + } + + // Rho and Pi steps + let mut rhopi_output64 = [[0u64; 5]; 5]; + let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = + theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); + } + } + + for x in 0..5 { + for y in 0..5 { + rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); + } + } + + // Chi step + + let mut nonlinear64 = [[0u64; 5]; 5]; + let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + nonlinear64[y][x] = + !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; + nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); + + for k in 0..8 { + add_and( + 0xFF - rhopi_output8[y][(x + 1) % 5][k], + rhopi_output8[y][(x + 2) % 5][k], + round, + ); + } + } + } + + let mut chi_output64 = [[0u64; 5]; 5]; + let mut chi_output8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; + chi_output8[y][x] = conv64to8(chi_output64[y][x]); + for k in 0..8 { + add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) + } + } + } + + // Iota step + let mut iota_output64 = chi_output64.clone(); + let mut iota_output8 = [[[0u64; 8]; 5]; 5]; + iota_output64[0][0] ^= RC[round]; + + for k in 0..8 { + add_xor(chi_output8[0][0][k], (RC[round] >> (k * 8)) & 0xFF, round); + } + + for x in 0..5 { + for y in 0..5 { + iota_output8[x][y] = conv64to8(iota_output64[x][y]); + } + } + + let all_wits64 = [ + state8.into_iter().flatten().flatten().collect_vec(), + c_aux8.into_iter().flatten().flatten().collect_vec(), + c_temp.into_iter().flatten().collect_vec(), + crot8.into_iter().flatten().collect_vec(), + d8.into_iter().flatten().collect_vec(), + theta_state8.into_iter().flatten().flatten().collect_vec(), + rotation_witness, + rhopi_output8.into_iter().flatten().flatten().collect_vec(), + nonlinear8.into_iter().flatten().flatten().collect_vec(), + chi_output8.into_iter().flatten().flatten().collect_vec(), + iota_output8 + .clone() + .into_iter() + .flatten() + .flatten() + .collect_vec(), + ]; + + // let sizes = all_wits64.iter().map(|e| e.len()).collect_vec(); + // dbg!(&sizes); + + // let all_wits = nest::( + // &all_wits64 + // .into_iter() + // .flat_map(|v| u64s_to_felts::(v)) + // .collect_vec(), + // ); + + push_instance(all_wits64.into_iter().flatten().collect_vec()); + + state8 = iota_output8; + state64 = iota_output64; + } + + let mut keccak_output32 = vec![vec![vec![0; 2]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + keccak_output32[x][y] = MaskRepresentation::from( + state8[x][y].into_iter().map(|e| (8, e)).collect_vec(), + ) + .convert(vec![32; 2]) + .values(); + } + } + + push_instance(state8.clone().into_iter().flatten().flatten().collect_vec()); + + // For temporary convenience, use one extra layer to store the correct outputs + // of the circuit This is not used during proving + let lookups = (0..24) + .into_iter() + .rev() + .flat_map(|i| { + chain!( + and_lookups[i].clone(), + xor_lookups[i].clone(), + range_lookups[i].clone() + ) + }) + .collect_vec(); + + push_instance( + chain!( + keccak_output32.into_iter().flatten().flatten(), + state32, + lookups + ) + .collect_vec(), + ); + } + + let len = layer_wits.len() - 1; + layer_wits[..len].reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) -> () { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + let gkr_circuit = chip.gkr_circuit(); + + let mut instances = vec![]; + for state in states { + let state_mask64 = + MaskRepresentation::from(state.into_iter().map(|e| (64, e)).collect_vec()); + let state_mask32 = state_mask64.convert(vec![32; 50]); + + instances.push( + state_mask32 + .values() + .iter() + .map(|e| *e as u32) + .collect_vec() + .try_into() + .unwrap(), + ); + } + + let phase1_witness = layout.phase1_witness(KeccakTrace { instances }); + + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &vec![]); + + let out_evals = { + let point1 = Arc::new(vec![E::ZERO]); + let point2 = Arc::new(vec![E::ONE]); + let final_output1 = gkr_witness + .layers + .last() + .unwrap() + .bases + //.clone() + .iter() + .map(|base| base[0].clone()) + .collect_vec(); + let final_output2 = gkr_witness + .layers + .last() + .unwrap() + .bases + //.clone() + .iter() + .map(|base| base[1].clone()) + .collect_vec(); + + // if test { + // // confront outputs with tiny_keccak result + // let mut keccak_output64 = state_mask64.values().try_into().unwrap(); + // keccakf(&mut keccak_output64); + + // let keccak_output32 = MaskRepresentation::from( + // keccak_output64.into_iter().map(|e| (64, e)).collect_vec(), + // ) + // .convert(vec![32; 50]) + // .values(); + + // let keccak_output32 = u64s_to_felts::(keccak_output32); + // assert_eq!(keccak_output32, final_output[..50]); + // } + + let len = final_output1.len(); + let gkr_outputs = chain!( + zip(final_output1, once(point1).cycle().take(len)), + zip(final_output2, once(point2).cycle().take(len)) + ); + gkr_outputs + .into_iter() + .map(|(elem, point)| PointAndEval { + point: point.clone(), + eval: E::from_bases(&[elem, Goldilocks::ZERO]), + }) + .collect_vec() + }; + + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &vec![], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_v2_keccakf() { + for _ in 0..3 { + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); + let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + // let state = [0; 50]; + run_faster_keccakf(vec![state1, state2], true, true); + } + } +} diff --git a/gkr_iop/src/precompiles/keccak_f.rs b/gkr_iop/src/precompiles/keccak_f.rs new file mode 100644 index 000000000..0a2e44f09 --- /dev/null +++ b/gkr_iop/src/precompiles/keccak_f.rs @@ -0,0 +1,600 @@ +use std::{array::from_fn, marker::PhantomData, sync::Arc}; + +use crate::{ + chip::Chip, + evaluation::{EvalExpression, PointAndEval}, + gkr::{ + layer::{Layer, LayerType, LayerWitness}, + GKRCircuitWitness, GKRProverOutput, + }, + ProtocolBuilder, ProtocolWitnessGenerator, +}; +use ff_ext::ExtensionField; +use itertools::{chain, iproduct, Itertools}; +use p3_field::{extension::BinomialExtensionField, Field, PrimeCharacteristicRing}; +use p3_goldilocks::Goldilocks; + +use subprotocols::expression::{Constant, Expression, Witness}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; +#[derive(Clone, Debug, Default)] +struct KeccakParams {} + +#[derive(Clone, Debug, Default)] +struct KeccakLayout { + params: KeccakParams, + + committed_bits_id: usize, + + result: Vec, + _marker: PhantomData, +} + +const X: usize = 5; +const Y: usize = 5; +const Z: usize = 64; +const STATE_SIZE: usize = X * Y * Z; +const C_SIZE: usize = X * Z; +const D_SIZE: usize = X * Z; + +fn to_xyz(i: usize) -> (usize, usize, usize) { + assert!(i < STATE_SIZE); + (i / 64 % 5, (i / 64) / 5, i % 64) +} + +fn from_xyz(x: usize, y: usize, z: usize) -> usize { + assert!(x < 5 && y < 5 && z < 64); + 64 * (5 * y + x) + z +} + +fn xor(a: F, b: F) -> F { + a + b - a * b - a * b +} + +fn and_expr(a: Expression, b: Expression) -> Expression { + a.clone() * b.clone() +} + +fn not_expr(a: Expression) -> Expression { + one_expr() - a +} + +fn xor_expr(a: Expression, b: Expression) -> Expression { + a.clone() + b.clone() - Expression::Const(Constant::Base(2)) * a * b +} + +fn zero_expr() -> Expression { + Expression::Const(Constant::Base(0)) +} + +fn one_expr() -> Expression { + Expression::Const(Constant::Base(1)) +} + +fn c(x: usize, z: usize, bits: &[F]) -> F { + (0..5) + .map(|y| bits[from_xyz(x, y, z)]) + .fold(F::ZERO, |acc, x| xor(acc, x)) +} + +fn c_expr(x: usize, z: usize, state_wits: &[Witness]) -> Expression { + (0..5) + .map(|y| Expression::from(state_wits[from_xyz(x, y, z)])) + .fold(zero_expr(), |acc, x| xor_expr(acc, x)) +} + +fn from_xz(x: usize, z: usize) -> usize { + x * 64 + z +} + +fn d(x: usize, z: usize, c_vals: &[F]) -> F { + let lhs = from_xz((x + 5 - 1) % 5, z); + let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); + xor(c_vals[lhs], c_vals[rhs]) +} + +fn d_expr(x: usize, z: usize, c_wits: &[Witness]) -> Expression { + let lhs = from_xz((x + 5 - 1) % 5, z); + let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); + xor_expr(c_wits[lhs].into(), c_wits[rhs].into()) +} + +fn theta(bits: Vec) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + + let c_vals = iproduct!(0..5, 0..64) + .map(|(x, z)| c(x, z, &bits)) + .collect_vec(); + + let d_vals = iproduct!(0..5, 0..64) + .map(|(x, z)| d(x, z, &c_vals)) + .collect_vec(); + + bits.iter() + .enumerate() + .map(|(i, bit)| { + let (x, _, z) = to_xyz(i); + xor(*bit, d_vals[from_xz(x, z)]) + }) + .collect() +} + +fn and(a: F, b: F) -> F { + a * b +} + +fn not(a: F) -> F { + F::ONE - a +} + +fn bools_to_u64s(state: &[bool]) -> Vec { + state + .chunks(64) + .map(|chunk| { + chunk + .iter() + .enumerate() + .fold(0u64, |acc, (i, &bit)| acc | ((bit as u64) << i)) + }) + .collect::>() +} + +fn u64s_to_bools(state64: &[u64]) -> Vec { + state64 + .iter() + .flat_map(|&word| (0..64).map(move |i| ((word >> i) & 1) == 1)) + .collect() +} + +fn dbg_vec(vec: &Vec) { + let res = vec + .iter() + .map(|f| if *f == F::ZERO { false } else { true }) + .collect_vec(); + // println!("{:?}", bools_to_u64s(&res)); +} +fn chi(bits: &Vec) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + + bits.iter() + .enumerate() + .map(|(i, bit)| { + let (x, y, z) = to_xyz(i); + let rhs = and( + not(bits[from_xyz((x + 1) % X, y, z)]), + bits[from_xyz((x + 2) % X, y, z)], + ); + xor(*bit, rhs) + }) + .collect() +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +fn iota(bits: &Vec, round_value: u64) -> Vec { + assert_eq!(bits.len(), STATE_SIZE); + let mut ret = bits.clone(); + + let cast = |x| match x { + 0 => F::ZERO, + 1 => F::ONE, + _ => unreachable!(), + }; + + for z in 0..Z { + ret[from_xyz(0, 0, z)] = xor(bits[from_xyz(0, 0, z)], cast((round_value >> z) & 1)); + } + + ret +} + +fn iota_expr(bits: &[Witness], index: usize, round_value: u64) -> Expression { + assert_eq!(bits.len(), STATE_SIZE); + let (x, y, z) = to_xyz(index); + + if x > 0 || y > 0 { + bits[index].into() + } else { + let round_bit = Expression::Const(Constant::Base( + ((round_value >> index) & 1).try_into().unwrap(), + )); + xor_expr(bits[from_xyz(0, 0, z)].into(), round_bit) + } +} + +fn chi_expr(i: usize, bits: &[Witness]) -> Expression { + assert_eq!(bits.len(), STATE_SIZE); + + let (x, y, z) = to_xyz(i); + let rhs = and_expr( + not_expr(bits[from_xyz((x + 1) % X, y, z)].into()), + bits[from_xyz((x + 2) % X, y, z)].into(), + ); + xor_expr((bits[i]).into(), rhs) +} + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(params: Self::Params) -> Self { + Self { + params, + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_bits_id] = chip.allocate_committed_base(); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_output = chip.allocate_output_evals::(); + + (0..ROUNDS) + .rev() + .into_iter() + .fold(final_output, |round_output, round| { + let (chi_output, _) = chip.allocate_wits_in_layer::(); + + let exprs = (0..STATE_SIZE) + .map(|i| iota_expr(&chi_output.iter().map(|e| e.0).collect_vec(), i, RC[round])) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Iota:: compute output"), + LayerType::Zerocheck, + exprs, + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + round_output.to_vec(), + vec![], + )); + + let (theta_output, _) = chip.allocate_wits_in_layer::(); + + // Apply the effects of the rho + pi permutation directly o the argument of chi + // No need for a separate layer + let perm = rho_and_pi_permutation(); + let permuted = (0..STATE_SIZE) + .map(|i| theta_output[perm[i]].0) + .collect_vec(); + + let exprs = (0..STATE_SIZE) + .map(|i| chi_expr(i, &permuted)) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Chi:: apply rho, pi and chi"), + LayerType::Zerocheck, + exprs, + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (d_and_state, _) = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }, 0>(); + let (d, state2) = d_and_state.split_at(D_SIZE); + + // Compute post-theta state using original state and D[][] values + let exprs = (0..STATE_SIZE) + .map(|i| { + let (x, y, z) = to_xyz(i); + xor_expr(state2[i].0.into(), d[from_xz(x, z)].0.into()) + }) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute output"), + LayerType::Zerocheck, + exprs, + vec![], + d_and_state.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (c, []) = chip.allocate_wits_in_layer::<{ C_SIZE }, 0>(); + + let c_wits = c.iter().map(|e| e.0.clone()).collect_vec(); + // Compute D[][] from C[][] values + let d_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d_expr(x, z, &c_wits)) + .collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute D[x][z]"), + LayerType::Zerocheck, + d_exprs, + vec![], + c.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + d.iter().map(|e| e.1.clone()).collect_vec(), + vec![], + )); + + let (state, []) = chip.allocate_wits_in_layer::(); + let state_wits = state.iter().map(|s| s.0).collect_vec(); + + // Compute C[][] from state + let c_exprs = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c_expr(x, z, &state_wits)) + .collect_vec(); + + // Copy state + let id_exprs: Vec = + (0..STATE_SIZE).map(|i| state_wits[i].into()).collect_vec(); + + chip.add_layer(Layer::new( + format!("Round {round}: Theta::compute C[x][z]"), + LayerType::Zerocheck, + chain!(c_exprs, id_exprs).collect_vec(), + vec![], + state.iter().map(|t| t.1.clone()).collect_vec(), + vec![], + chain!( + c.iter().map(|e| e.1.clone()), + state2.iter().map(|e| e.1.clone()) + ) + .collect_vec(), + vec![], + )); + + state + .iter() + .map(|e| e.1.clone()) + .collect_vec() + .try_into() + .unwrap() + }); + + // Skip base opening allocation + } +} + +pub struct KeccakTrace { + pub bits: [bool; STATE_SIZE], +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + let mut res = vec![vec![]; 1]; + res[0] = phase1 + .bits + .into_iter() + .map(|b| E::BaseField::from_u64(b as u64)) + .collect(); + res + } + + fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness { + let mut bits = phase1[self.committed_bits_id].clone(); + + let n_layers = 100; + let mut layer_wits = Vec::>::with_capacity(n_layers + 1); + + for i in 0..24 { + if i == 0 { + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + } + + let c_wits = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| c(x, z, &bits)) + .collect_vec(); + + layer_wits.push(LayerWitness::new( + chain!( + c_wits.clone().into_iter().map(|b| vec![b]), + // Note: it seems test pass even if this is uncommented. + // Maybe it's good to assert there are no unused witnesses + // bits.clone().into_iter().map(|b| vec![b]) + ) + .collect_vec(), + vec![], + )); + + let d_wits = iproduct!(0..5usize, 0..64usize) + .map(|(x, z)| d(x, z, &c_wits)) + .collect_vec(); + + layer_wits.push(LayerWitness::new( + chain!( + d_wits.clone().into_iter().map(|b| vec![b]), + bits.clone().into_iter().map(|b| vec![b]) + ) + .collect_vec(), + vec![], + )); + + bits = theta(bits); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + + bits = chi(&pi(&rho(&bits))); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + + if i < 23 { + bits = iota(&bits, RC[i]); + layer_wits.push(LayerWitness::new( + bits.clone().into_iter().map(|b| vec![b]).collect_vec(), + vec![], + )); + } + } + + // Assumes one input instance + let total_witness_size: usize = layer_wits.iter().map(|layer| layer.bases.len()).sum(); + dbg!(total_witness_size); + + layer_wits.reverse(); + + GKRCircuitWitness { layers: layer_wits } + } +} + +// based on +// https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_rho.cpp +fn rho(state: &[T]) -> Vec { + assert_eq!(state.len(), STATE_SIZE); + let (mut x, mut y) = (1, 0); + let mut ret = [T::default(); STATE_SIZE]; + + for z in 0..Z { + ret[from_xyz(0, 0, z)] = state[from_xyz(0, 0, z)]; + } + for t in 0..24 { + for z in 0..Z { + let new_z = (1000 * Z + z - (t + 1) * (t + 2) / 2) % Z; + ret[from_xyz(x, y, z)] = state[from_xyz(x, y, new_z)]; + } + (x, y) = (y, (2 * x + 3 * y) % Y); + } + + return ret.to_vec(); +} + +// https://github.com/0xPolygonHermez/zkevm-prover/blob/main/tools/sm/keccak_f/keccak_pi.cpp +fn pi(state: &[T]) -> Vec { + assert_eq!(state.len(), STATE_SIZE); + let mut ret = [T::default(); STATE_SIZE]; + + iproduct!(0..X, 0..Y, 0..Z) + .into_iter() + .map(|(x, y, z)| ret[from_xyz(x, y, z)] = state[from_xyz((x + 3 * y) % X, x, z)]) + .count(); + + return ret.to_vec(); +} + +// Combines rho and pi steps into a single permutation +fn rho_and_pi_permutation() -> Vec { + let perm: [usize; STATE_SIZE] = from_fn(|i| i); + pi(&rho(&perm)) +} + +pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) -> () { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + let gkr_circuit = chip.gkr_circuit(); + + let bits = u64s_to_bools(&state); + + let phase1_witness = layout.phase1_witness(KeccakTrace { + bits: bits.try_into().unwrap(), + }); + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let gkr_witness = layout.gkr_witness(&phase1_witness, &vec![]); + + let out_evals = { + let point = Arc::new(vec![]); + + let last_witness = gkr_witness.layers[0] + .bases + .clone() + .into_iter() + .flatten() + .collect_vec(); + + // Last witness is missing the final sub-round; apply it now + let expected_result_manual = iota(&last_witness, RC[23]); + + if test { + let mut state = state.clone(); + keccakf(&mut state); + let state = u64s_to_bools(&state) + .into_iter() + .map(|b| Goldilocks::from_u64(b as u64)) + .collect_vec(); + assert_eq!(state, expected_result_manual); + } + + expected_result_manual + .iter() + .map(|bit| PointAndEval { + point: point.clone(), + eval: E::from_bases(&[*bit, Goldilocks::ZERO]), + }) + .collect_vec() + }; + + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &vec![], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &vec![], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use rand::{Rng, SeedableRng}; + + use super::*; + + #[test] + fn test_keccakf() { + for _ in 0..3 { + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); + let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + run_keccakf(state, true, true); + } + } +} diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs new file mode 100644 index 000000000..56ce5c42f --- /dev/null +++ b/gkr_iop/src/precompiles/mod.rs @@ -0,0 +1,5 @@ +mod faster_keccak; +mod keccak_f; +mod utils; +pub use faster_keccak::run_faster_keccakf; +pub use keccak_f::run_keccakf; diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs new file mode 100644 index 000000000..bec858e43 --- /dev/null +++ b/gkr_iop/src/precompiles/utils.rs @@ -0,0 +1,190 @@ +use ff_ext::ExtensionField; +use itertools::Itertools; +use p3_field::PrimeCharacteristicRing; +use subprotocols::expression::{Constant, Expression}; + +use crate::evaluation::EvalExpression; + +pub fn zero_expr() -> Expression { + Expression::Const(Constant::Base(0)) +} + +pub fn one_expr() -> Expression { + Expression::Const(Constant::Base(1)) +} + +pub fn not8_expr(expr: Expression) -> Expression { + Expression::Const(Constant::Base(0xFF)) - expr +} + +pub fn zero_eval() -> EvalExpression { + EvalExpression::Linear(0, Constant::Base(0), Constant::Base(0)) +} + +pub fn nest(v: &Vec) -> Vec> { + v.clone().into_iter().map(|e| vec![e]).collect_vec() +} + +pub fn u64s_to_felts(words: Vec) -> Vec { + words + .into_iter() + .map(|word| E::BaseField::from_u64(word)) + .collect() +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Mask { + pub size: usize, + pub value: u64, +} + +impl Mask { + pub fn new(size: usize, value: u64) -> Self { + if size < 64 { + assert!(value < (1 << size)); + } + Self { size, value } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MaskRepresentation { + pub rep: Vec, +} + +impl From for (usize, u64) { + fn from(mask: Mask) -> Self { + (mask.size, mask.value) + } +} + +impl From<(usize, u64)> for Mask { + fn from(tuple: (usize, u64)) -> Self { + Mask::new(tuple.0, tuple.1) + } +} + +impl From for Vec<(usize, u64)> { + fn from(mask_rep: MaskRepresentation) -> Self { + mask_rep.rep.into_iter().map(|mask| mask.into()).collect() + } +} + +impl From> for MaskRepresentation { + fn from(tuples: Vec<(usize, u64)>) -> Self { + MaskRepresentation { + rep: tuples.into_iter().map(|tuple| tuple.into()).collect(), + } + } +} + +impl MaskRepresentation { + pub fn new(masks: Vec) -> Self { + Self { rep: masks } + } + + pub fn from_bits(bits: Vec, sizes: Vec) -> Self { + assert_eq!(bits.len(), sizes.iter().sum::()); + let mut masks = Vec::new(); + let mut bit_iter = bits.into_iter(); + for size in sizes { + let mut mask = 0; + for i in 0..size { + mask += (1 << i) * (bit_iter.next().unwrap() as u64); + } + masks.push(Mask::new(size, mask)); + } + Self { rep: masks } + } + + pub fn to_bits(&self) -> Vec { + self.rep + .iter() + .flat_map(|mask| (0..mask.size).map(move |i| ((mask.value >> i) & 1))) + .collect() + } + + pub fn convert(&self, new_sizes: Vec) -> Self { + let bits = self.to_bits(); + Self::from_bits(bits, new_sizes) + } + + pub fn values(&self) -> Vec { + self.rep.iter().map(|m| m.value).collect_vec() + } + + pub fn masks(&self) -> Vec { + self.rep.clone() + } + + fn len(&self) -> usize { + self.rep.len() + } +} + +#[derive(Debug)] +pub enum CenoLookup { + And(Expression, Expression, Expression), + Xor(Expression, Expression, Expression), + U16(Expression), +} + +impl CenoLookup { + // combine lookup arguments with challenges + pub fn compress(&self, alpha: Constant, beta: Constant) -> Expression { + let [alpha, beta] = [alpha, beta].map(|e| { + // assert!(matches!(e, Constant::Challenge(_))); + Expression::Const(e) + }); + + match self { + CenoLookup::And(a, b, c) => { + a.clone() + + alpha.clone() * b.clone() + + alpha.clone() * alpha.clone() * c.clone() + + beta + } + CenoLookup::Xor(a, b, c) => { + a.clone() + + alpha.clone() * b.clone() + + alpha.clone() * alpha.clone() * c.clone() + + beta + } + CenoLookup::U16(a) => a.clone() + beta, + } + } +} + +mod tests { + use crate::precompiles::utils::{Mask, MaskRepresentation}; + + #[test] + fn test_mask_representation_from_bits() { + let bits = vec![1, 0, 1, 1, 0, 1, 0, 0]; + let sizes = vec![3, 5]; + let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone()); + assert_eq!(mask_rep.rep.len(), 2); + assert_eq!(mask_rep.rep[0], Mask::new(3, 0b101)); + assert_eq!(mask_rep.rep[1], Mask::new(5, 0b00101)); + } + + #[test] + fn test_mask_representation_to_bits() { + let masks = vec![Mask::new(3, 0b101), Mask::new(5, 0b00101)]; + let mask_rep = MaskRepresentation::new(masks); + let bits = mask_rep.to_bits(); + assert_eq!(bits, vec![1, 0, 1, 1, 0, 1, 0, 0]); + } + + #[test] + fn test_mask_representation_convert() { + let bits = vec![1, 0, 1, 1, 0, 1, 0, 0]; + let sizes = vec![3, 5]; + let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone()); + let new_sizes = vec![4, 4]; + let new_mask_rep = mask_rep.convert(new_sizes); + assert_eq!(new_mask_rep.rep.len(), 2); + assert_eq!(new_mask_rep.rep[0], Mask::new(4, 0b1101)); + assert_eq!(new_mask_rep.rep[1], Mask::new(4, 0b0010)); + } +} diff --git a/subprotocols/examples/zerocheck_logup.rs b/subprotocols/examples/zerocheck_logup.rs index 4390db649..36c227b7e 100644 --- a/subprotocols/examples/zerocheck_logup.rs +++ b/subprotocols/examples/zerocheck_logup.rs @@ -51,6 +51,7 @@ fn run_verifier( let verifier = ZerocheckVerifierState::new( vec![*ans], vec![expr], + vec![], vec![point], proof, &challenges, diff --git a/subprotocols/src/error.rs b/subprotocols/src/error.rs index 5dad7bc9d..ca8eddefe 100644 --- a/subprotocols/src/error.rs +++ b/subprotocols/src/error.rs @@ -4,6 +4,6 @@ use crate::expression::Expression; #[derive(Clone, Debug, Error)] pub enum VerifierError { - #[error("Claim not match: expr: {0:?}, expect: {1:?}, got: {2:?}")] - ClaimNotMatch(Expression, E, E), + #[error("Claim not match: expr: {0:?}\n (expr name: {3:?})\n expect: {1:?}, got: {2:?}")] + ClaimNotMatch(Expression, E, E, String), } diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs index d01c2702a..3b1ebc079 100644 --- a/subprotocols/src/sumcheck.rs +++ b/subprotocols/src/sumcheck.rs @@ -179,6 +179,7 @@ where sigma: E, expr: Expression, proof: SumcheckProof, + expr_names: Vec, challenges: &'a [E], transcript: &'a mut Trans, out_points: Vec<&'a [E]>, @@ -202,6 +203,7 @@ where proof: SumcheckProof, challenges: &'a [E], transcript: &'a mut Trans, + expr_names: Vec, ) -> Self { Self { sigma, @@ -210,6 +212,7 @@ where challenges, transcript, out_points, + expr_names, } } @@ -221,6 +224,7 @@ where challenges, transcript, out_points, + expr_names, } = self; let SumcheckProof { univariate_polys, @@ -262,6 +266,7 @@ where expr, expected_claim, got_claim, + "placeholder".to_string(), )); } @@ -324,6 +329,7 @@ mod test { proof, &challenges, &mut verifier_transcript, + vec![], ); verifier.verify().expect("verification failed"); diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs index 0d8b1bcec..7b6d989e4 100644 --- a/subprotocols/src/zerocheck.rs +++ b/subprotocols/src/zerocheck.rs @@ -196,6 +196,7 @@ where inv_of_one_minus_points: Vec>, exprs: Vec<(Expression, &'a [E])>, proof: SumcheckProof, + expr_names: Vec, challenges: &'a [E], transcript: &'a mut Trans, } @@ -208,6 +209,7 @@ where pub fn new( sigmas: Vec, exprs: Vec, + expr_names: Vec, points: Vec<&'a [E]>, proof: SumcheckProof, challenges: &'a [E], @@ -235,6 +237,7 @@ where proof, challenges, transcript, + expr_names, } } @@ -246,6 +249,7 @@ where proof, challenges, transcript, + expr_names, .. } = self; let SumcheckProof { @@ -285,7 +289,7 @@ where ); // Check the final evaluations. - for (expected_claim, (expr, _)) in izip!(expected_claims, exprs) { + for (expected_claim, (expr, _), expr_name) in izip!(expected_claims, exprs, expr_names) { let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); if expected_claim != got_claim { @@ -293,6 +297,7 @@ where expr, expected_claim, got_claim, + expr_name.clone(), )); } } @@ -352,6 +357,7 @@ mod test { let verifier = ZerocheckVerifierState::new( sigmas, exprs, + vec![], points, proof, &challenges,