diff --git a/README.md b/README.md index 7841cc0..2355150 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,22 @@ let cfg = CircomConfig::::new( // Insert our public inputs as key value pairs let mut builder = CircomBuilder::new(cfg); -builder.push_input("a", 3); -builder.push_input("b", 11); +builder.push_input("a", ark_circom::circom::Inputs::BigInt(3.into())); +builder.push_input("b", ark_circom::circom::Inputs::BigInt(11.into())); +// for 2d array, use +// builder.push_input("a", ark_circom::circom::Inputs::BigIntVec(vec[3.into()])); +// for 3d array, use +// builder.push_input( +// "a", +// ark_circom::circom::InputValue::BigIntVecVec(vec![ +// vec![ +// 3.into(), +// ], +// vec![ +// 11.into(), +// ], +// ]), +// ); // Create an empty instance for setting it up let circom = builder.setup(); diff --git a/benches/groth16.rs b/benches/groth16.rs index 2063d7f..0b31ea5 100644 --- a/benches/groth16.rs +++ b/benches/groth16.rs @@ -1,7 +1,7 @@ use ark_crypto_primitives::snark::SNARK; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use ark_circom::{read_zkey, CircomReduction, WitnessCalculator}; +use ark_circom::{circom::Inputs, read_zkey, CircomReduction, WitnessCalculator}; use ark_std::rand::thread_rng; use ark_bn254::Bn254; @@ -22,9 +22,10 @@ fn bench_groth(c: &mut Criterion, num_validators: u32, num_constraints: u32) { let num_constraints = matrices.num_constraints; let inputs = { - let mut inputs: HashMap> = HashMap::new(); - let values = inputs.entry("a".to_string()).or_insert_with(Vec::new); - values.push(3.into()); + let mut inputs: HashMap = HashMap::new(); + let values = inputs + .entry("a".to_string()) + .or_insert(Inputs::BigInt(3.into())); inputs }; diff --git a/src/circom/builder.rs b/src/circom/builder.rs index 670e22b..28a93c1 100644 --- a/src/circom/builder.rs +++ b/src/circom/builder.rs @@ -9,10 +9,40 @@ use std::collections::HashMap; use crate::{circom::R1CSFile, witness::WitnessCalculator}; use color_eyre::Result; +#[derive(Clone, Debug)] +pub enum Inputs { + BigInt(BigInt), + BigIntVec(Vec), + BigIntVecVec(Vec>), +} + +impl Inputs { + pub fn flatten(&self) -> Vec { + match self { + Inputs::BigInt(n) => vec![n.clone()], + Inputs::BigIntVec(n) => n.clone(), + Inputs::BigIntVecVec(n) => n.clone().into_iter().flatten().collect(), + } + } +} + +impl std::iter::FromIterator for Inputs { + fn from_iter>(iter: I) -> Self { + let items: Vec = iter.into_iter().collect(); + + match items.len() { + 0 => panic!("Cannot create Value from an empty iterator"), + 1 => Inputs::BigInt(items[0].clone()), + 2 => Inputs::BigIntVec(items), + _ => panic!("Cannot create Value from an iterator with more than 3 elements"), + } + } +} + #[derive(Clone, Debug)] pub struct CircomBuilder { pub cfg: CircomConfig, - pub inputs: HashMap>, + pub inputs: HashMap, } // Add utils for creating this from files / directly from bytes @@ -47,9 +77,8 @@ impl CircomBuilder { } /// Pushes a Circom input at the specified name. - pub fn push_input>(&mut self, name: impl ToString, val: T) { - let values = self.inputs.entry(name.to_string()).or_insert_with(Vec::new); - values.push(val.into()); + pub fn push_input(&mut self, name: impl ToString, val: Inputs) { + self.inputs.entry(name.to_string()).or_insert(val); } /// Generates an empty circom circuit with no witness set, to be used for diff --git a/src/circom/circuit.rs b/src/circom/circuit.rs index 12540e9..64706e5 100644 --- a/src/circom/circuit.rs +++ b/src/circom/circuit.rs @@ -92,6 +92,7 @@ mod tests { use crate::{CircomBuilder, CircomConfig}; use ark_bn254::{Bn254, Fr}; use ark_relations::r1cs::ConstraintSystem; + use num_bigint::BigInt; #[test] fn satisfied() { @@ -101,8 +102,8 @@ mod tests { ) .unwrap(); let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); - builder.push_input("b", 11); + builder.push_input("a", crate::circom::Inputs::BigInt(BigInt::from(3))); + builder.push_input("b", crate::circom::Inputs::BigInt(BigInt::from(11))); let circom = builder.build().unwrap(); let cs = ConstraintSystem::::new_ref(); diff --git a/src/circom/mod.rs b/src/circom/mod.rs index 6ad7fd0..3a5097f 100644 --- a/src/circom/mod.rs +++ b/src/circom/mod.rs @@ -7,7 +7,7 @@ mod circuit; pub use circuit::CircomCircuit; mod builder; -pub use builder::{CircomBuilder, CircomConfig}; +pub use builder::{CircomBuilder, CircomConfig, Inputs}; mod qap; pub use qap::CircomReduction; diff --git a/src/witness/witness_calculator.rs b/src/witness/witness_calculator.rs index 49582b9..dea8c0d 100644 --- a/src/witness/witness_calculator.rs +++ b/src/witness/witness_calculator.rs @@ -1,4 +1,5 @@ use super::{fnv, CircomBase, SafeMemory, Wasm}; +use crate::circom::Inputs; use color_eyre::Result; use num_bigint::BigInt; use num_traits::Zero; @@ -152,7 +153,7 @@ impl WitnessCalculator { } } - pub fn calculate_witness)>>( + pub fn calculate_witness>( &mut self, inputs: I, sanity_check: bool, @@ -173,7 +174,7 @@ impl WitnessCalculator { } // Circom 1 default behavior - fn calculate_witness_circom1)>>( + fn calculate_witness_circom1>( &mut self, inputs: I, sanity_check: bool, @@ -187,13 +188,13 @@ impl WitnessCalculator { // allocate the inputs for (name, values) in inputs.into_iter() { let (msb, lsb) = fnv(&name); - + let flatten = values.flatten(); self.instance .get_signal_offset32(p_sig_offset, 0, msb, lsb)?; let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize; - for (i, value) in values.into_iter().enumerate() { + for (i, value) in flatten.into_iter().enumerate() { self.memory.write_fr(p_fr as usize, &value)?; self.instance .set_signal(0, 0, (sig_offset + i) as u32, p_fr)?; @@ -216,7 +217,7 @@ impl WitnessCalculator { // Circom 2 feature flag with version 2 #[cfg(feature = "circom-2")] - fn calculate_witness_circom2)>>( + fn calculate_witness_circom2>( &mut self, inputs: I, sanity_check: bool, @@ -229,7 +230,8 @@ impl WitnessCalculator { for (name, values) in inputs.into_iter() { let (msb, lsb) = fnv(&name); - for (i, value) in values.into_iter().enumerate() { + let flatten = values.flatten(); + for (i, value) in flatten.into_iter().enumerate() { let f_arr = to_array32(&value, n32 as usize); for j in 0..n32 { self.instance @@ -256,7 +258,7 @@ impl WitnessCalculator { pub fn calculate_witness_element< E: ark_ec::pairing::Pairing, - I: IntoIterator)>, + I: IntoIterator, >( &mut self, inputs: I, @@ -473,23 +475,25 @@ mod tests { let inputs: std::collections::HashMap = serde_json::from_str(&inputs_str).unwrap(); - let inputs = inputs + let inputs: HashMap = inputs .iter() .map(|(key, value)| { let res = match value { - Value::String(inner) => { - vec![BigInt::from_str(inner).unwrap()] - } + Value::String(inner) => Inputs::BigInt(BigInt::from_str(inner).unwrap()), Value::Number(inner) => { - vec![BigInt::from(inner.as_u64().expect("not a u32"))] + Inputs::BigInt(BigInt::from(inner.as_u64().expect("not a u32"))) + } + Value::Array(inner) => { + let k = inner.iter().cloned().map(value_to_bigint).collect(); + k } - Value::Array(inner) => inner.iter().cloned().map(value_to_bigint).collect(), _ => panic!(), }; (key.clone(), res) }) .collect::>(); + println!("496, {:?}", inputs); let res = wtns.calculate_witness(inputs, false).unwrap(); for (r, w) in res.iter().zip(case.witness) { diff --git a/src/zkey.rs b/src/zkey.rs index 8dd6422..c972496 100644 --- a/src/zkey.rs +++ b/src/zkey.rs @@ -372,11 +372,12 @@ mod tests { use super::*; use ark_bn254::{G1Projective, G2Projective}; use ark_crypto_primitives::snark::SNARK; + use num_bigint::BigInt; use num_bigint::BigUint; use serde_json::Value; use std::fs::File; - use crate::circom::CircomReduction; + use crate::circom::{CircomReduction, Inputs}; use crate::witness::WitnessCalculator; use crate::{CircomBuilder, CircomConfig}; use ark_groth16::Groth16; @@ -854,8 +855,8 @@ mod tests { ) .unwrap(); let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); - builder.push_input("b", 11); + builder.push_input("a", Inputs::BigInt(BigInt::from(3))); + builder.push_input("b", Inputs::BigInt(BigInt::from(11))); let circom = builder.build().unwrap(); @@ -878,12 +879,14 @@ mod tests { let (params, matrices) = read_zkey(&mut file).unwrap(); let mut wtns = WitnessCalculator::new("./test-vectors/mycircuit.wasm").unwrap(); - let mut inputs: HashMap> = HashMap::new(); - let values = inputs.entry("a".to_string()).or_insert_with(Vec::new); - values.push(3.into()); - - let values = inputs.entry("b".to_string()).or_insert_with(Vec::new); - values.push(11.into()); + let mut inputs: HashMap = HashMap::new(); + inputs + .entry("a".to_string()) + .or_insert(Inputs::BigInt(BigInt::from(3))); + + inputs + .entry("b".to_string()) + .or_insert(Inputs::BigInt(BigInt::from(11))); let mut rng = thread_rng(); use ark_std::UniformRand; diff --git a/tests/groth16.rs b/tests/groth16.rs index 8b29ba3..39dbaec 100644 --- a/tests/groth16.rs +++ b/tests/groth16.rs @@ -1,10 +1,9 @@ -use ark_circom::{CircomBuilder, CircomConfig}; -use ark_std::rand::thread_rng; -use color_eyre::Result; - use ark_bn254::Bn254; +use ark_circom::{CircomBuilder, CircomConfig}; use ark_crypto_primitives::snark::SNARK; use ark_groth16::Groth16; +use ark_std::rand::thread_rng; +use color_eyre::Result; type GrothBn = Groth16; @@ -15,8 +14,8 @@ fn groth16_proof() -> Result<()> { "./test-vectors/mycircuit.r1cs", )?; let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); - builder.push_input("b", 11); + builder.push_input("a", ark_circom::circom::Inputs::BigInt(3.into())); + builder.push_input("b", ark_circom::circom::Inputs::BigInt(11.into())); // create an empty instance for setting it up let circom = builder.setup(); @@ -47,9 +46,9 @@ fn groth16_proof_wrong_input() { ) .unwrap(); let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); + builder.push_input("a", ark_circom::circom::Inputs::BigInt(3.into())); // This isn't a public input to the circuit, should fail - builder.push_input("foo", 11); + builder.push_input("foo", ark_circom::circom::Inputs::BigInt(11.into())); // create an empty instance for setting it up let circom = builder.setup(); @@ -68,8 +67,8 @@ fn groth16_proof_circom2() -> Result<()> { "./test-vectors/circom2_multiplier2.r1cs", )?; let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); - builder.push_input("b", 11); + builder.push_input("a", ark_circom::circom::Inputs::BigInt(3.into())); + builder.push_input("b", ark_circom::circom::Inputs::BigInt(11.into())); // create an empty instance for setting it up let circom = builder.setup(); @@ -100,8 +99,11 @@ fn witness_generation_circom2() -> Result<()> { "./test-vectors/circom2_multiplier2.r1cs", )?; let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); - builder.push_input("b", 0x100000000u64 - 1); + builder.push_input("a", ark_circom::circom::Inputs::BigInt(3.into())); + builder.push_input( + "b", + ark_circom::circom::Inputs::BigInt((0x100000000u64 - 1).into()), + ); assert!(builder.build().is_ok()); diff --git a/tests/solidity.rs b/tests/solidity.rs index c183709..d3944d0 100644 --- a/tests/solidity.rs +++ b/tests/solidity.rs @@ -21,8 +21,8 @@ async fn solidity_verifier() -> Result<()> { "./test-vectors/mycircuit.r1cs", )?; let mut builder = CircomBuilder::new(cfg); - builder.push_input("a", 3); - builder.push_input("b", 11); + builder.push_input("a", ark_circom::circom::Inputs::BigInt(3.into())); + builder.push_input("b", ark_circom::circom::Inputs::BigInt(11.into())); // create an empty instance for setting it up let circom = builder.setup();