diff --git a/Cargo.toml b/Cargo.toml index e7a97a9de7..9d997c8d43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ members = [ "riscv-syscalls", "schemas", "backend-utils", - "executor-utils", + "executor-utils", "precompile_macro", ] exclude = ["riscv-runtime"] diff --git a/analysis/src/lib.rs b/analysis/src/lib.rs index 1fa1b08af7..e0b804b0e1 100644 --- a/analysis/src/lib.rs +++ b/analysis/src/lib.rs @@ -1,9 +1,33 @@ pub mod machine_check; mod vm; +use std::collections::{BTreeMap, BTreeSet}; +use std::ops::ControlFlow; + +use powdr_ast::asm_analysis::InstructionDefinitionStatement; +use powdr_ast::parsed::asm::{ + parse_absolute_path, AbsoluteSymbolPath, CallableRef, Instruction, InstructionBody, + LinkDeclaration, MachineParams, OperationId, Param, Params, Part, SymbolPath, +}; use powdr_ast::{asm_analysis::AnalysisASMFile, parsed::asm::ASMProgram}; +use powdr_ast::{ + asm_analysis::{ + CallableSymbol, CallableSymbolDefinitions, FunctionStatement, FunctionStatements, + InstructionStatement, LabelStatement, LinkDefinition, Machine, MachineDegree, Module, + OperationSymbol, SubmachineDeclaration, + }, + parsed::{ + visitor::{ExpressionVisitable, VisitOrder}, + BinaryOperator, FunctionCall, NamespacedPolynomialReference, Number, PilStatement, + }, +}; +use powdr_number::BigUint; use powdr_number::FieldElement; +type Expression = powdr_ast::asm_analysis::Expression; + +const MAIN_MACHINE_STR: &str = "::Main"; + pub fn convert_asm_to_pil( file: ASMProgram, ) -> Result> { @@ -11,6 +35,38 @@ pub fn convert_asm_to_pil( Ok(powdr_asm_to_pil::compile::(file)) } +pub fn check(file: ASMProgram) -> Result> { + log::debug!("Run machine check analysis step"); + let mut file = machine_check::check(file)?; + annotate_basic_blocks(&mut file); + + Ok(file) +} + +pub fn analyze_precompiles( + analyzed_asm: AnalysisASMFile, + selected: &BTreeSet, +) -> AnalysisASMFile { + let main_machine_path = parse_absolute_path(MAIN_MACHINE_STR); + if analyzed_asm + .machines() + .all(|(path, _)| path != main_machine_path) + { + return analyzed_asm; + } + + create_precompiles(analyzed_asm, selected) +} + +pub fn analyze_only(file: AnalysisASMFile) -> Result> { + // run analysis on virtual machines, batching instructions + log::debug!("Start asm analysis"); + let file = vm::analyze(file)?; + log::debug!("End asm analysis"); + + Ok(file) +} + pub fn analyze(file: ASMProgram) -> Result> { log::debug!("Run machine check analysis step"); let file = machine_check::check(file)?; @@ -36,3 +92,1028 @@ pub mod utils { PIL_STATEMENT_PARSER.parse(&ctx, input).unwrap() } } + +pub fn transform_autoprecompile_blocks( + //program: &mut Vec, + program: &mut FunctionStatements, + selected: &BTreeSet, +) -> Vec<(String, Vec)> { + let mut blocks = Vec::new(); + let mut i = 0; + + let program = &mut program.inner; + while i < program.len() { + if let FunctionStatement::Label(LabelStatement { source, name }) = &program[i] { + if selected.get(name).is_some() { + // Start collecting statements in this block + let mut block_statements = Vec::new(); + let mut j = i + 1; + + // Collect until next label, branch, or return + while j < program.len() { + match &program[j] { + FunctionStatement::Label(_) => break, + FunctionStatement::Instruction(InstructionStatement { + source: _, + instruction, + inputs: _, + }) if instruction.starts_with("branch") + || instruction.starts_with("jump") + || instruction.starts_with("skip") => + { + break + } + FunctionStatement::Return(_) | FunctionStatement::Assignment(_) => break, + _ => { + block_statements.push(program[j].clone()); + j += 1; + } + } + } + + // Create a synthetic instruction + //let synthetic_instruction_name = format!("synthetic_{}", name); + let synthetic_instruction_name = name.clone(); + let synthetic_instruction = FunctionStatement::Instruction(InstructionStatement { + source: source.clone(), + instruction: synthetic_instruction_name.clone(), + inputs: vec![], + }); + + blocks.push((name.clone(), block_statements)); + + /* + println!( + "Replacing block '{}' with synthetic instruction '{}'", + name, synthetic_instruction_name + ); + */ + + // Replace the block contents + program.splice(i + 1..j, std::iter::once(synthetic_instruction)); + + // Adjust the index to continue after the synthetic instruction + i += 1; + } else { + i += 1; + } + } else { + i += 1; + } + } + + blocks +} + +pub fn generate_precompile( + statements: &Vec, + instruction_map: &BTreeMap, + degree: MachineDegree, + wits: &[&PilStatement], + identities: &[&PilStatement], + main_links: &[LinkDefinition], +) -> Machine { + let mut ssa_counter = 0; + let mut constraints: Vec = Vec::new(); + let mut links: Vec = Vec::new(); + + // let latch = 1; + let latch = PilStatement::LetStatement( + Default::default(), + "latch".to_string(), + None, + Some(Expression::Number( + Default::default(), + Number { + value: BigUint::from(1u32), + type_: None, + }, + )), + ); + + // let operation_id; + //let op_id = + // PilStatement::LetStatement(Default::default(), "operation_id".to_string(), None, None); + + // let step; + let step = PilStatement::LetStatement(Default::default(), "step".to_string(), None, None); + + // std::array::sum(sel) + let sum_fun_symbol = SymbolPath::from_parts( + ["std", "array", "sum"] + .into_iter() + .map(|p| Part::Named(p.to_string())), + ); + let sum_fun: NamespacedPolynomialReference = sum_fun_symbol.into(); + let sum_fun: Expression = Expression::Reference(Default::default(), sum_fun); + let sel_symbol = SymbolPath::from_identifier("sel".to_string()); + let sel: NamespacedPolynomialReference = sel_symbol.into(); + let sel = Expression::Reference(Default::default(), sel); + let sum_sel = Expression::FunctionCall( + Default::default(), + FunctionCall { + function: Box::new(sum_fun), + arguments: vec![sel], + }, + ); + + // let used = std::array::sum(sel); + let used = + PilStatement::LetStatement(Default::default(), "used".to_string(), None, Some(sum_sel)); + + // std::utils::force_bool(used); + let bool_fun_symbol = SymbolPath::from_parts( + ["std", "utils", "force_bool"] + .into_iter() + .map(|p| Part::Named(p.to_string())), + ); + let bool_fun: NamespacedPolynomialReference = bool_fun_symbol.into(); + let bool_fun: Expression = Expression::Reference(Default::default(), bool_fun); + let used_symbol = SymbolPath::from_identifier("used".to_string()); + let used_ref: NamespacedPolynomialReference = used_symbol.into(); + let used_ref = Expression::Reference(Default::default(), used_ref); + let bool_used_ref = Expression::FunctionCall( + Default::default(), + FunctionCall { + function: Box::new(bool_fun), + arguments: vec![used_ref.clone()], + }, + ); + let bool_used_ref = PilStatement::Expression(Default::default(), bool_used_ref); + + // operation run step; + let op_param = Param { + source: Default::default(), + name: "step".to_string(), + index: None, + ty: None, + }; + let run_symbol = OperationSymbol { + source: Default::default(), + id: OperationId { id: None }, + params: Params::::new(vec![op_param], vec![]), + }; + let run_symbol = CallableSymbol::Operation(run_symbol); + let callable_defs = + CallableSymbolDefinitions([("run".to_string(), run_symbol)].into_iter().collect()); + + constraints.push(step); + constraints.push(latch); + //constraints.push(op_id); + constraints.push(used); + constraints.push(bool_used_ref); + + let step_symbol = SymbolPath::from_identifier("step".to_string()); + let step_ref: NamespacedPolynomialReference = step_symbol.into(); + let step_ref = Expression::Reference(Default::default(), step_ref); + + let four = Expression::Number( + Default::default(), + Number { + value: BigUint::from(4u32), + type_: None, + }, + ); + + for stmt in statements { + match stmt { + FunctionStatement::Instruction(InstructionStatement { + source: _, + instruction, + inputs, + }) => { + if let Some(instr_def) = instruction_map.get(instruction) { + let instr_inputs = instr_def + .params + .inputs + .iter() + .map(|p| p.name.clone()) + .collect::>(); + + // Create initial substitution map + let sub_map: BTreeMap = + instr_inputs.into_iter().zip(inputs.clone()).collect(); + + // STEP_0 = step; + // STEP_i = STEP_{i-1} + 1; + let prev_step_ref = if ssa_counter == 0 { + step_ref.clone() + } else { + let prev_step_symbol = + SymbolPath::from_identifier(format!("STEP_{}", ssa_counter - 1)); + let prev_step_ref: NamespacedPolynomialReference = prev_step_symbol.into(); + let prev_step_ref = + Expression::Reference(Default::default(), prev_step_ref); + Expression::new_binary( + prev_step_ref.clone(), + BinaryOperator::Add, + four.clone(), + ) + }; + + let step_i = PilStatement::LetStatement( + Default::default(), + format!("STEP_{ssa_counter}"), + None, + Some(prev_step_ref), + ); + + constraints.push(step_i); + + // Witness columns from main + let local_wits = wits + .iter() + .map(|wit| { + if let PilStatement::PolynomialCommitDeclaration( + source, + stage, + names, + value, + ) = wit + { + PilStatement::PolynomialCommitDeclaration( + source.clone(), + *stage, + names + .iter() + .map(|n| format!("{}_{}", n.clone(), ssa_counter).into()) + .collect::>(), + value.clone(), + ) + } else { + panic!("Expected PolynomialCommitDeclaration") + } + }) + .collect::>(); + + // Constraints from main + let local_identities = identities + .iter() + .map(|id| { + if let PilStatement::Expression(source, expr) = id { + let mut expr = expr.clone(); + append_suffix_mut(&mut expr, &ssa_counter.to_string()); + PilStatement::Expression(source.clone(), expr) + } else { + panic!("Expected PolynomialCommitDeclaration") + } + }) + .collect::>(); + + constraints.extend(local_wits); + constraints.extend(local_identities); + + // Links from main + for link in main_links { + let mut link = link.clone(); + for e in &mut link.to.params.inputs_and_outputs_mut() { + substitute(e, &sub_map); + append_suffix_mut(e, &ssa_counter.to_string()); + } + links.push(link); + } + + // Process Links + for link in &instr_def.links { + let sub_inputs = link + .link + .params + .inputs + .clone() + .into_iter() + .map(|mut p| { + substitute(&mut p, &sub_map); + append_suffix_mut(&mut p, &ssa_counter.to_string()); + p + }) + .collect::>(); + let sub_outputs = link + .link + .params + .outputs + .clone() + .into_iter() + .map(|mut p| { + substitute(&mut p, &sub_map); + append_suffix_mut(&mut p, &ssa_counter.to_string()); + p + }) + .collect::>(); + let sub_link_link = CallableRef { + instance: link.link.instance.clone(), + callable: link.link.callable.clone(), + params: Params::::new(sub_inputs, sub_outputs), + }; + + links.push(LinkDefinition { + source: Default::default(), + instr_flag: None, + link_flag: used_ref.clone(), + to: sub_link_link, + is_permutation: link.is_permutation, + }); + } + + // Process constraints + for pil_stmt in &instr_def.body.0 { + if let PilStatement::Expression(source, expr) = pil_stmt { + let mut expr = expr.clone(); + substitute(&mut expr, &sub_map); + append_suffix_mut(&mut expr, &ssa_counter.to_string()); + constraints.push(PilStatement::Expression(source.clone(), expr)); + } + } + + ssa_counter += 1; + } + } + _ => { + // Handle other statement types if necessary + } + } + } + + let regs_param = Param { + source: Default::default(), + name: "regs".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "large_field", "memory", "Memory"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let mem_param = Param { + source: Default::default(), + name: "memory".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "large_field", "memory", "Memory"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let split_param = Param { + source: Default::default(), + name: "split_gl".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "split", "split_gl", "SplitGL"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let binary_param = Param { + source: Default::default(), + name: "binary".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "large_field", "binary", "Binary"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let shift_param = Param { + source: Default::default(), + name: "shift".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "large_field", "shift", "Shift"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let byte_param = Param { + source: Default::default(), + name: "byte".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "range", "Byte"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let bit2_param = Param { + source: Default::default(), + name: "bit2".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "range", "Bit2"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let bit6_param = Param { + source: Default::default(), + name: "bit6".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "range", "Bit6"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + let bit7_param = Param { + source: Default::default(), + name: "bit7".to_string(), + index: None, + ty: Some(SymbolPath::from_parts( + ["std", "machines", "range", "Bit7"] + .iter() + .map(|p| Part::Named(p.to_string())), + )), + }; + Machine { + degree, + latch: Some("latch".to_string()), + //operation_id: Some("operation_id".to_string()), + operation_id: None, + call_selectors: Some("sel".to_string()), + params: MachineParams(vec![ + regs_param, + mem_param, + split_param, + binary_param, + shift_param, + byte_param, + bit2_param, + bit6_param, + bit7_param, + ]), + registers: Vec::new(), + pc: None, + pil: constraints, + instructions: Vec::new(), + links, + callable: callable_defs, + submachines: Vec::new(), + } +} + +fn substitute(expr: &mut Expression, sub: &BTreeMap) { + expr.visit_expressions_mut( + &mut |expr| { + if let Expression::Reference(_, ref mut r) = expr { + if let Some(sub_expr) = sub.get(&r.path.to_string()) { + *expr = sub_expr.clone(); + } + } + ControlFlow::Continue::<()>(()) + }, + VisitOrder::Pre, + ); +} + +fn append_suffix_mut(expr: &mut Expression, suffix: &str) { + expr.visit_expressions_mut( + &mut |expr| match expr { + Expression::FunctionCall(_, ref mut fun_call) => { + for arg in &mut fun_call.arguments { + append_suffix_mut(arg, suffix); + } + ControlFlow::Break::<()>(()) + } + Expression::Reference(_, ref mut r) => { + let name = r.path.try_last_part().unwrap(); + let name = format!("{name}_{suffix}"); + *r.path.try_last_part_mut().unwrap() = name; + ControlFlow::Continue::<()>(()) + } + _ => ControlFlow::Continue::<()>(()), + }, + VisitOrder::Pre, + ); +} + +fn collect_columns(expr: &Expression) -> Vec { + let mut cols: Vec<_> = Vec::new(); + expr.visit_expressions( + &mut |expr| match expr { + Expression::FunctionCall(_, ref fun_call) => { + for arg in &fun_call.arguments { + cols.extend(collect_columns(arg)); + } + ControlFlow::Break::<()>(()) + } + Expression::Reference(_, ref r) => { + let name = r.path.try_last_part().unwrap(); + cols.push(name.clone()); + ControlFlow::Continue::<()>(()) + } + _ => ControlFlow::Continue::<()>(()), + }, + VisitOrder::Pre, + ); + cols +} + +fn create_precompiles( + mut analyzed_asm: AnalysisASMFile, + selected: &BTreeSet, +) -> AnalysisASMFile { + let machine = analyzed_asm + .get_machine_mut(&parse_absolute_path("::Main")) + .unwrap(); + let CallableSymbol::Function(ref mut main_function) = + &mut machine.callable.0.get_mut("main").unwrap() + else { + panic!("main function missing") + }; + + let blocks = transform_autoprecompile_blocks(&mut main_function.body.statements, selected); + + if blocks.is_empty() { + return analyzed_asm; + } + + let new_fs = FunctionStatements::new(main_function.body.statements.inner.clone()); + main_function.body.statements = new_fs.clone(); + //println!("new statements: {new_fs}"); + + let wits = machine + .pil + .iter() + .filter(|stmt| matches!(stmt, PilStatement::PolynomialCommitDeclaration(_, _, _, _))) + .collect::>(); + + let identities = machine + .pil + .iter() + .filter(|stmt| matches!(stmt, PilStatement::Expression(_, _))) + .collect::>(); + + let name_to_instr: BTreeMap = machine + .instructions + .iter() + .map(|instr| (instr.name.clone(), instr.instruction.clone())) + .collect(); + + let degree = machine.degree.clone(); + + let mut module_names = Vec::new(); + let mut modules = Vec::new(); + + for block in &blocks { + let precompile_module_name = block.0.clone(); + let precompile_machine_name = format!("Precompile_{precompile_module_name}"); + let precompile_submachine_name = format!("instance_{precompile_module_name}"); + let precompile_instr_name = precompile_module_name.clone(); + + let precompile = generate_precompile( + &block.1, + &name_to_instr, + degree.clone(), + &wits, + &identities, + &machine.links, + ); + + let precompile = optimize_precompile(precompile); + println!("New precompile:\n{precompile}"); + + let mut module = Module::new(Default::default(), Default::default(), Default::default()); + module.push_machine(precompile_machine_name.clone(), precompile); + let module_path = parse_absolute_path(&format!("::{precompile_module_name}")); + + let mut submachine_path = module_path.clone(); + submachine_path.push(precompile_machine_name); + + let args = vec![ + "regs", + "memory", + "split_gl", + "binary", + "shift", + "byte", + "bit2", + "bit6", + "bit7", + "MIN_DEGREE", + "MAIN_MAX_DEGREE", + ]; + let precompile_decl = SubmachineDeclaration { + name: precompile_submachine_name.clone(), + ty: submachine_path, + args: args + .into_iter() + .map(|s| { + Expression::Reference( + Default::default(), + NamespacedPolynomialReference::from_identifier(s.to_string()), + ) + }) + .collect(), + }; + machine.submachines.push(precompile_decl); + + let step = Expression::Reference( + Default::default(), + NamespacedPolynomialReference::from_identifier("STEP".to_string()), + ); + + let link_callable = CallableRef { + instance: precompile_submachine_name, + callable: "run".to_string(), + params: Params::new(vec![step], vec![]), + }; + let one = Expression::Number( + Default::default(), + Number { + value: BigUint::from(1u32), + type_: None, + }, + ); + let link = LinkDeclaration { + flag: one, + link: link_callable, + is_permutation: true, + }; + let instruction = Instruction { + params: Default::default(), + links: vec![link], + body: InstructionBody(vec![]), + }; + let instr_decl = InstructionDefinitionStatement { + source: Default::default(), + name: precompile_instr_name, + instruction, + }; + machine.instructions.push(instr_decl); + + //analyzed_asm.modules.insert(module_path.clone(), module); + + //main_module.push_module(precompile_module_name); + module_names.push(precompile_module_name); + modules.push((module_path.clone(), module)); + } + + for module in modules { + analyzed_asm.modules.insert(module.0, module.1); + } + + let main_module = analyzed_asm + .modules + .get_mut(&AbsoluteSymbolPath::default()) + .unwrap(); + + for module in module_names { + main_module.push_module(module); + } + + println!("Optimized ASM:\n{analyzed_asm}"); + + analyzed_asm +} + +fn optimize_precompile(mut machine: Machine) -> Machine { + let mut scc: BTreeSet = BTreeSet::new(); + + // We use operation inputs/outputs as scc sources + for callable in machine.callable.0.values() { + match callable { + CallableSymbol::Operation(symbol) => { + symbol.params.inputs_and_outputs().for_each(|p| { + scc.insert(p.name.clone()); + }); + } + CallableSymbol::Function(symbol) => { + symbol.params.inputs_and_outputs().for_each(|p| { + scc.insert(p.name.clone()); + }); + } + } + } + + // Use args of mstore links as sources too + for link in &machine.links { + let callable = &link.to.callable; + if callable == "mstore" { + scc.extend(collect_columns(&link.link_flag)); + if let Some(ref flag) = &link.instr_flag { + scc.extend(collect_columns(flag)); + } + for p in link.to.params.inputs_and_outputs() { + scc.extend(collect_columns(p)); + } + } + } + + //println!("Optimizer source SCC = {:?}", scc); + + // Collect connected items until we can't anymore. + // This is ofc slower than a proper SCC algorithm, but it's fine for now. + loop { + let pre_len = scc.len(); + + // Collect all cols in links. + // For a given link, if one col is part of the SCC, + // add all others. + for link in &machine.links { + //println!("Checking link {}", link); + let mut local: BTreeSet = BTreeSet::new(); + local.extend(collect_columns(&link.link_flag)); + if let Some(ref flag) = &link.instr_flag { + local.extend(collect_columns(flag)); + } + for p in link.to.params.inputs_and_outputs() { + local.extend(collect_columns(p)); + } + //println!("Local SCC = {:?}", local); + if local.iter().any(|c| scc.contains(c)) { + scc.extend(local); + } + } + //println!("Extended SCC after links = {:?}", scc); + + // Collect all cols in identities. + // For a given identity, if one col is part of the SCC, + // add all others. + for stmt in &machine.pil { + let mut local: BTreeSet = BTreeSet::new(); + if let PilStatement::Expression(_, ref expr) = stmt { + local.extend(collect_columns(expr)); + } + if local.iter().any(|c| scc.contains(c)) { + scc.extend(local); + } + } + //println!("Extended SCC after pils = {:?}", scc); + + let post_len = scc.len(); + if pre_len == post_len { + break; + } + } + + //println!("Optimizer converged SCC = {:?}", scc); + + // Remove all links that are not part of the SCC + machine.links.retain(|link| { + let mut local: BTreeSet = BTreeSet::new(); + local.extend(collect_columns(&link.link_flag)); + if let Some(ref flag) = &link.instr_flag { + local.extend(collect_columns(flag)); + } + for p in link.to.params.inputs_and_outputs() { + local.extend(collect_columns(p)); + } + local.iter().all(|c| scc.contains(c)) + }); + + // Remove all identities that are not part of the SCC + machine.pil.retain(|stmt| { + if let PilStatement::Expression(_, ref expr) = stmt { + let cols = collect_columns(expr); + cols.iter().all(|c| scc.contains(c)) + } else { + true + } + }); + + println!("Optimized machine: before mloads\n{machine}"); + + // Optimize mloads. + let mut mem: BTreeMap = BTreeMap::new(); + machine.links.retain(|link| { + let callable = &link.to.callable; + if link.to.instance != "regs" || (callable != "mload" && callable != "mstore") { + return true; + } + + let inputs = &link.to.params.inputs; + let outputs = &link.to.params.outputs; + + let reg: u64 = match &inputs[0] { + Expression::Number(_, ref n) => n.value.clone().try_into().unwrap(), + _ => panic!("Expected number"), + }; + + if callable == "mload" { + assert_eq!(outputs.len(), 1); + let cols = collect_columns(&outputs[0]); + assert_eq!(cols.len(), 1); + let output_col = &cols[0]; + + if let Some(col) = mem.get(®) { + machine.pil.push(PilStatement::Expression( + Default::default(), + Expression::new_binary( + Expression::Reference( + Default::default(), + NamespacedPolynomialReference::from_identifier(output_col.clone()), + ), + BinaryOperator::Identity, + col.clone(), /* + Expression::Reference( + Default::default(), + NamespacedPolynomialReference::from_identifier(col.clone()), + )*/ + ), + )); + return false; + } else { + //mem.insert(reg, output_col.clone()); + mem.insert(reg, outputs[0].clone()); + } + } else if callable == "mstore" { + assert_eq!(inputs.len(), 3); + + //let cols = collect_columns(&inputs[2]); + //println!("cols = {cols:?}"); + //assert_eq!(cols.len(), 1); + //let value = &cols[0]; + + //mem.insert(reg, value.clone()); + mem.insert(reg, inputs[2].clone()); + } + true + }); + + // Optimize mstores. + let mut last_store: BTreeMap = BTreeMap::new(); + for (i, link) in machine.links.iter().enumerate() { + let callable = &link.to.callable; + if link.to.instance != "regs" || callable != "mstore" { + continue; + } + + let inputs = &link.to.params.inputs; + + let reg: u64 = match &inputs[0] { + Expression::Number(_, ref n) => n.value.clone().try_into().unwrap(), + _ => panic!("Expected number"), + }; + + last_store.insert(reg, i); + } + + machine.links = machine + .links + .into_iter() + .enumerate() + .filter_map(|(i, link)| { + let callable = &link.to.callable; + if link.to.instance != "regs" || callable != "mstore" { + // Retain non-mstore links + return Some(link); + } + + let inputs = &link.to.params.inputs; + + let reg: u64 = match &inputs[0] { + Expression::Number(_, ref n) => n.value.clone().try_into().unwrap(), + _ => panic!("Expected number"), + }; + + // Retain only if this index is the last `mstore` for this address + if last_store + .get(®) + .is_some_and(|&last_index| last_index == i) + { + Some(link) + } else { + None + } + }) + .collect(); + + // Move all memory links to the end because of witgen. + let mut memory_links = vec![]; + machine.links.retain(|link| { + if link.to.instance == "memory" || link.to.instance == "regs" { + memory_links.push(link.clone()); + false + } else { + true + } + }); + machine.links.extend(memory_links); + + machine +} + +pub fn collect_basic_blocks( + analyzed_asm: &AnalysisASMFile, +) -> Vec<(String, Vec)> { + let machine = analyzed_asm + .get_machine(&parse_absolute_path("::Main")) + .unwrap(); + let CallableSymbol::Function(ref main_function) = &mut machine.callable.0.get("main").unwrap() + else { + panic!("main function missing") + }; + + let program = &main_function.body.statements.inner; + + let mut blocks = Vec::new(); + //let ghost_labels = 0; + //let mut curr_label = format!("ghost_label_{ghost_labels}"); + //let mut curr_label = "block_init".to_string(); + let mut curr_label: Option = None; + let mut block_statements = Vec::new(); + + for op in program { + match &op { + FunctionStatement::Label(LabelStatement { source: _, name }) => { + if let Some(label) = curr_label { + assert!(!blocks.iter().any(|(l, _)| l == &label)); + blocks.push((label.clone(), block_statements.clone())); + } + block_statements.clear(); + curr_label = Some(name.clone()); + } + FunctionStatement::Instruction(InstructionStatement { + source: _, + instruction, + inputs: _, + }) if instruction.starts_with("branch") + || instruction.starts_with("jump") + || instruction.starts_with("skip") => + { + if let Some(label) = curr_label { + assert!(!blocks.iter().any(|(l, _)| l == &label)); + blocks.push((label.clone(), block_statements.clone())); + } + block_statements.clear(); + curr_label = None; + //assert!(!blocks.iter().any(|(label, _)| label == &curr_label)); + //blocks.push((curr_label.clone(), block_statements.clone())); + //block_statements.clear(); + //ghost_labels += 1; + //curr_label = format!("ghost_label_{ghost_labels}"); + } + FunctionStatement::Instruction(InstructionStatement { + source: _, + instruction: _, + inputs: _, + }) => { + block_statements.push(op.clone()); + } + FunctionStatement::Return(_) | FunctionStatement::Assignment(_) => { + if let Some(label) = curr_label { + assert!(!blocks.iter().any(|(l, _)| l == &label)); + blocks.push((label.clone(), block_statements.clone())); + } + block_statements.clear(); + curr_label = None; + //blocks.push((curr_label.clone(), block_statements.clone())); + //block_statements.clear(); + //ghost_labels += 1; + //curr_label = format!("ghost_label_{ghost_labels}"); + } + _ => {} + } + } + + blocks +} + +pub fn annotate_basic_blocks(analyzed_asm: &mut AnalysisASMFile) { + let machine = analyzed_asm + .get_machine_mut(&parse_absolute_path("::Main")) + .unwrap(); + let CallableSymbol::Function(ref mut main_function) = + &mut machine.callable.0.get_mut("main").unwrap() + else { + panic!("main function missing") + }; + + let program = &mut main_function.body.statements.inner; + + let mut ghost_labels = 0; + + let mut i = 0; + while i < program.len() { + match &program[i] { + FunctionStatement::Instruction(InstructionStatement { + source, + instruction, + inputs: _, + }) if instruction.starts_with("branch") + || instruction.starts_with("jump") + || instruction.starts_with("skip") => + { + let curr_label = format!("ghost_label_{ghost_labels}"); + let new_label = FunctionStatement::Label(LabelStatement { + source: source.clone(), + name: curr_label.clone(), + }); + program.insert(i + 1, new_label); + ghost_labels += 1; + i += 1; + } + FunctionStatement::Return(_) | FunctionStatement::Assignment(_) => { + let curr_label = format!("ghost_label_{ghost_labels}"); + let new_label = FunctionStatement::Label(LabelStatement { + source: Default::default(), + name: curr_label.clone(), + }); + program.insert(i + 1, new_label); + ghost_labels += 1; + i += 1; + } + _ => {} + } + i += 1; + } +} diff --git a/asmopt/Cargo.toml b/asmopt/Cargo.toml index 57067c6eae..02f1d410ae 100644 --- a/asmopt/Cargo.toml +++ b/asmopt/Cargo.toml @@ -9,6 +9,7 @@ repository.workspace = true [dependencies] powdr-ast.workspace = true powdr-analysis.workspace = true +powdr-number.workspace = true powdr-pilopt.workspace = true powdr-parser.workspace = true diff --git a/asmopt/src/lib.rs b/asmopt/src/lib.rs index 4b3ea135a8..b158b7f3a2 100644 --- a/asmopt/src/lib.rs +++ b/asmopt/src/lib.rs @@ -1,10 +1,10 @@ use std::collections::{HashMap, HashSet}; use std::iter::once; -use powdr_ast::parsed::asm::parse_absolute_path; +use powdr_ast::parsed::asm::{parse_absolute_path, AbsoluteSymbolPath}; use powdr_ast::{ asm_analysis::{AnalysisASMFile, Machine}, - parsed::{asm::AbsoluteSymbolPath, NamespacedPolynomialReference}, + parsed::NamespacedPolynomialReference, }; use powdr_pilopt::referenced_symbols::ReferencedSymbols; diff --git a/ast/src/asm_analysis/display.rs b/ast/src/asm_analysis/display.rs index d5b33c06fb..02487cff1c 100644 --- a/ast/src/asm_analysis/display.rs +++ b/ast/src/asm_analysis/display.rs @@ -83,6 +83,10 @@ impl Display for MachineDegree { impl Display for Machine { fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let params = &self.params.0; + if !params.is_empty() { + write!(f, "({params}) ", params = params.iter().format(", "))?; + } let props = std::iter::once(&self.degree) .map(|d| format!("{d}")) .chain(self.latch.as_ref().map(|s| format!("latch: {s}"))) diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index df82e6bdaa..f648a1be58 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -92,7 +92,7 @@ pub fn combine_flags(instr_flag: Option, link_flag: Expression) -> E #[derive(Clone, Debug, Default)] pub struct FunctionStatements { - inner: Vec, + pub inner: Vec, batches: Option>, } @@ -837,6 +837,12 @@ impl AnalysisASMFile { let name = path.pop().unwrap(); self.modules[&path].machines.get(&name) } + + pub fn get_machine_mut(&mut self, ty: &AbsoluteSymbolPath) -> Option<&mut Machine> { + let mut path = ty.clone(); + let name = path.pop().unwrap(); + self.modules.get_mut(&path).unwrap().machines.get_mut(&name) + } } #[derive(Clone, Debug)] diff --git a/cli-rs/Cargo.toml b/cli-rs/Cargo.toml index 97f7466e58..614540b6f2 100644 --- a/cli-rs/Cargo.toml +++ b/cli-rs/Cargo.toml @@ -17,6 +17,8 @@ stwo = ["powdr/stwo"] [dependencies] powdr.workspace = true +powdr-analysis.workspace = true +itertools = "0.13" clap = { version = "^4.3", features = ["derive"] } env_logger = "0.10.0" @@ -27,4 +29,4 @@ clap-markdown = "0.1.3" [[bin]] name = "powdr-rs" path = "src/main.rs" -bench = false # See https://github.com/bheisler/criterion.rs/issues/458 +bench = false # See https://github.com/bheisler/criterion.rs/issues/458 diff --git a/cli-rs/src/main.rs b/cli-rs/src/main.rs index 269633c175..f5bda65809 100644 --- a/cli-rs/src/main.rs +++ b/cli-rs/src/main.rs @@ -14,12 +14,15 @@ use powdr::riscv::{CompilerOptions, RuntimeLibs}; use powdr::riscv_executor::{write_executor_csv, ProfilerOptions}; use powdr::Pipeline; +use itertools::Itertools; +use std::collections::{BTreeMap, BTreeSet}; use std::ffi::OsStr; use std::time::Instant; use std::{ io::{self, Write}, path::Path, }; + use strum::{Display, EnumString, EnumVariantNames}; #[derive(Clone, EnumString, EnumVariantNames, Display)] @@ -149,6 +152,10 @@ enum Commands { #[arg(long)] #[arg(default_value_t = false)] generate_callgrind: bool, + + #[arg(long)] + #[arg(default_value_t = false)] + auto_precompiles: bool, }, /// Execute and generate a valid witness for a RISCV powdr-asm file with the given inputs. Witgen { @@ -280,6 +287,7 @@ fn run_command(command: Commands) { output_directory, generate_flamegraph, generate_callgrind, + auto_precompiles, } => { let profiling = if generate_callgrind || generate_flamegraph { Some(ProfilerOptions { @@ -294,12 +302,20 @@ fn run_command(command: Commands) { } else { None }; - call_with_field!(execute_fast::( - Path::new(&file), - split_inputs(&inputs), - Path::new(&output_directory), - profiling - )) + if !auto_precompiles { + call_with_field!(execute_fast::( + Path::new(&file), + split_inputs(&inputs), + Path::new(&output_directory), + profiling + )) + } else { + call_with_field!(autoprecompiles::( + Path::new(&file), + split_inputs(&inputs), + Path::new(&output_directory) + )) + } } Commands::Witgen { file, @@ -392,12 +408,13 @@ fn execute_fast( let start = Instant::now(); - let trace_len = powdr::riscv_executor::execute::( + let (trace_len, _) = powdr::riscv_executor::execute::( &asm, powdr::riscv_executor::MemoryState::new(), pipeline.data_callback().unwrap(), &[], profiling, + Default::default(), ); let duration = start.elapsed(); @@ -406,6 +423,132 @@ fn execute_fast( Ok(()) } +fn autoprecompiles( + file_name: &Path, + inputs: Vec, + output_dir: &Path, +) -> Result<(), Vec> { + let mut pipeline = Pipeline::::default() + .from_asm_file(file_name.to_path_buf()) + .with_prover_inputs(inputs) + .with_backend(powdr::backend::BackendType::Plonky3Composite, None) + .with_output(output_dir.into(), true); + + pipeline.compute_checked_asm().unwrap(); + let checked_asm = pipeline.checked_asm().unwrap().clone(); + + let asm = pipeline.compute_analyzed_asm().unwrap().clone(); + let initial_memory = + powdr::riscv::continuations::load_initial_memory(&asm, pipeline.initial_memory()); + + println!("Running powdr-riscv executor in fast mode..."); + let start = Instant::now(); + + let (trace_len, label_freq) = powdr::riscv_executor::execute( + &asm, + initial_memory, + pipeline.data_callback().unwrap(), + &powdr::riscv::continuations::bootloader::default_input(&[]), + None, + Default::default(), + ); + + let duration = start.elapsed(); + println!("Fast executor took: {duration:?}"); + println!("Trace length: {trace_len}"); + + let blocks = powdr_analysis::collect_basic_blocks(&checked_asm); + //println!("Basic blocks:\n{blocks:?}"); + + let blocks = blocks + .into_iter() + .map(|(name, b)| { + let freq = label_freq.get(&name).unwrap_or(&0); + let l = b.len() as u64; + (name, b, freq, freq * l) + }) + .sorted_by_key(|(_, _, _, cost)| std::cmp::Reverse(*cost)) + .collect::>(); + for (name, block, freq, cost) in &blocks { + println!( + "{name}: size = {}, freq = {freq}, cost = {cost}", + block.len() + ); + } + + let total_cost = blocks.iter().map(|(_, _, _, cost)| cost).sum::(); + + //let dont_eq = vec!["__data_init", "main", "halt"]; + let dont_eq: Vec<&str> = vec![]; + //let dont_contain = vec!["powdr_riscv_runtime", "page_ok"]; + let dont_contain: Vec<&str> = vec![]; + let selected: BTreeSet = blocks + .iter() + .skip(0) + .filter(|(name, block, _, cost)| { + !dont_eq.contains(&name.as_str()) + && !dont_contain.iter().any(|s| name.contains(s)) + && block.len() > 1 + && *cost > 2 + }) + //.take(5) + .map(|block| block.0.clone()) + .into_iter() + .collect(); + let auto_asm = powdr_analysis::analyze_precompiles(checked_asm.clone(), &selected); + + println!("Selected blocks: {selected:?}"); + println!("Selected {} blocks", selected.len()); + + //println!("New auto_asm:\n{auto_asm}"); + let cost_unopt = blocks + .iter() + .filter(|(name, _, _, _)| !selected.contains(name)) + .map(|(name, _, _, cost)| { + println!("Did not select block {name} with cost {cost}"); + cost + }) + .sum::(); + + println!("Total cost = {total_cost}"); + println!("Total cost unopt = {cost_unopt}"); + + let selected_blocks: BTreeMap<_, _> = blocks + .into_iter() + .filter(|(name, _, _, _)| selected.contains(name)) + .map(|(name, block, _, _)| (name, block)) + .collect(); + + pipeline.rollback_from_checked_asm(); + pipeline.set_checked_asm(auto_asm); + let asm = pipeline.compute_analyzed_asm().unwrap().clone(); + let initial_memory = + powdr::riscv::continuations::load_initial_memory(&asm, pipeline.initial_memory()); + + println!("Running powdr-riscv executor in fast mode with autoprecomiles..."); + let start = Instant::now(); + + let (trace_len, _) = powdr::riscv_executor::execute( + &asm, + initial_memory, + pipeline.data_callback().unwrap(), + &powdr::riscv::continuations::bootloader::default_input(&[]), + None, + selected_blocks, + ); + + let duration = start.elapsed(); + println!("Fast executor with autoprecompiles took: {duration:?}"); + println!("Trace length with autoprecompiles: {trace_len}"); + + /* + pipeline.compute_witness()?; + pipeline.compute_proof()?; + */ + + Ok(()) +} + #[allow(clippy::too_many_arguments)] fn execute( file_name: &Path, @@ -436,7 +579,7 @@ fn execute( let start = Instant::now(); - let execution = powdr::riscv_executor::execute_with_witness::( + let (execution, _) = powdr::riscv_executor::execute_with_trace::( &asm, &pil, fixed, diff --git a/examples/fibonacci/Cargo.toml b/examples/fibonacci/Cargo.toml index 31d0f4d12e..67df231db5 100644 --- a/examples/fibonacci/Cargo.toml +++ b/examples/fibonacci/Cargo.toml @@ -8,7 +8,8 @@ default = [] simd = ["powdr/plonky3-simd"] [dependencies] -powdr = { git = "https://github.com/powdr-labs/powdr", tag = "v0.1.3", features = [ +#powdr = { git = "https://github.com/powdr-labs/powdr", tag = "v0.1.3", features = [ +powdr = { path = "../../powdr", features = [ "plonky3", ] } diff --git a/examples/fibonacci/guest/Cargo.toml b/examples/fibonacci/guest/Cargo.toml index bc04c62b19..62b9f3f550 100644 --- a/examples/fibonacci/guest/Cargo.toml +++ b/examples/fibonacci/guest/Cargo.toml @@ -4,7 +4,8 @@ version = "0.1.0" edition = "2021" [dependencies] -powdr-riscv-runtime = { git = "https://github.com/powdr-labs/powdr", tag = "v0.1.3", features = [ +#powdr-riscv-runtime = { git = "https://github.com/powdr-labs/powdr", tag = "v0.1.3", features = [ +powdr-riscv-runtime = { path = "../../../riscv-runtime", features = [ "std", ] } diff --git a/examples/fibonacci/guest/src/main.rs b/examples/fibonacci/guest/src/main.rs index f094c6f8ff..7f264a6018 100644 --- a/examples/fibonacci/guest/src/main.rs +++ b/examples/fibonacci/guest/src/main.rs @@ -11,7 +11,7 @@ fn fib(n: u32) -> u32 { fn main() { // Read input from stdin. - let n: u32 = read(0); + let n: u32 = read(); let r = fib(n); // Write result to stdout. write(1, r); diff --git a/examples/fibonacci/src/main.rs b/examples/fibonacci/src/main.rs index 2da28fa1cb..73e4bdc0dc 100644 --- a/examples/fibonacci/src/main.rs +++ b/examples/fibonacci/src/main.rs @@ -10,10 +10,12 @@ fn main() { .chunk_size_log2(18) .build() // Compute Fibonacci of 21 in the guest. - .write(0, &n); + .write(&n); // Fast dry run to test execution. - session.run(); + //session.run(); + + session.optimize_autoprecompile(); session.prove(); } diff --git a/executor/src/witgen/jit/function_cache.rs b/executor/src/witgen/jit/function_cache.rs index ee376e2c06..84e3b96f03 100644 --- a/executor/src/witgen/jit/function_cache.rs +++ b/executor/src/witgen/jit/function_cache.rs @@ -67,6 +67,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { identity_id: u64, known_args: &BitVec, ) -> &Option> { + return &None; let cache_key = CacheKey { identity_id, known_args: known_args.clone(), @@ -76,6 +77,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { } fn ensure_cache(&mut self, can_process: impl CanProcessCall, cache_key: &CacheKey) { + return; if self.witgen_functions.contains_key(cache_key) { return; } diff --git a/executor/src/witgen/machines/machine_extractor.rs b/executor/src/witgen/machines/machine_extractor.rs index e3bc28fb26..7a4e531aa7 100644 --- a/executor/src/witgen/machines/machine_extractor.rs +++ b/executor/src/witgen/machines/machine_extractor.rs @@ -467,6 +467,7 @@ fn build_machine<'a, T: FieldElement>( BlockMachine::try_new(name_with_type("BlockMachine"), fixed_data, &machine_parts) { log::debug!("Detected machine: {machine}"); + log_extracted_machine(machine.name(), &machine_parts); KnownMachine::BlockMachine(machine) } else { log::debug!("Detected machine: Dynamic machine."); diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index e5de154dc8..34b8d4d2dd 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -52,6 +52,8 @@ pub struct Artifacts { /// A tree of .asm modules (with all dependencies potentially imported /// from other files) with all references resolved to absolute symbol paths. resolved_module_tree: Option, + /// The machine checked .asm file. + checked_asm: Option, /// The analyzed .asm file: Assignment registers are inferred, instructions /// are batched and some properties are checked. analyzed_asm: Option, @@ -159,6 +161,7 @@ impl Clone for Artifacts { asm_string: self.asm_string.clone(), parsed_asm_file: self.parsed_asm_file.clone(), resolved_module_tree: self.resolved_module_tree.clone(), + checked_asm: self.checked_asm.clone(), analyzed_asm: self.analyzed_asm.clone(), optimized_asm: self.optimized_asm.clone(), constrained_machine_collection: self.constrained_machine_collection.clone(), @@ -700,6 +703,20 @@ impl Pipeline { self.arguments.external_witness_values.clear(); } + pub fn rollback_from_checked_asm(&mut self) { + self.rollback_from_witness(); + self.artifact.analyzed_asm = None; + self.artifact.optimized_asm = None; + self.artifact.constrained_machine_collection = None; + self.artifact.linked_machine_graph = None; + self.artifact.parsed_pil_file = None; + self.artifact.pil_file_path = None; + self.artifact.pil_string = None; + self.artifact.analyzed_pil = None; + self.artifact.optimized_pil = None; + self.artifact.fixed_cols = None; + } + // ===== Compute and retrieve artifacts ===== pub fn asm_file_path(&self) -> Result<&PathBuf, Vec> { @@ -782,14 +799,40 @@ impl Pipeline { Ok(self.artifact.resolved_module_tree.as_ref().unwrap()) } + pub fn compute_checked_asm(&mut self) -> Result<&AnalysisASMFile, Vec> { + if self.artifact.checked_asm.is_none() { + self.artifact.checked_asm = Some({ + self.compute_resolved_module_tree()?; + let resolved = self.artifact.resolved_module_tree.take().unwrap(); + + self.log("Run analysis machine check"); + let checked_asm = powdr_analysis::check(resolved)?; + self.log("Analysis machine check done"); + log::trace!("{checked_asm}"); + + checked_asm + }); + } + + Ok(self.artifact.checked_asm.as_ref().unwrap()) + } + + pub fn checked_asm(&self) -> Result<&AnalysisASMFile, Vec> { + Ok(self.artifact.checked_asm.as_ref().unwrap()) + } + + pub fn set_checked_asm(&mut self, asm: AnalysisASMFile) { + self.artifact.checked_asm = Some(asm); + } + pub fn compute_analyzed_asm(&mut self) -> Result<&AnalysisASMFile, Vec> { if self.artifact.analyzed_asm.is_none() { self.artifact.analyzed_asm = Some({ - self.compute_resolved_module_tree()?; - let resolved = self.artifact.resolved_module_tree.take().unwrap(); + self.compute_checked_asm()?; + let checked = self.artifact.checked_asm.take().unwrap(); self.log("Run analysis"); - let analyzed_asm = powdr_analysis::analyze(resolved)?; + let analyzed_asm = powdr_analysis::analyze_only(checked)?; self.log("Analysis done"); log::trace!("{analyzed_asm}"); diff --git a/powdr/Cargo.toml b/powdr/Cargo.toml index f754c623fc..49c16ba7c3 100644 --- a/powdr/Cargo.toml +++ b/powdr/Cargo.toml @@ -9,6 +9,7 @@ repository = { workspace = true } [dependencies] powdr-ast.workspace = true +powdr-analysis.workspace = true powdr-backend.workspace = true powdr-executor.workspace = true powdr-number.workspace = true @@ -19,6 +20,7 @@ powdr-pipeline.workspace = true powdr-riscv.workspace = true powdr-riscv-executor.workspace = true +itertools = "0.13" log = "0.4.17" serde = { version = "1.0", default-features = false, features = [ diff --git a/powdr/src/lib.rs b/powdr/src/lib.rs index 67ce734496..e60bdae146 100644 --- a/powdr/src/lib.rs +++ b/powdr/src/lib.rs @@ -17,11 +17,14 @@ pub use powdr_number::{FieldElement, LargeInt}; use riscv::{CompilerOptions, RuntimeLibs}; +use std::collections::{BTreeMap, BTreeSet}; use std::fs::{self, File}; use std::path::Path; use std::path::PathBuf; use std::time::Instant; +use itertools::Itertools; + #[derive(Default)] pub struct SessionBuilder { guest_path: String, @@ -154,6 +157,118 @@ impl Session { run_with_profiler(&mut self.pipeline, profiler) } + pub fn optimize_autoprecompile(&mut self, n: usize) { + self.pipeline.compute_checked_asm().unwrap(); + let checked_asm = self.pipeline.checked_asm().unwrap().clone(); + + let asm = self.pipeline.compute_analyzed_asm().unwrap().clone(); + let initial_memory = + riscv::continuations::load_initial_memory(&asm, self.pipeline.initial_memory()); + + println!("Running powdr-riscv executor in fast mode..."); + let start = Instant::now(); + + let (trace_len, label_freq) = riscv_executor::execute( + &asm, + initial_memory, + self.pipeline.data_callback().unwrap(), + &riscv::continuations::bootloader::default_input(&[]), + None, + Default::default(), + ); + + let duration = start.elapsed(); + println!("Fast executor took: {duration:?}"); + println!("Trace length: {trace_len}"); + + let blocks = powdr_analysis::collect_basic_blocks(&checked_asm); + let blocks = blocks + .into_iter() + .map(|(name, b)| { + let freq = label_freq.get(&name).unwrap_or(&0); + let l = b.len() as u64; + (name, b, freq, freq * l) + }) + .sorted_by_key(|(_, _, _, cost)| std::cmp::Reverse(*cost)) + .collect::>(); + for (name, block, freq, cost) in &blocks { + println!( + "{name}: size = {}, freq = {freq}, cost = {cost}", + block.len() + ); + } + + let total_cost = blocks.iter().map(|(_, _, _, cost)| cost).sum::(); + + //let dont_eq = vec!["__data_init", "main", "halt"]; + let dont_eq: Vec<&str> = vec![]; + //let dont_contain = vec!["powdr_riscv_runtime", "page_ok"]; + let dont_contain: Vec<&str> = vec![]; + let selected: BTreeSet = blocks + .iter() + .skip(0) + .filter(|(name, block, _, cost)| { + !dont_eq.contains(&name.as_str()) + && !dont_contain.iter().any(|s| name.contains(s)) + && block.len() > 1 + && *cost > 2 + }) + //.take(n) + .map(|block| block.0.clone()) + .into_iter() + .collect(); + let auto_asm = powdr_analysis::analyze_precompiles(checked_asm.clone(), &selected); + + println!("Selected blocks: {selected:?}"); + println!("Selected {} blocks", selected.len()); + + let cost_unopt = blocks + .iter() + .filter(|(name, _, _, _)| !selected.contains(name)) + .map(|(name, _, _, cost)| { + println!("Did not select block {name} with cost {cost}"); + cost + }) + .sum::(); + //println!("New auto_asm:\n{auto_asm}"); + + println!("Total cost = {total_cost}"); + println!("Total cost unopt = {cost_unopt}"); + + let selected_blocks: BTreeMap<_, _> = blocks + .into_iter() + .filter(|(name, _, _, _)| selected.contains(name)) + .map(|(name, block, _, _)| (name, block)) + .collect(); + + self.pipeline.rollback_from_checked_asm(); + self.pipeline.set_checked_asm(auto_asm); + let asm = self.pipeline.compute_analyzed_asm().unwrap().clone(); + let initial_memory = + riscv::continuations::load_initial_memory(&asm, self.pipeline.initial_memory()); + + println!("Running powdr-riscv executor in fast mode with autoprecomiles..."); + let start = Instant::now(); + + let (trace_len, _) = riscv_executor::execute( + &asm, + initial_memory, + self.pipeline.data_callback().unwrap(), + &riscv::continuations::bootloader::default_input(&[]), + None, + selected_blocks, + ); + + let duration = start.elapsed(); + println!("Fast executor with autoprecompiles took: {duration:?}"); + println!("Trace length with autoprecompiles: {trace_len}"); + + /* + pipeline.compute_witness()?; + pipeline.compute_proof()?; + */ + } + pub fn prove(&mut self) { let asm_name = self.pipeline.asm_string().unwrap().0.clone().unwrap(); let pil_file = pil_file_path(&asm_name); @@ -312,12 +427,13 @@ fn run_internal( let asm = pipeline.compute_analyzed_asm().unwrap().clone(); let initial_memory = riscv::continuations::load_initial_memory(&asm, pipeline.initial_memory()); - let trace_len = riscv_executor::execute( + let (trace_len, _) = riscv_executor::execute( &asm, initial_memory, pipeline.data_callback().unwrap(), &riscv::continuations::bootloader::default_input(&[]), profiler, + Default::default(), ); let duration = start.elapsed(); diff --git a/precompile_macro/Cargo.toml b/precompile_macro/Cargo.toml new file mode 100644 index 0000000000..17978db496 --- /dev/null +++ b/precompile_macro/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "precompile_macro" +version.workspace = true +edition.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0" +syn = { version = "2.0", features = ["full"] } + +[lints] +workspace = true diff --git a/precompile_macro/src/lib.rs b/precompile_macro/src/lib.rs new file mode 100644 index 0000000000..1828a00562 --- /dev/null +++ b/precompile_macro/src/lib.rs @@ -0,0 +1,45 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, ItemFn}; + +#[proc_macro_attribute] +pub fn precompile(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as ItemFn); + + // Extract the function's components + let fn_name = &input.sig.ident; + let fn_attrs = &input.attrs; + let fn_vis = &input.vis; + let fn_inputs = &input.sig.inputs; + let fn_output = &input.sig.output; + let fn_block = &input.block; + + // Create a new function name with the "autoprecompile_" prefix + let new_fn_name = syn::Ident::new(&format!("autoprecompile_{fn_name}"), fn_name.span()); + + // Extract parameter names (assumes the inputs are named parameters) + let param_names = fn_inputs.iter().map(|arg| match arg { + syn::FnArg::Typed(pat_type) => &pat_type.pat, + syn::FnArg::Receiver(_) => panic!("Methods with `self` are not supported"), + }); + + // Generate the output + let output = quote! { + // Define the new function with the body of the original function + #[no_mangle] + #[inline(never)] + fn #new_fn_name(#fn_inputs) #fn_output { + #fn_block + } + + // The original function now just calls the new function + #(#fn_attrs)* + #fn_vis fn #fn_name(#fn_inputs) #fn_output { + #new_fn_name(#(#param_names),*) + } + }; + + TokenStream::from(output) +} diff --git a/riscv-executor/Cargo.toml b/riscv-executor/Cargo.toml index 7e5c426704..9ff0746977 100644 --- a/riscv-executor/Cargo.toml +++ b/riscv-executor/Cargo.toml @@ -9,6 +9,7 @@ repository = { workspace = true } [dependencies] powdr-ast.workspace = true +powdr-analysis.workspace = true powdr-executor.workspace = true powdr-number.workspace = true powdr-parser.workspace = true diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index 7fb1df594a..b080c8f7db 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -1288,6 +1288,7 @@ fn preprocess_main_function(machine: &Machine) -> PreprocessedM if !name.contains("___dot_L") { function_starts.insert(batch_idx + PC_INITIAL_VAL, name.as_str()); } + statements.push(s); } } } @@ -2166,14 +2167,42 @@ impl Executor<'_, '_, F> { None } Instruction::mul => { - let read_reg1 = args[0].u(); - let read_reg2 = args[1].u(); + let aaa = args[0].bin(); + let read_reg1: u32 = match aaa.try_into() { + Ok(v) => v, + Err(_) => panic!("noooooooooooooo 111111111"), + }; + let aaa = args[1].bin(); + let read_reg2: u32 = match aaa.try_into() { + Ok(v) => v, + Err(_) => panic!("noooooooooooooo 222222222"), + }; let lid = self.instr_link_id(instr, MachineInstance::regs, 0); let val1 = self.reg_read(0, read_reg1, lid); let lid = self.instr_link_id(instr, MachineInstance::regs, 1); let val2 = self.reg_read(1, read_reg2, lid); - let write_reg1 = args[2].u(); - let write_reg2 = args[3].u(); + let aaa = args[2].bin(); + let write_reg1: u32 = match aaa.try_into() { + Ok(v) => v, + Err(_) => panic!("noooooooooooooo 333333333333"), + }; + let aaa = args[3].bin(); + let write_reg2: u32 = match aaa.try_into() { + Ok(v) => v, + Err(_) => panic!("noooooooooooooo 44444444444444"), + }; + + let aaa = val1.bin(); + let val1_u32: u32 = match aaa.try_into() { + Ok(v) => v, + Err(_) => panic!("noooooooooooooo 5555555555555 {aaa:?}"), + }; + + let aaa = val2.bin(); + let val2_u32: u32 = match aaa.try_into() { + Ok(v) => v, + Err(_) => panic!("noooooooooooooo 6666666666666 {aaa:?}"), + }; let r = val1.u() as u64 * val2.u() as u64; let lo = r as u32; @@ -2872,9 +2901,10 @@ pub fn execute( prover_ctx: &Callback, bootloader_inputs: &[F], profiling: Option, -) -> usize { + precompile_blocks: BTreeMap>, +) -> (usize, BTreeMap) { log::info!("Executing..."); - execute_inner( + let res = execute_inner( asm, None, None, @@ -2884,8 +2914,9 @@ pub fn execute( usize::MAX, ExecMode::Fast, profiling, - ) - .trace_len + precompile_blocks, + ); + (res.0.trace_len, res.1) } /// Execute generating a witness for the PC and powdr asm registers. @@ -2899,7 +2930,7 @@ pub fn execute_with_trace( bootloader_inputs: &[F], max_steps_to_execute: Option, profiling: Option, -) -> Execution { +) -> (Execution, BTreeMap) { log::info!("Executing (trace generation)..."); execute_inner( @@ -2912,6 +2943,7 @@ pub fn execute_with_trace( max_steps_to_execute.unwrap_or(usize::MAX), ExecMode::Trace, profiling, + Default::default(), ) } @@ -2926,7 +2958,7 @@ pub fn execute_with_witness( bootloader_inputs: &[F], max_steps_to_execute: Option, profiling: Option, -) -> Execution { +) -> (Execution, BTreeMap) { log::info!("Executing (trace generation)..."); execute_inner( @@ -2939,6 +2971,7 @@ pub fn execute_with_witness( max_steps_to_execute.unwrap_or(usize::MAX), ExecMode::Witness, profiling, + Default::default(), ) } @@ -2953,10 +2986,14 @@ fn execute_inner( max_steps_to_execute: usize, mode: ExecMode, profiling: Option, -) -> Execution { + precompile_blocks: BTreeMap>, +) -> (Execution, BTreeMap) { let start = Instant::now(); let main_machine = get_main_machine(asm); + let mut label_freq: BTreeMap = Default::default(); + let mut instr_freq: BTreeMap = Default::default(); + let PreprocessedMain { statements, label_map, @@ -3016,7 +3053,7 @@ fn execute_inner( mode, ) { Ok(proc) => proc, - Err(ret) => return *ret, + Err(ret) => return (*ret, Default::default()), }; let bootloader_inputs = bootloader_inputs @@ -3052,6 +3089,10 @@ fn execute_inner( e.proc.push_row(PC_INITIAL_VAL as u32); let mut last = Instant::now(); let mut count = 0; + let mut label_count = 0; + let mut ass_count = 0; + let mut debug_count = 0; + let mut precompile_calls = 0; loop { let stm = statements[curr_pc as usize]; @@ -3066,12 +3107,14 @@ fn execute_inner( if elapsed.as_secs_f64() > 1.0 { last = now; log::debug!("instructions/s: {}", count as f64 / elapsed.as_secs_f64(),); - count = 0; + //count = 0; } } match stm { FunctionStatement::Assignment(a) => { + ass_count += 1; + let pc = e.proc.get_pc().u(); e.proc.set_col(KnownWitnessCol::_operation_id, 2.into()); if let Some(p) = &mut profiler { @@ -3125,6 +3168,25 @@ fn execute_inner( e.proc.set_reg(dest, val); } } + FunctionStatement::Instruction(i) + if precompile_blocks.contains_key(&i.instruction.to_string()) => + { + precompile_calls += 1; + let name = i.instruction.to_string(); + let pc = e.proc.get_pc(); + + //println!("Executing precompile {}", i.instruction.to_string()); + for stmt in precompile_blocks.get(&name).unwrap() { + match stmt { + FunctionStatement::Instruction(i) => { + e.exec_instruction(&i.instruction, &i.inputs); + } + a => unreachable!("{a:?}"), + } + } + + e.proc.set_pc(pc.add(&Elem::Binary(1))); + } FunctionStatement::Instruction(i) => { e.proc.set_col(KnownWitnessCol::_operation_id, 2.into()); @@ -3132,6 +3194,12 @@ fn execute_inner( p.add_instruction_cost(e.proc.get_pc().u() as usize); } + let name = i.instruction.to_string(); + instr_freq + .entry(name.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); + if ["jump", "jump_dyn"].contains(&i.instruction.as_str()) { let pc_before = e.proc.get_pc().u(); @@ -3164,7 +3232,10 @@ fn execute_inner( break; } FunctionStatement::DebugDirective(dd) => { + debug_count += 1; + e.step -= 4; + count -= 1; match &dd.directive { DebugDirective::Loc(file, line, column) => { let (dir, file) = debug_files[file - 1]; @@ -3176,8 +3247,16 @@ fn execute_inner( DebugDirective::File(_, _, _) => unreachable!(), }; } - FunctionStatement::Label(_) => { - unreachable!() + FunctionStatement::Label(LabelStatement { source: _, name }) => { + label_count += 1; + + e.step -= 4; + count -= 1; + label_freq + .entry(name.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); + //unreachable!() } }; @@ -3186,6 +3265,12 @@ fn execute_inner( None => break, }; } + println!("Finish executor loop with true instruction count = {count}"); + + let total_freq = instr_freq.values().sum::(); + println!("Instr freq:\n{instr_freq:?}"); + println!("Total freq: {total_freq}"); + println!("Precompile count = {precompile_calls}, ass count = {ass_count}, label_count = {label_count}, debug count = {debug_count}"); if let Some(mut p) = profiler { p.finish(); @@ -3232,7 +3317,26 @@ fn execute_inner( } } - e.proc.finish(opt_pil, program_columns) + /* + let blocks = powdr_analysis::collect_basic_blocks(&asm); + let blocks = blocks + .into_iter() + .map(|(name, b)| { + let freq = label_freq.get(&name).unwrap_or(&0); + let l = b.len() as u64; + (name, b, freq, freq * l) + }) + .sorted_by_key(|(_, _, _, cost)| std::cmp::Reverse(*cost)) + .collect::>(); + for (name, block, freq, cost) in &blocks { + println!( + "{name}: size = {}, freq = {freq}, cost = {cost}", + block.len() + ); + } + */ + + (e.proc.finish(opt_pil, program_columns), label_freq) } /// Utility function for writing the executor witness CSV file. diff --git a/riscv/src/continuations.rs b/riscv/src/continuations.rs index 8abdca8064..512c220a42 100644 --- a/riscv/src/continuations.rs +++ b/riscv/src/continuations.rs @@ -362,7 +362,7 @@ pub fn rust_continuations_dry_run( // TODO: commit to the merkle_tree root in the verifier. log::info!("Initial execution..."); - let full_exec = powdr_riscv_executor::execute_with_trace::( + let (full_exec, _) = powdr_riscv_executor::execute_with_trace::( &asm, &pil, fixed.clone(), @@ -472,7 +472,7 @@ pub fn rust_continuations_dry_run( // execute the chunk log::info!("Simulating chunk execution..."); - let chunk_exec = powdr_riscv_executor::execute_with_trace::( + let (chunk_exec, _) = powdr_riscv_executor::execute_with_trace::( &asm, &pil, fixed.clone(), diff --git a/riscv/src/lib.rs b/riscv/src/lib.rs index 75f4bfa6f1..fb0ceb25fe 100644 --- a/riscv/src/lib.rs +++ b/riscv/src/lib.rs @@ -360,7 +360,7 @@ fn build_cargo_command( let mut cmd = Command::new("cargo"); cmd.env( "RUSTFLAGS", - "-g -C link-arg=-Tpowdr.x -C link-arg=--emit-relocs -C passes=lower-atomic -C panic=abort", + "-C opt-level=3 -C link-arg=-Tpowdr.x -C link-arg=--emit-relocs -C passes=lower-atomic,simplifycfg -C panic=abort", ); // keep debug info for the profiler (callgrind/flamegraph) cmd.env("CARGO_PROFILE_RELEASE_DEBUG", "true"); diff --git a/riscv/src/riscv_gl.asm b/riscv/src/riscv_gl.asm index 9347e91dab..8b79b1c8eb 100644 --- a/riscv/src/riscv_gl.asm +++ b/riscv/src/riscv_gl.asm @@ -97,7 +97,7 @@ machine Main with min_degree: MIN_DEGREE, max_degree: {{MAIN_MAX_DEGREE}} { // =========================================== // Increased by 4 in each step, because we do up to 4 register memory accesses per step - col fixed STEP(i) { 4 * i }; + col fixed STEP(i) { 3000 * i }; // ============== memory instructions ============== diff --git a/riscv/tests/riscv_data/keccak/src/main.rs b/riscv/tests/riscv_data/keccak/src/main.rs index 3286e796e0..3c4a468336 100644 --- a/riscv/tests/riscv_data/keccak/src/main.rs +++ b/riscv/tests/riscv_data/keccak/src/main.rs @@ -14,12 +14,14 @@ pub fn main() { } hasher.finalize(&mut output); - assert_eq!( - output, - [ - 0xb2, 0x60, 0x1c, 0x72, 0x12, 0xd8, 0x26, 0x0d, 0xa4, 0x6d, 0xde, 0x19, 0x8d, 0x50, - 0xa7, 0xe4, 0x67, 0x1f, 0xc1, 0xbb, 0x8f, 0xf2, 0xd1, 0x72, 0x5a, 0x8d, 0xa1, 0x08, - 0x11, 0xb5, 0x81, 0x69 - ], - ); + /* + assert_eq!( + output, + [ + 0xb2, 0x60, 0x1c, 0x72, 0x12, 0xd8, 0x26, 0x0d, 0xa4, 0x6d, 0xde, 0x19, 0x8d, 0x50, + 0xa7, 0xe4, 0x67, 0x1f, 0xc1, 0xbb, 0x8f, 0xf2, 0xd1, 0x72, 0x5a, 0x8d, 0xa1, 0x08, + 0x11, 0xb5, 0x81, 0x69 + ], + ); + */ }