From 3eba5c455c49cc7883411c0e459f35d2b21de571 Mon Sep 17 00:00:00 2001 From: Leandro Pacheco Date: Mon, 13 Jan 2025 12:09:45 -0300 Subject: [PATCH] Interpreter for witgen effects (#2301) Implementation of an interpreter for jit witgen effects. Performance seems to be around 0.20x of the compiled jit code (poseidon benchmark). - Disabled by default, can be enabled setting `POWDR_JIT_INTERPRETER=1`. - Also introduces a way to disable JIT in general, with `POWDR_JIT_DISABLE=1`. --- .../data_structures/finalizable_data.rs | 21 +- executor/src/witgen/jit/compiler.rs | 34 +- executor/src/witgen/jit/effect.rs | 21 + executor/src/witgen/jit/interpreter.rs | 551 ++++++++++++++++++ executor/src/witgen/jit/mod.rs | 1 + .../src/witgen/jit/symbolic_expression.rs | 28 +- 6 files changed, 605 insertions(+), 51 deletions(-) create mode 100644 executor/src/witgen/jit/interpreter.rs diff --git a/executor/src/witgen/data_structures/finalizable_data.rs b/executor/src/witgen/data_structures/finalizable_data.rs index 670e17df0f..7cb8e42fe8 100644 --- a/executor/src/witgen/data_structures/finalizable_data.rs +++ b/executor/src/witgen/data_structures/finalizable_data.rs @@ -91,7 +91,7 @@ impl CompactData { } /// Sets an entire row at the given index - pub fn set(&mut self, row: usize, new_row: Row) { + pub fn set_row(&mut self, row: usize, new_row: Row) { let idx = row * self.column_count; for (i, col_id) in self.column_ids().enumerate() { if let Some(v) = new_row.value(&col_id) { @@ -109,6 +109,7 @@ impl CompactData { self.known_cells.append_empty_rows(count); } + #[inline] fn index(&self, row: usize, col: u64) -> usize { let col = col - self.first_column_id; row * self.column_count + col as usize @@ -121,6 +122,14 @@ impl CompactData { }) } + /// Sets a single cell + pub fn set(&mut self, row: usize, col: u64, value: T) { + let idx = self.index(row, col); + self.data[idx] = value; + let relative_col = col - self.first_column_id; + self.known_cells.set(row, relative_col, true); + } + pub fn get(&self, row: usize, col: u64) -> (T, bool) { let idx = self.index(row, col); let relative_col = col - self.first_column_id; @@ -146,8 +155,8 @@ impl CompactData { /// only for a certain block of rows, starting from row index zero. /// It allows negative row indices as well. pub struct CompactDataRef<'a, T> { - data: &'a mut CompactData, - row_offset: usize, + pub data: &'a mut CompactData, + pub row_offset: usize, } impl<'a, T: FieldElement> CompactDataRef<'a, T> { @@ -160,10 +169,6 @@ impl<'a, T: FieldElement> CompactDataRef<'a, T> { pub fn as_mut_slices(&mut self) -> (&mut [T], &mut [u32]) { self.data.as_mut_slices() } - - pub fn row_offset(&self) -> usize { - self.row_offset - } } /// A data structure that stores witness data. @@ -413,7 +418,7 @@ impl<'a, T: FieldElement> FinalizableData<'a, T> { pub fn set(&mut self, i: usize, row: Row) { match self.location_of_row(i) { Location::Finalized(local) => { - self.finalized_data.set(local, row); + self.finalized_data.set_row(local, row); } Location::PostFinalized(local) => self.post_finalized_data[local] = row, } diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index 98e329eab6..684f88b4b2 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -1,4 +1,4 @@ -use std::{cmp::Ordering, ffi::c_void, iter, mem, sync::Arc}; +use std::{cmp::Ordering, ffi::c_void, mem, sync::Arc}; use itertools::Itertools; use libloading::Library; @@ -20,7 +20,7 @@ use super::{ variable::Variable, }; -pub struct WitgenFunction { +pub struct WitgenFunction { // TODO We might want to pass arguments as direct function parameters // (instead of a struct), so that // they are stored in registers instead of the stack. Should be checked. @@ -39,7 +39,7 @@ impl WitgenFunction { params: &mut [LookupCell], mut data: CompactDataRef<'_, T>, ) { - let row_offset = data.row_offset().try_into().unwrap(); + let row_offset = data.row_offset.try_into().unwrap(); let (data, known) = data.as_mut_slices(); (self.function)(WitgenFunctionParams { data: data.into(), @@ -160,7 +160,7 @@ fn witgen_code( let main_code = format_effects(effects); let vars_known = effects .iter() - .flat_map(written_vars_in_effect) + .flat_map(Effect::written_vars) .map(|(var, _)| var) .collect_vec(); let store_values = vars_known @@ -224,30 +224,6 @@ extern "C" fn witgen( ) } -/// Returns an iterator over all variables written to in the effect. -/// The flag indicates if the variable is the return value of a machine call and thus needs -/// to be declared mutable. -fn written_vars_in_effect( - effect: &Effect, -) -> Box + '_> { - match effect { - Effect::Assignment(var, _) => Box::new(iter::once((var, false))), - Effect::RangeConstraint(..) => unreachable!(), - Effect::Assertion(..) => Box::new(iter::empty()), - Effect::MachineCall(_, known, vars) => Box::new( - vars.iter() - .zip_eq(known) - .flat_map(|(v, known)| (!known).then_some((v, true))), - ), - Effect::Branch(_, first, second) => Box::new( - first - .iter() - .chain(second) - .flat_map(|e| written_vars_in_effect(e)), - ), - } -} - pub fn format_effects(effects: &[Effect]) -> String { format_effects_inner(effects, true) } @@ -321,7 +297,7 @@ fn format_effect(effect: &Effect, is_top_level: bo first .iter() .chain(second) - .flat_map(|e| written_vars_in_effect(e)) + .flat_map(|e| e.written_vars()) .sorted() .dedup() .map(|(v, needs_mut)| { diff --git a/executor/src/witgen/jit/effect.rs b/executor/src/witgen/jit/effect.rs index f193e5aff1..158663270a 100644 --- a/executor/src/witgen/jit/effect.rs +++ b/executor/src/witgen/jit/effect.rs @@ -27,6 +27,27 @@ pub enum Effect { Branch(BranchCondition, Vec>, Vec>), } +impl Effect { + /// Returns an iterator over all variables written to in the effect. + /// The flag indicates if the variable is the return value of a machine call and thus needs + /// to be declared mutable. + pub fn written_vars(&self) -> Box + '_> { + match self { + Effect::Assignment(var, _) => Box::new(iter::once((var, false))), + Effect::RangeConstraint(..) => unreachable!(), + Effect::Assertion(..) => Box::new(iter::empty()), + Effect::MachineCall(_, known, vars) => Box::new( + vars.iter() + .zip_eq(known) + .flat_map(|(v, known)| (!known).then_some((v, true))), + ), + Effect::Branch(_, first, second) => { + Box::new(first.iter().chain(second).flat_map(|e| e.written_vars())) + } + } + } +} + impl Effect { pub fn referenced_variables(&self) -> impl Iterator { let iter: Box> = match self { diff --git a/executor/src/witgen/jit/interpreter.rs b/executor/src/witgen/jit/interpreter.rs new file mode 100644 index 0000000000..ef8cde78a7 --- /dev/null +++ b/executor/src/witgen/jit/interpreter.rs @@ -0,0 +1,551 @@ +// TODO: the unused is only here because the interpreter is not integrated in the final code yet +#![allow(unused)] +use super::effect::{Assertion, Effect}; + +use super::symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator}; +use super::variable::{Cell, Variable}; +use crate::witgen::data_structures::finalizable_data::CompactDataRef; +use crate::witgen::data_structures::mutable_state::MutableState; +use crate::witgen::machines::LookupCell; +use crate::witgen::QueryCallback; +use powdr_number::FieldElement; + +use std::collections::{BTreeSet, HashMap}; + +/// Interpreter for instructions compiled from witgen effects. +pub struct EffectsInterpreter { + var_count: usize, + actions: Vec>, +} + +/// Witgen effects compiled into instructions for a stack machine. +/// Variables have been removed and replaced by their index in the variable list. +enum InterpreterAction { + ReadCell(usize, Cell), + ReadParam(usize, usize), + AssignExpression(usize, RPNExpression), + WriteCell(usize, Cell), + WriteParam(usize, usize), + MachineCall(u64, Vec), + Assertion(RPNExpression, RPNExpression, bool), +} + +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd)] +pub enum MachineCallArgumentIdx { + /// var index of the evaluated known argument expression + Known(usize), + /// var index of the unknown + Unknown(usize), +} + +impl EffectsInterpreter { + pub fn new(known_inputs: &[Variable], effects: &[Effect]) -> Self { + let mut actions = vec![]; + let mut var_mapper = VariableMapper::new(); + + Self::load_known_inputs(&mut var_mapper, &mut actions, known_inputs); + Self::process_effects(&mut var_mapper, &mut actions, effects); + Self::write_data(&mut var_mapper, &mut actions, effects); + + let ret = Self { + var_count: var_mapper.var_count(), + actions, + }; + assert!(ret.is_valid()); + ret + } + + fn load_known_inputs( + var_mapper: &mut VariableMapper, + actions: &mut Vec>, + known_inputs: &[Variable], + ) { + actions.extend(known_inputs.iter().map(|var| match var { + Variable::Cell(c) => { + let idx = var_mapper.map_var(var); + InterpreterAction::ReadCell(idx, c.clone()) + } + Variable::Param(i) => { + let idx = var_mapper.map_var(var); + InterpreterAction::ReadParam(idx, *i) + } + Variable::MachineCallParam(_) => unreachable!(), + })); + } + + fn process_effects( + var_mapper: &mut VariableMapper, + actions: &mut Vec>, + effects: &[Effect], + ) { + effects.iter().for_each(|effect| { + let action = match effect { + Effect::Assignment(var, e) => { + let idx = var_mapper.map_var(var); + InterpreterAction::AssignExpression(idx, var_mapper.map_expr_to_rpn(e)) + } + Effect::RangeConstraint(..) => { + unreachable!("Final code should not contain pure range constraints.") + } + Effect::Assertion(Assertion { + lhs, + rhs, + expected_equal, + }) => InterpreterAction::Assertion( + var_mapper.map_expr_to_rpn(lhs), + var_mapper.map_expr_to_rpn(rhs), + *expected_equal, + ), + Effect::MachineCall(id, known_inputs, arguments) => { + let arguments = known_inputs + .iter() + .zip(arguments) + .map(|(is_input, var)| { + if is_input { + MachineCallArgumentIdx::Known(var_mapper.map_var(var)) + } else { + MachineCallArgumentIdx::Unknown(var_mapper.map_var(var)) + } + }) + .collect(); + InterpreterAction::MachineCall(*id, arguments) + } + Effect::Branch(..) => { + unimplemented!("Branches are not supported in the interpreter yet") + } + }; + actions.push(action); + }) + } + + fn write_data( + var_mapper: &mut VariableMapper, + actions: &mut Vec>, + effects: &[Effect], + ) { + effects + .iter() + .flat_map(Effect::written_vars) + .for_each(|(var, _mutable)| { + match var { + Variable::Cell(cell) => { + let idx = var_mapper.get_var(var).unwrap(); + actions.push(InterpreterAction::WriteCell(idx, cell.clone())); + } + Variable::Param(i) => { + let idx = var_mapper.get_var(var).unwrap(); + actions.push(InterpreterAction::WriteParam(idx, *i)); + } + Variable::MachineCallParam(_) => { + // This is just an internal variable. + } + } + }); + } + + /// Check that actions are valid (e.g., variables written to only once, and only read after being written to) + fn is_valid(&self) -> bool { + let mut prev_writes = BTreeSet::new(); + for action in &self.actions { + let writes = action.writes(); + // writing to a variable already written? + if !writes.is_disjoint(&prev_writes) { + return false; + } + // reading a variable that was not written to? + if !action.reads().is_subset(&prev_writes) { + return false; + } + prev_writes.extend(writes); + } + true + } + + /// Execute the machine effects for the given the parameters + pub fn call>( + &self, + mutable_state: &MutableState<'_, T, Q>, + params: &mut [LookupCell], + data: CompactDataRef<'_, T>, + ) { + let mut vars = vec![T::zero(); self.var_count]; + + let mut eval_stack = vec![]; + for action in &self.actions { + match action { + InterpreterAction::AssignExpression(idx, e) => { + let val = e.evaluate(&mut eval_stack, &vars[..]); + vars[*idx] = val; + } + InterpreterAction::ReadCell(idx, c) => { + let row_offset: i32 = data.row_offset.try_into().unwrap(); + vars[*idx] = data + .data + .get((row_offset + c.row_offset).try_into().unwrap(), c.id) + .0; + } + InterpreterAction::ReadParam(idx, i) => { + vars[*idx] = get_param(params, *i); + } + InterpreterAction::WriteCell(idx, c) => { + let row_offset: i32 = data.row_offset.try_into().unwrap(); + data.data.set( + (row_offset + c.row_offset).try_into().unwrap(), + c.id, + vars[*idx], + ); + } + InterpreterAction::WriteParam(idx, i) => { + set_param(params, *i, vars[*idx]); + } + InterpreterAction::MachineCall(id, arguments) => { + // we know it's safe to escape the references here, but the compiler doesn't, so we use unsafe + let mut args = arguments + .iter() + .map(|a| match a { + MachineCallArgumentIdx::Unknown(idx) => { + let var = &mut vars[*idx] as *mut T; + LookupCell::Output(unsafe { var.as_mut().unwrap() }) + } + MachineCallArgumentIdx::Known(idx) => { + let var = &vars[*idx] as *const T; + LookupCell::Input(unsafe { var.as_ref().unwrap() }) + } + }) + .collect::>(); + mutable_state.call_direct(*id, &mut args[..]).unwrap(); + } + InterpreterAction::Assertion(e1, e2, expected_equal) => { + let lhs_value = e1.evaluate(&mut eval_stack, &vars); + let rhs_value = e2.evaluate(&mut eval_stack, &vars); + if *expected_equal { + assert_eq!(lhs_value, rhs_value, "Assertion failed"); + } else { + assert_ne!(lhs_value, rhs_value, "Assertion failed"); + } + } + } + } + assert!(eval_stack.is_empty()); + } +} + +impl InterpreterAction { + /// variable indexes written by the action + fn writes(&self) -> BTreeSet { + let mut set = BTreeSet::new(); + match self { + InterpreterAction::ReadCell(idx, _) + | InterpreterAction::ReadParam(idx, _) + | InterpreterAction::AssignExpression(idx, _) => { + set.insert(*idx); + } + InterpreterAction::MachineCall(_, params) => params.iter().for_each(|p| { + if let MachineCallArgumentIdx::Unknown(v) = p { + set.insert(*v); + } + }), + _ => {} + } + set + } + + /// variable indexes read by the action + fn reads(&self) -> BTreeSet { + let mut set = BTreeSet::new(); + match self { + InterpreterAction::WriteCell(idx, _) | InterpreterAction::WriteParam(idx, _) => { + set.insert(*idx); + } + InterpreterAction::AssignExpression(_, expr) => expr.elems.iter().for_each(|e| { + if let RPNExpressionElem::Symbol(idx) = e { + set.insert(*idx); + } + }), + InterpreterAction::MachineCall(_, params) => params.iter().for_each(|p| { + if let MachineCallArgumentIdx::Known(v) = p { + set.insert(*v); + } + }), + InterpreterAction::Assertion(lhs, rhs, _) => { + lhs.elems.iter().for_each(|e| { + if let RPNExpressionElem::Symbol(idx) = e { + set.insert(*idx); + } + }); + rhs.elems.iter().for_each(|e| { + if let RPNExpressionElem::Symbol(idx) = e { + set.insert(*idx); + } + }); + } + _ => {} + } + set + } +} + +/// Helper struct to map variables to contiguous indices, so they can be kept in +/// sequential memory and quickly refered to during execution. +pub struct VariableMapper { + var_idx: HashMap, + count: usize, +} + +impl VariableMapper { + pub fn new() -> Self { + Self { + var_idx: HashMap::new(), + count: 0, + } + } + + pub fn var_count(&self) -> usize { + self.count + } + + pub fn map_var(&mut self, var: &Variable) -> usize { + let idx = *self.var_idx.entry(var.clone()).or_insert_with(|| { + self.count += 1; + self.count - 1 + }); + idx + } + + /// reserve a new variable index + pub fn reserve_idx(&mut self) -> usize { + let idx = self.count; + self.count += 1; + idx + } + + /// get the index of a variable if it was previously mapped + pub fn get_var(&mut self, var: &Variable) -> Option { + self.var_idx.get(var).copied() + } + + pub fn map_expr_to_rpn( + &mut self, + expr: &SymbolicExpression, + ) -> RPNExpression { + RPNExpression::map_from(expr, self) + } +} + +/// An expression in Reverse Polish Notation. +pub struct RPNExpression { + pub elems: Vec>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RPNExpressionElem { + Concrete(T), + Symbol(S), + BinaryOperation(BinaryOperator), + UnaryOperation(UnaryOperator), + BitOperation(BitOperator, T::Integer), +} + +impl RPNExpression { + /// Convert a symbolic expression to RPN, mapping variables to indices + fn map_from(expr: &SymbolicExpression, var_mapper: &mut VariableMapper) -> Self { + fn inner( + expr: &SymbolicExpression, + elems: &mut Vec>, + var_mapper: &mut VariableMapper, + ) { + match expr { + SymbolicExpression::Concrete(n) => { + elems.push(RPNExpressionElem::Concrete(*n)); + } + SymbolicExpression::Symbol(s, _) => { + elems.push(RPNExpressionElem::Symbol(var_mapper.map_var(s))); + } + SymbolicExpression::BinaryOperation(lhs, op, rhs, _) => { + inner(lhs, elems, var_mapper); + inner(rhs, elems, var_mapper); + elems.push(RPNExpressionElem::BinaryOperation(op.clone())); + } + SymbolicExpression::UnaryOperation(op, expr, _) => { + inner(expr, elems, var_mapper); + elems.push(RPNExpressionElem::UnaryOperation(op.clone())); + } + SymbolicExpression::BitOperation(expr, op, n, _) => { + inner(expr, elems, var_mapper); + elems.push(RPNExpressionElem::BitOperation(op.clone(), *n)); + } + } + } + let mut elems = Vec::new(); + inner(expr, &mut elems, var_mapper); + RPNExpression { elems } + } + + /// Evaluate the expression using the provided variables. + /// The stack is used to store intermediate results, it's taken as + /// a parameter to avoid allocating on every call to evaluate. + fn evaluate(&self, stack: &mut Vec, vars: &[T]) -> T { + self.elems.iter().for_each(|elem| match elem { + RPNExpressionElem::Concrete(v) => stack.push(*v), + RPNExpressionElem::Symbol(idx) => stack.push(vars[*idx]), + RPNExpressionElem::BinaryOperation(op) => { + let right = stack.pop().unwrap(); + let left = stack.pop().unwrap(); + let result = match op { + BinaryOperator::Add => left + right, + BinaryOperator::Sub => left - right, + BinaryOperator::Mul => left * right, + BinaryOperator::Div => left / right, + BinaryOperator::IntegerDiv => { + T::from(left.to_arbitrary_integer() / right.to_arbitrary_integer()) + } + }; + stack.push(result); + } + RPNExpressionElem::UnaryOperation(op) => { + let inner = stack.pop().unwrap(); + let result = match op { + UnaryOperator::Neg => -inner, + }; + stack.push(result); + } + RPNExpressionElem::BitOperation(op, right) => { + let left = stack.pop().unwrap(); + let result = match op { + BitOperator::And => T::from(left.to_integer() & *right), + }; + stack.push(result); + } + }); + stack.pop().unwrap() + } +} + +#[inline] +fn get_param(params: &[LookupCell], i: usize) -> T { + match params[i] { + LookupCell::Input(v) => *v, + LookupCell::Output(_) => panic!("Output cell used as input"), + } +} +#[inline] +fn set_param(params: &mut [LookupCell], i: usize, value: T) { + match &mut params[i] { + LookupCell::Input(_) => panic!("Input cell used as output"), + LookupCell::Output(v) => **v = value, + } +} + +#[cfg(test)] +mod test { + use std::fs::read_to_string; + + use super::EffectsInterpreter; + use crate::witgen::data_structures::{ + finalizable_data::{CompactData, CompactDataRef}, + mutable_state::MutableState, + }; + use crate::witgen::global_constraints; + use crate::witgen::jit::block_machine_processor::BlockMachineProcessor; + use crate::witgen::jit::test_util::read_pil; + use crate::witgen::jit::variable::Variable; + use crate::witgen::machines::{ + machine_extractor::MachineExtractor, KnownMachine, LookupCell, Machine, + }; + use crate::witgen::FixedData; + + use bit_vec::BitVec; + use itertools::Itertools; + use powdr_number::GoldilocksField; + + #[test] + fn call_poseidon() { + let file = "../test_data/pil/poseidon_gl.pil"; + let machine_name = "main_poseidon"; + let (num_inputs, num_outputs) = (12, 4); + let pil = read_to_string(file).unwrap(); + + let (analyzed, fixed_col_vals) = read_pil::(&pil); + + let fixed_data = FixedData::new(&analyzed, &fixed_col_vals, &[], Default::default(), 0); + let (fixed_data, retained_identities) = + global_constraints::set_global_constraints(fixed_data, &analyzed.identities); + let machines = MachineExtractor::new(&fixed_data).split_out_machines(retained_identities); + let [KnownMachine::BlockMachine(machine)] = machines + .iter() + .filter(|m| m.name().contains(machine_name)) + .collect_vec() + .as_slice() + else { + panic!("Expected exactly one matching block machine") + }; + let (machine_parts, block_size, latch_row) = machine.machine_info(); + assert_eq!(machine_parts.connections.len(), 1); + let connection_id = *machine_parts.connections.keys().next().unwrap(); + let processor = + BlockMachineProcessor::new(&fixed_data, machine_parts.clone(), block_size, latch_row); + + let mutable_state = MutableState::new(machines.into_iter(), &|_| { + Err("Query not implemented".to_string()) + }); + + let known_values = BitVec::from_iter( + (0..num_inputs) + .map(|_| true) + .chain((0..num_outputs).map(|_| false)), + ); + + let effects = processor + .generate_code(&mutable_state, connection_id, &known_values) + .unwrap(); + + let known_inputs = (0..12).map(Variable::Param).collect::>(); + + // generate interpreter + let interpreter = EffectsInterpreter::new(&known_inputs, &effects); + // call it + let mut params = [GoldilocksField::default(); 16]; + let mut param_lookups = params + .iter_mut() + .enumerate() + .map(|(i, p)| { + if i < 12 { + LookupCell::Input(p) + } else { + LookupCell::Output(p) + } + }) + .collect::>(); + let poly_ids = analyzed + .committed_polys_in_source_order() + .flat_map(|p| p.0.array_elements().map(|e| e.1)) + .collect_vec(); + + let mut data = CompactData::new(&poly_ids); + data.append_new_rows(31); + let data_ref = CompactDataRef::new(&mut data, 0); + interpreter.call(&mutable_state, &mut param_lookups, data_ref); + + assert_eq!( + ¶ms, + &[ + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(0), + GoldilocksField::from(4330397376401421145u64), + GoldilocksField::from(14124799381142128323u64), + GoldilocksField::from(8742572140681234676u64), + GoldilocksField::from(14345658006221440202u64), + ] + ) + } +} diff --git a/executor/src/witgen/jit/mod.rs b/executor/src/witgen/jit/mod.rs index 4839e76318..0af557ca9c 100644 --- a/executor/src/witgen/jit/mod.rs +++ b/executor/src/witgen/jit/mod.rs @@ -3,6 +3,7 @@ mod block_machine_processor; mod compiler; mod effect; pub(crate) mod function_cache; +mod interpreter; mod single_step_processor; mod symbolic_expression; mod variable; diff --git a/executor/src/witgen/jit/symbolic_expression.rs b/executor/src/witgen/jit/symbolic_expression.rs index 1301d7ad85..7edc12fbdf 100644 --- a/executor/src/witgen/jit/symbolic_expression.rs +++ b/executor/src/witgen/jit/symbolic_expression.rs @@ -8,7 +8,7 @@ use std::{ fmt::{self, Display, Formatter}, iter, ops::{Add, BitAnd, Mul, Neg}, - rc::Rc, + sync::Arc, }; use crate::witgen::range_constraints::RangeConstraint; @@ -23,9 +23,9 @@ pub enum SymbolicExpression { /// A symbolic value known at run-time, referencing a cell, /// an input, a local variable or whatever it is used for. Symbol(S, RangeConstraint), - BinaryOperation(Rc, BinaryOperator, Rc, RangeConstraint), - UnaryOperation(UnaryOperator, Rc, RangeConstraint), - BitOperation(Rc, BitOperator, T::Integer, RangeConstraint), + BinaryOperation(Arc, BinaryOperator, Arc, RangeConstraint), + UnaryOperation(UnaryOperator, Arc, RangeConstraint), + BitOperation(Arc, BitOperator, T::Integer, RangeConstraint), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -199,9 +199,9 @@ impl Add for &SymbolicExpression { SymbolicExpression::Concrete(*a + *b) } _ => SymbolicExpression::BinaryOperation( - Rc::new(self.clone()), + Arc::new(self.clone()), BinaryOperator::Add, - Rc::new(rhs.clone()), + Arc::new(rhs.clone()), self.range_constraint().combine_sum(&rhs.range_constraint()), ), } @@ -226,7 +226,7 @@ impl Neg for &SymbolicExpression { } _ => SymbolicExpression::UnaryOperation( UnaryOperator::Neg, - Rc::new(self.clone()), + Arc::new(self.clone()), self.range_constraint().multiple(-T::from(1)), ), } @@ -258,9 +258,9 @@ impl Mul for &SymbolicExpression { -self } else { SymbolicExpression::BinaryOperation( - Rc::new(self.clone()), + Arc::new(self.clone()), BinaryOperator::Mul, - Rc::new(rhs.clone()), + Arc::new(rhs.clone()), Default::default(), ) } @@ -290,9 +290,9 @@ impl SymbolicExpression { } else { // TODO other simplifications like `-x / -y => x / y`, `-x / concrete => x / -concrete`, etc. SymbolicExpression::BinaryOperation( - Rc::new(self.clone()), + Arc::new(self.clone()), BinaryOperator::Div, - Rc::new(rhs.clone()), + Arc::new(rhs.clone()), Default::default(), ) } @@ -307,9 +307,9 @@ impl SymbolicExpression { self.clone() } else { SymbolicExpression::BinaryOperation( - Rc::new(self.clone()), + Arc::new(self.clone()), BinaryOperator::IntegerDiv, - Rc::new(rhs.clone()), + Arc::new(rhs.clone()), Default::default(), ) } @@ -326,7 +326,7 @@ impl BitAnd for SymbolicExpression SymbolicExpression::Concrete(T::from(0)) } else { let rc = RangeConstraint::from_mask(*self.range_constraint().mask() & rhs); - SymbolicExpression::BitOperation(Rc::new(self), BitOperator::And, rhs, rc) + SymbolicExpression::BitOperation(Arc::new(self), BitOperator::And, rhs, rc) } } }