diff --git a/std/machines/hash/keccakf32_memory.asm b/std/machines/hash/keccakf32_memory.asm index de4ae884c3..3bc1507a47 100644 --- a/std/machines/hash/keccakf32_memory.asm +++ b/std/machines/hash/keccakf32_memory.asm @@ -8,7 +8,7 @@ use std::convert::fe; use std::prelude::set_hint; use std::prelude::Query; use std::prover::eval; -use std::prover::provide_value; +use std::prover::compute_from_multi; use std::machines::large_field::memory::Memory; machine Keccakf32Memory(mem: Memory) with @@ -578,6 +578,8 @@ machine Keccakf32Memory(mem: Memory) with }); // Prover function section (for witness generation). + // Hints are only needed for c and a_prime, the solver is able to figure out the + // rest of the witness. // // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]). // for x in 0..5 { @@ -592,49 +594,21 @@ machine Keccakf32Memory(mem: Memory) with // } // } - let query_c: int, int, int -> int = query |x, limb, bit_in_limb| - utils::fold( - 5, - |y| (int(eval(a[y * 10 + x * 2 + limb])) >> bit_in_limb) & 0x1, - 0, - |acc, e| acc ^ e - ); - - query |row| { - let _ = array::map_enumerated(c, |i, c_i| { + query |row| compute_from_multi( + c, row, a, + |a_fe| array::new(array::len(c), |i| { let x = i / 64; let z = i % 64; let limb = z / 32; let bit_in_limb = z % 32; - provide_value(c_i, row, fe(query_c(x, limb, bit_in_limb))); - }); - }; - - // // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). - // for x in 0..5 { - // for z in 0..64 { - // row.c_prime[x][z] = xor([ - // row.c[x][z], - // row.c[(x + 4) % 5][z], - // row.c[(x + 1) % 5][(z + 63) % 64], - // ]); - // } - // } - - let query_c_prime: int, int -> int = query |x, z| - int(eval(c[x * 64 + z])) ^ - int(eval(c[((x + 4) % 5) * 64 + z])) ^ - int(eval(c[((x + 1) % 5) * 64 + (z + 63) % 64])); - - query |row| { - let _ = array::map_enumerated(c_prime, |i, c_i| { - let x = i / 64; - let z = i % 64; - - provide_value(c_i, row, fe(query_c_prime(x, z))); - }); - }; + fe(utils::fold( + 5, + |y| (int(a_fe[y * 10 + x * 2 + limb]) >> bit_in_limb) & 0x1, + 0, + |acc, e| acc ^ e + )) + })); // // Populate A'. To avoid shifting indices, we rewrite // // A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1]) @@ -652,110 +626,32 @@ machine Keccakf32Memory(mem: Memory) with // } // } - let query_a_prime: int, int, int, int, int -> int = query |x, y, z, limb, bit_in_limb| - ((int(eval(a[y * 10 + x * 2 + limb])) >> bit_in_limb) & 0x1) ^ - int(eval(c[x * 64 + z])) ^ - int(eval(c_prime[x * 64 + z])); - - query |row| { - let _ = array::map_enumerated(a_prime, |i, a_i| { + query |row| compute_from_multi( + a_prime, row, a + c + c_prime, + |inputs| array::new(array::len(a_prime), |i| { let y = i / 320; let x = (i / 64) % 5; let z = i % 64; let limb = z / 32; let bit_in_limb = z % 32; - provide_value(a_i, row, fe(query_a_prime(x, y, z, limb, bit_in_limb))); - }); - }; + let a_elem = inputs[y * 10 + x * 2 + limb]; + let c_elem = inputs[x * 64 + z + 5 * 5 * 2]; + let c_prime_elem = inputs[x * 64 + z + 5 * 5 * 2 + 5 * 64]; - // // Populate A''.P - // // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). - // for y in 0..5 { - // for x in 0..5 { - // for limb in 0..U64_LIMBS { - // row.a_prime_prime[y][x][limb] = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) - // .rev() - // .fold(F::zero(), |acc, z| { - // let bit = xor([ - // row.b(x, y, z), - // andn(row.b((x + 1) % 5, y, z), row.b((x + 2) % 5, y, z)), - // ]); - // acc.double() + bit - // }); - // } - // } - // } - - let query_a_prime_prime: int, int, int -> int = query |x, y, limb| - utils::fold( - 32, - |z| - int(eval(b(x, y, (limb + 1) * 32 - 1 - z))) ^ - int(eval(andn(b((x + 1) % 5, y, (limb + 1) * 32 - 1 - z), - b((x + 2) % 5, y, (limb + 1) * 32 - 1 - z)))), - 0, - |acc, e| acc * 2 + e - ); - - query |row| { - let _ = array::map_enumerated(a_prime_prime, |i, a_i| { - let y = i / 10; - let x = (i / 2) % 5; - let limb = i % 2; - - provide_value(a_i, row, fe(query_a_prime_prime(x, y, limb))); - }); - }; - - // // For the XOR, we split A''[0, 0] to bits. - // let mut val = 0; // smaller address correspond to less significant limb - // for limb in 0..U64_LIMBS { - // let val_limb = row.a_prime_prime[0][0][limb].as_canonical_u64(); - // val |= val_limb << (limb * BITS_PER_LIMB); - // } - // let val_bits: Vec<bool> = (0..64) // smaller address correspond to less significant bit - // .scan(val, |acc, _| { - // let bit = (*acc & 1) != 0; - // *acc >>= 1; - // Some(bit) - // }) - // .collect(); - // for (i, bit) in row.a_prime_prime_0_0_bits.iter_mut().enumerate() { - // *bit = F::from_bool(val_bits[i]); - // } + fe(((int(a_elem) >> bit_in_limb) & 0x1) ^ int(c_elem) ^ int(c_prime_elem)) + })); + // TODO: This hint is correct but not needed (the solver can figure this out). + // We keep it here because it prevents the JIT solver from succeeding (because of the + // use of `provide_value`), because it currently fails when compiling Rust code. + // Once these issues are resolved, we can remove this hint. query |row| { - let _ = array::map_enumerated(a_prime_prime_0_0_bits, |i, a_i| { - let limb = i / 32; - let bit_in_limb = i % 32; - - provide_value( - a_i, - row, - fe((int(eval(a_prime_prime[limb])) >> bit_in_limb) & 0x1) - ); - }); + std::prover::provide_value( + a_prime_prime_0_0_bits[0], + row, + fe((int(eval(a_prime_prime[0]))) & 0x1) + ); }; - // // A''[0, 0] is additionally xor'd with RC. - // for limb in 0..U64_LIMBS { - // let rc_lo = rc_value_limb(round, limb); - // row.a_prime_prime_prime_0_0_limbs[limb] = - // F::from_canonical_u16(row.a_prime_prime[0][0][limb].as_canonical_u64() as u16 ^ rc_lo); - // } - - let query_a_prime_prime_prime_0_0_limbs: int, int -> int = query |round, limb| - int(eval(a_prime_prime[limb])) ^ - ((RC[round] >> (limb * 32)) & 0xffffffff); - - query |row| { - let _ = array::new(2, |limb| { - provide_value( - a_prime_prime_prime_0_0_limbs[limb], - row, - fe(query_a_prime_prime_prime_0_0_limbs(row % NUM_ROUNDS, limb) - )); - }); - }; }