Skip to content

Commit de5cc96

Browse files
committed
correct opt: half size on poseidon and keep ecc sum as binary tree
1 parent 8162e6f commit de5cc96

3 files changed

Lines changed: 60 additions & 134 deletions

File tree

ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,9 @@ pub fn extract_shard_ram_column_map<E: ExtensionField>(
7070

7171
let mut x = [0u32; 7];
7272
let mut y = [0u32; 7];
73-
let mut slope = [0u32; 7];
7473
for i in 0..7 {
7574
x[i] = config.x[i].id as u32;
7675
y[i] = config.y[i].id as u32;
77-
slope[i] = config.slope[i].id as u32;
7876
}
7977

8078
// Poseidon2 columns: p3_cols are contiguous, followed by post_linear_layer_cols
@@ -114,7 +112,7 @@ pub fn extract_shard_ram_column_map<E: ExtensionField>(
114112
is_global_write,
115113
x,
116114
y,
117-
slope,
115+
slope: [0; 7],
118116
poseidon2_base_col,
119117
num_poseidon2_cols,
120118
num_p3_cols,
@@ -126,22 +124,6 @@ pub fn extract_shard_ram_ec_tree_column_map<E: ExtensionField>(
126124
config: &ShardRamEcTreeConfig<E>,
127125
num_witin: usize,
128126
) -> ShardRamColumnMap {
129-
let addr = config.addr.id as u32;
130-
let is_ram_register = config.is_ram_register.id as u32;
131-
132-
let value_limbs = config
133-
.value
134-
.wits_in()
135-
.expect("value should have WitIn limbs");
136-
assert_eq!(value_limbs.len(), 2, "Expected 2 value limbs");
137-
let value = [value_limbs[0].id as u32, value_limbs[1].id as u32];
138-
139-
let shard = config.shard.id as u32;
140-
let global_clk = config.global_clk.id as u32;
141-
let local_clk = config.local_clk.id as u32;
142-
let nonce = config.nonce.id as u32;
143-
let is_global_write = config.is_global_write.id as u32;
144-
145127
let mut x = [0u32; 7];
146128
let mut y = [0u32; 7];
147129
let mut slope = [0u32; 7];
@@ -152,14 +134,14 @@ pub fn extract_shard_ram_ec_tree_column_map<E: ExtensionField>(
152134
}
153135

154136
ShardRamColumnMap {
155-
addr,
156-
is_ram_register,
157-
value,
158-
shard,
159-
global_clk,
160-
local_clk,
161-
nonce,
162-
is_global_write,
137+
addr: 0,
138+
is_ram_register: 0,
139+
value: [0; 2],
140+
shard: 0,
141+
global_clk: 0,
142+
local_clk: 0,
143+
nonce: 0,
144+
is_global_write: 0,
163145
x,
164146
y,
165147
slope,

ceno_zkvm/src/scheme/cpu/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ use crate::{
1414
GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims,
1515
extract_ecc_quark_witness_inputs, first_layer_output_group_stage_masks,
1616
first_layer_selector_contexts, infer_tower_logup_witness, infer_tower_product_witness,
17-
interleaving_mles_to_mles,
18-
split_rotation_evals,
17+
interleaving_mles_to_mles, split_rotation_evals,
1918
},
2019
verifier::eval_batched_main_frontload_terms,
2120
},

ceno_zkvm/src/tables/shard_ram.rs

Lines changed: 50 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,12 @@ use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTE
4343

4444
pub(crate) const Y6_LO_TOP_BYTE_LT_BOUND: u64 = 60;
4545

46-
fn shard_ram_ec_point_record<E: ExtensionField>(
47-
is_global_write: Expression<E>,
48-
ram_type: Expression<E>,
49-
addr: Expression<E>,
50-
value: impl IntoIterator<Item = Expression<E>>,
51-
shard: Expression<E>,
52-
global_clk: Expression<E>,
53-
local_clk: Expression<E>,
54-
nonce: Expression<E>,
55-
x: &[WitIn],
56-
y: &[WitIn],
57-
) -> Vec<Expression<E>> {
58-
[
59-
CustomRWTag::ShardRamEcPoint.expr::<E>(),
60-
is_global_write,
61-
ram_type,
62-
addr,
63-
]
64-
.into_iter()
65-
.chain(value)
66-
.chain([shard, global_clk, local_clk, nonce])
67-
.chain(x.iter().map(|w| w.expr()))
68-
.chain(y.iter().map(|w| w.expr()))
69-
.collect()
46+
fn shard_ram_ec_point_record<E: ExtensionField>(x: &[WitIn], y: &[WitIn]) -> Vec<Expression<E>> {
47+
[CustomRWTag::ShardRamEcPoint.expr::<E>()]
48+
.into_iter()
49+
.chain(x.iter().map(|w| w.expr()))
50+
.chain(y.iter().map(|w| w.expr()))
51+
.collect()
7052
}
7153

7254
/// A record for a read/write into the shard RAM
@@ -213,7 +195,6 @@ pub struct ShardRamConfig<E: ExtensionField> {
213195
pub(crate) is_global_write: WitIn,
214196
pub(crate) x: Vec<WitIn>,
215197
pub(crate) y: Vec<WitIn>,
216-
pub(crate) slope: Vec<WitIn>,
217198
// Byte limbs of `y6_lo`, the helper that binds `y[SEPTIC_EXTENSION_DEGREE - 1]`
218199
// to `is_global_write` in `configure`.
219200
pub(crate) y6_lo_bytes: [WitIn; 4],
@@ -229,9 +210,6 @@ impl<E: ExtensionField> ShardRamConfig<E> {
229210
let y: Vec<WitIn> = (0..SEPTIC_EXTENSION_DEGREE)
230211
.map(|i| cb.create_witin(|| format!("y{}", i)))
231212
.collect();
232-
let slope: Vec<WitIn> = (0..SEPTIC_EXTENSION_DEGREE)
233-
.map(|i| cb.create_witin(|| format!("slope{}", i)))
234-
.collect();
235213
let addr = cb.create_witin(|| "addr");
236214
let is_ram_register = cb.create_witin(|| "is_ram_register");
237215
let value = UInt::new_unchecked(|| "value", cb)?;
@@ -292,18 +270,7 @@ impl<E: ExtensionField> ShardRamConfig<E> {
292270
cb.rlc_chip_record(record),
293271
)?;
294272

295-
let ec_point_record = shard_ram_ec_point_record(
296-
is_global_write.expr(),
297-
ram_type,
298-
addr.expr(),
299-
value.memory_expr(),
300-
shard.expr(),
301-
global_clk.expr(),
302-
local_clk.expr(),
303-
nonce.expr(),
304-
&x,
305-
&y,
306-
);
273+
let ec_point_record = shard_ram_ec_point_record(&x, &y);
307274
cb.read_record(
308275
|| "shard_ram_ec_point_in",
309276
RAMType::Custom,
@@ -367,7 +334,6 @@ impl<E: ExtensionField> ShardRamConfig<E> {
367334
Ok(ShardRamConfig {
368335
x,
369336
y,
370-
slope,
371337
addr,
372338
is_ram_register,
373339
value,
@@ -383,14 +349,6 @@ impl<E: ExtensionField> ShardRamConfig<E> {
383349
}
384350

385351
pub struct ShardRamEcTreeConfig<E: ExtensionField> {
386-
pub(crate) addr: WitIn,
387-
pub(crate) is_ram_register: WitIn,
388-
pub(crate) value: UInt<E>,
389-
pub(crate) shard: WitIn,
390-
pub(crate) global_clk: WitIn,
391-
pub(crate) local_clk: WitIn,
392-
pub(crate) nonce: WitIn,
393-
pub(crate) is_global_write: WitIn,
394352
pub(crate) x: Vec<WitIn>,
395353
pub(crate) y: Vec<WitIn>,
396354
pub(crate) slope: Vec<WitIn>,
@@ -399,14 +357,6 @@ pub struct ShardRamEcTreeConfig<E: ExtensionField> {
399357

400358
impl<E: ExtensionField> ShardRamEcTreeConfig<E> {
401359
pub fn configure(cb: &mut CircuitBuilder<E>) -> Result<Self, CircuitBuilderError> {
402-
let addr = cb.create_witin(|| "addr");
403-
let is_ram_register = cb.create_witin(|| "is_ram_register");
404-
let value = UInt::new_unchecked(|| "value", cb)?;
405-
let shard = cb.create_witin(|| "shard");
406-
let global_clk = cb.create_witin(|| "global_clk");
407-
let local_clk = cb.create_witin(|| "local_clk");
408-
let nonce = cb.create_witin(|| "nonce");
409-
let is_global_write = cb.create_witin(|| "is_global_write");
410360
let x: Vec<WitIn> = (0..SEPTIC_EXTENSION_DEGREE)
411361
.map(|i| cb.create_witin(|| format!("x{i}")))
412362
.collect();
@@ -417,22 +367,7 @@ impl<E: ExtensionField> ShardRamEcTreeConfig<E> {
417367
.map(|i| cb.create_witin(|| format!("slope{i}")))
418368
.collect();
419369

420-
let is_ram_reg: Expression<E> = is_ram_register.expr();
421-
let reg: Expression<E> = RAMType::Register.into();
422-
let mem: Expression<E> = RAMType::Memory.into();
423-
let ram_type = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem;
424-
let ec_point_record = shard_ram_ec_point_record(
425-
is_global_write.expr(),
426-
ram_type,
427-
addr.expr(),
428-
value.memory_expr(),
429-
shard.expr(),
430-
global_clk.expr(),
431-
local_clk.expr(),
432-
nonce.expr(),
433-
&x,
434-
&y,
435-
);
370+
let ec_point_record = shard_ram_ec_point_record(&x, &y);
436371
cb.read_record(
437372
|| "shard_ram_ec_point_in",
438373
RAMType::Custom,
@@ -453,14 +388,6 @@ impl<E: ExtensionField> ShardRamEcTreeConfig<E> {
453388
);
454389

455390
Ok(Self {
456-
addr,
457-
is_ram_register,
458-
value,
459-
shard,
460-
global_clk,
461-
local_clk,
462-
nonce,
463-
is_global_write,
464391
x,
465392
y,
466393
slope,
@@ -573,10 +500,10 @@ impl<E: ExtensionField> ShardRamCircuit<E> {
573500
input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk);
574501
input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce);
575502

576-
config
577-
.perm_config
578-
// TODO: remove hardcoded constant 28
579-
.assign_instance(&mut instance[28 + UINT_LIMBS..], input);
503+
config.perm_config.assign_instance(
504+
&mut instance[config.perm_config.p3_cols[0].id as usize..],
505+
input,
506+
);
580507

581508
Ok(())
582509
}
@@ -811,25 +738,6 @@ impl<E: ExtensionField> ShardRamEcTreeCircuit<E> {
811738
instance: &mut [E::BaseField],
812739
input: &ShardRamInput<E>,
813740
) {
814-
let record = &input.record;
815-
let is_ram_register = match record.ram_type {
816-
RAMType::Register => 1,
817-
RAMType::Memory => 0,
818-
_ => unreachable!(),
819-
};
820-
set_val!(instance, config.addr, record.addr as u64);
821-
set_val!(instance, config.is_ram_register, is_ram_register as u64);
822-
let value = Value::new_unchecked(record.value);
823-
config.value.assign_limbs(instance, value.as_u16_limbs());
824-
set_val!(instance, config.shard, record.shard);
825-
set_val!(instance, config.global_clk, record.global_clk);
826-
set_val!(instance, config.local_clk, record.local_clk);
827-
set_val!(instance, config.nonce, input.ec_point.nonce as u64);
828-
set_val!(
829-
instance,
830-
config.is_global_write,
831-
record.is_to_write_set as u64
832-
);
833741
config
834742
.x
835743
.iter()
@@ -1259,6 +1167,24 @@ mod tests {
12591167
}
12601168
}
12611169

1170+
fn assert_record_rows_match(
1171+
left: &Arc<multilinear_extensions::mle::MultilinearExtension<'_, E>>,
1172+
left_rows: std::ops::Range<usize>,
1173+
right: &Arc<multilinear_extensions::mle::MultilinearExtension<'_, E>>,
1174+
right_rows: std::ops::Range<usize>,
1175+
label: &str,
1176+
) {
1177+
assert_eq!(left_rows.len(), right_rows.len(), "{label} row count");
1178+
let left_evals = left.get_ext_field_vec();
1179+
let right_evals = right.get_ext_field_vec();
1180+
for (left_row, right_row) in left_rows.zip(right_rows) {
1181+
assert_eq!(
1182+
left_evals[left_row], right_evals[right_row],
1183+
"{label}: left row {left_row}, right row {right_row}"
1184+
);
1185+
}
1186+
}
1187+
12621188
#[test]
12631189
fn test_shard_ram_split_selectors_and_tower_padding() {
12641190
let read_count = 2;
@@ -1405,7 +1331,7 @@ mod tests {
14051331
build_main_witness::<E, Pcs, CpuBackend<E, Pcs>, CpuProver<CpuBackend<E, Pcs>>>(
14061332
&ec_tree_composed,
14071333
&ec_tree_proof_input,
1408-
&[E::ONE, E::from_canonical_u32(11)],
1334+
&[E::ONE, E::from_canonical_u32(7)],
14091335
WitnessBuildStage::Tower,
14101336
);
14111337
let ec_tree_r_len = ec_tree_composed.zkvm_v1_css.r_expressions.len()
@@ -1422,6 +1348,25 @@ mod tests {
14221348
ec_tree_r_len..ec_tree_r_len + ec_tree_w_len,
14231349
(0..write_count).chain(write_count + read_count..ec_tree_witness[0].height()),
14241350
);
1351+
1352+
let leaf_custom_read = &leaf_records[leaf_r_len - 1];
1353+
let leaf_custom_write = &leaf_records[leaf_r_len + leaf_w_len - 1];
1354+
let ec_tree_custom_read = &ec_tree_records[ec_tree_r_len - 1];
1355+
let ec_tree_custom_write = &ec_tree_records[ec_tree_r_len + ec_tree_w_len - 1];
1356+
assert_record_rows_match(
1357+
leaf_custom_read,
1358+
0..read_count,
1359+
ec_tree_custom_write,
1360+
write_count..write_count + read_count,
1361+
"leaf read vs ec-tree write",
1362+
);
1363+
assert_record_rows_match(
1364+
leaf_custom_write,
1365+
read_count..read_count + write_count,
1366+
ec_tree_custom_read,
1367+
0..write_count,
1368+
"leaf write vs ec-tree read",
1369+
);
14251370
}
14261371

14271372
#[test]

0 commit comments

Comments
 (0)