Skip to content

Add num_instances to transcript #648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ report.txt
report.json
table_cache_dev_*
.DS_Store
.env
tracing.folded
.env
13 changes: 9 additions & 4 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,20 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let (witness, num_instances) = wits
.remove(circuit_name)
.ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?;
if witness.is_empty() {
continue;
}
let wits_commit = commitments.remove(circuit_name).unwrap();
// TODO: add an enum for circuit type either in constraint_system or vk
let cs = pk.get_cs();
let is_opcode_circuit = cs.lk_table_expressions.is_empty()
&& cs.r_table_expressions.is_empty()
&& cs.w_table_expressions.is_empty();
if is_opcode_circuit {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the table circuits may also be skipped. Is there a blocker that requires this if?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it is a little less obvious how the num_instances of tables is communicated. Look at the verifier variable expected_rounds. Maybe we can do the table case somehow in another PR.

transcript.append_field_element(&E::BaseField::from(num_instances as u64));
}

if witness.is_empty() {
assert!(num_instances == 0);
continue;
}
let wits_commit = commitments.remove(circuit_name).unwrap();

if is_opcode_circuit {
tracing::debug!(
Expand Down
92 changes: 56 additions & 36 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,43 +140,63 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
let point_eval = PointAndEval::default();
let mut transcripts = transcript.fork(self.vk.circuit_vks.len());

for (name, (i, opcode_proof)) in vm_proof.opcode_proofs {
let transcript = &mut transcripts[i];
// For each opcode, include the num_instances
// into its corresponding fork of the transcript.
for ((name, _), transcript) in self
.vk
.circuit_vks
.iter() // Sorted by key.
.zip_eq(transcripts.iter_mut())
{
// get num_instances from opcode proof
let opcode_result = vm_proof.opcode_proofs.get(name);
let num_instances = opcode_result.map(|(_, p)| p.num_instances).unwrap_or(0);
if opcode_result.is_some() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to be the same as if is_opcode_circuit in the prover side, but I’m not sure that’s the case. Looking at vm_proof.opcode_proofs.insert which ultimately becomes opcode_result.is_some(), not the same as is_opcode_circuit?

transcript.append_field_element(&E::BaseField::from(num_instances as u64));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would protect against more potential attacks to do this before the alpha, beta read_challenge above. It’s before the transcript fork, so just a sequence of append to the root transcript.

}
}

let circuit_vk = self
.vk
.circuit_vks
.get(&name)
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
let _rand_point = self.verify_opcode_proof(
&name,
&self.vk.vp,
circuit_vk,
&opcode_proof,
pi_evals,
transcript,
NUM_FANIN,
&point_eval,
&challenges,
)?;
tracing::info!("verified proof for opcode {}", name);

// getting the number of dummy padding item that we used in this opcode circuit
let num_lks = circuit_vk.get_cs().lk_expressions.len();
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks;
let num_padded_instance =
next_pow2_instance_padding(opcode_proof.num_instances) - opcode_proof.num_instances;
dummy_table_item_multiplicity += num_padded_lks_per_instance
* opcode_proof.num_instances
+ num_lks.next_power_of_two() * num_padded_instance;

prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>();
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>();

logup_sum +=
opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap();
logup_sum +=
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
for (name, (i, opcode_proof)) in &vm_proof.opcode_proofs {
let transcript = &mut transcripts[*i];
let num_instances = opcode_proof.num_instances;

if num_instances != 0 {
let opcode_proof = &vm_proof.opcode_proofs.get(name).unwrap().1;
let circuit_vk = self
.vk
.circuit_vks
.get(name)
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
let _rand_point = self.verify_opcode_proof(
name,
&self.vk.vp,
circuit_vk,
opcode_proof,
pi_evals,
transcript,
NUM_FANIN,
&point_eval,
&challenges,
)?;
tracing::info!("verified proof for opcode {}", name);

// getting the number of dummy padding item that we used in this opcode circuit
let num_lks = circuit_vk.get_cs().lk_expressions.len();
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks;
let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances)
- opcode_proof.num_instances;
dummy_table_item_multiplicity += num_padded_lks_per_instance
* opcode_proof.num_instances
+ num_lks.next_power_of_two() * num_padded_instance;

prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>();
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>();

logup_sum +=
opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap();
logup_sum +=
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
}
}

for (name, (i, table_proof)) in vm_proof.table_proofs {
Expand Down