Skip to content

Commit

Permalink
Simplify range constraints. (#2314)
Browse files Browse the repository at this point in the history
Fixes #2313
  • Loading branch information
chriseth authored Jan 8, 2025
1 parent 379267d commit 707522d
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 86 deletions.
2 changes: 1 addition & 1 deletion executor/src/witgen/data_structures/mutable_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
&self,
identity_id: u64,
known_inputs: &BitVec,
range_constraints: &[Option<RangeConstraint<T>>],
range_constraints: &[RangeConstraint<T>],
) -> bool {
// TODO We are currently ignoring bus interaction (also, but not only because there is no
// unique machine responsible for handling a bus send), so just answer "false" if the identity
Expand Down
35 changes: 17 additions & 18 deletions executor/src/witgen/jit/affine_symbolic_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ impl<T: FieldElement, V> From<T> for AffineSymbolicExpression<T, V> {
}

impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
pub fn from_known_symbol(symbol: V, rc: Option<RangeConstraint<T>>) -> Self {
pub fn from_known_symbol(symbol: V, rc: RangeConstraint<T>) -> Self {
SymbolicExpression::from_symbol(symbol, rc).into()
}
pub fn from_unknown_variable(var: V, rc: Option<RangeConstraint<T>>) -> Self {
pub fn from_unknown_variable(var: V, rc: RangeConstraint<T>) -> Self {
AffineSymbolicExpression {
coefficients: [(var.clone(), T::from(1).into())].into_iter().collect(),
offset: SymbolicExpression::from(T::from(0)),
range_constraints: rc.into_iter().map(|rc| (var.clone(), rc)).collect(),
range_constraints: [(var.clone(), rc)].into_iter().collect(),
}
}

Expand Down Expand Up @@ -275,7 +275,7 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
let rc = self.range_constraints.get(var)?;
Some(rc.multiple(coeff))
})
.chain(std::iter::once(self.offset.range_constraint()))
.chain(std::iter::once(Some(self.offset.range_constraint())))
.collect::<Option<Vec<_>>>()?;
let constraint = summands.into_iter().reduce(|c1, c2| c1.combine_sum(&c2))?;
let constraint = if solve_for_coefficient.is_known_one() {
Expand Down Expand Up @@ -408,8 +408,8 @@ mod test {

#[test]
fn unsolvable_with_vars() {
let x = &Ase::from_known_symbol("X", None);
let y = &Ase::from_known_symbol("Y", None);
let x = &Ase::from_known_symbol("X", Default::default());
let y = &Ase::from_known_symbol("Y", Default::default());
let constr = x + y - from_number(10);
// We cannot solve it, but we can also not learn anything new from it.
let result = constr.solve().unwrap();
Expand All @@ -427,8 +427,8 @@ mod test {

#[test]
fn solve_simple_eq() {
let y = Ase::from_known_symbol("y", None);
let x = Ase::from_unknown_variable("X", None);
let y = Ase::from_known_symbol("y", Default::default());
let x = Ase::from_unknown_variable("X", Default::default());
// 2 * X + 7 * y - 10 = 0
let two = from_number(2);
let seven = from_number(7);
Expand All @@ -446,18 +446,17 @@ mod test {

#[test]
fn solve_div_by_range_constrained_var() {
let y = Ase::from_known_symbol("y", None);
let z = Ase::from_known_symbol("z", None);
let x = Ase::from_unknown_variable("X", None);
let y = Ase::from_known_symbol("y", Default::default());
let z = Ase::from_known_symbol("z", Default::default());
let x = Ase::from_unknown_variable("X", Default::default());
// z * X + 7 * y - 10 = 0
let seven = from_number(7);
let ten = from_number(10);
let constr = mul(&z, &x) + mul(&seven, &y) - ten.clone();
// If we do not range-constrain z, we cannot solve since we don't know if it might be zero.
let result = constr.solve().unwrap();
assert!(!result.complete && result.effects.is_empty());
let z =
Ase::from_known_symbol("z", Some(RangeConstraint::from_range(10.into(), 20.into())));
let z = Ase::from_known_symbol("z", RangeConstraint::from_range(10.into(), 20.into()));
let constr = mul(&z, &x) + mul(&seven, &y) - ten;
let result = constr.solve().unwrap();
assert!(result.complete);
Expand All @@ -471,12 +470,12 @@ mod test {

#[test]
fn solve_bit_decomposition() {
let rc = Some(RangeConstraint::from_mask(0xffu32));
let rc = RangeConstraint::from_mask(0xffu32);
// First try without range constrain on a
let a = Ase::from_unknown_variable("a", None);
let a = Ase::from_unknown_variable("a", Default::default());
let b = Ase::from_unknown_variable("b", rc.clone());
let c = Ase::from_unknown_variable("c", rc.clone());
let z = Ase::from_known_symbol("Z", None);
let z = Ase::from_known_symbol("Z", Default::default());
// a * 0x100 + b * 0x10000 + c * 0x1000000 + 10 + Z = 0
let ten = from_number(10);
let constr = mul(&a, &from_number(0x100))
Expand Down Expand Up @@ -527,11 +526,11 @@ assert (-(10 + Z) & 18446744069414584575) == 0;

#[test]
fn solve_constraint_transfer() {
let rc = Some(RangeConstraint::from_mask(0xffu32));
let rc = RangeConstraint::from_mask(0xffu32);
let a = Ase::from_unknown_variable("a", rc.clone());
let b = Ase::from_unknown_variable("b", rc.clone());
let c = Ase::from_unknown_variable("c", rc.clone());
let z = Ase::from_unknown_variable("Z", None);
let z = Ase::from_unknown_variable("Z", Default::default());
// a * 0x100 + b * 0x10000 + c * 0x1000000 + 10 - Z = 0
let ten = from_number(10);
let constr = mul(&a, &from_number(0x100))
Expand Down
2 changes: 1 addition & 1 deletion executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ mod tests {
}

fn symbol(var: &Variable) -> SymbolicExpression<GoldilocksField, Variable> {
SymbolicExpression::from_symbol(var.clone(), None)
SymbolicExpression::from_symbol(var.clone(), Default::default())
}

fn number(n: u64) -> SymbolicExpression<GoldilocksField, Variable> {
Expand Down
2 changes: 1 addition & 1 deletion executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
let Some(most_constrained_var) = witgen
.known_variables()
.iter()
.filter_map(|var| witgen.range_constraint(var).map(|rc| (var, rc)))
.map(|var| (var, witgen.range_constraint(var)))
.filter(|(_, rc)| rc.try_to_single_value().is_none())
.sorted()
.min_by_key(|(_, rc)| rc.range_width())
Expand Down
51 changes: 14 additions & 37 deletions executor/src/witgen/jit/symbolic_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,10 @@ pub enum SymbolicExpression<T: FieldElement, S> {
Concrete(T),
/// A symbolic value known at run-time, referencing a cell,
/// an input, a local variable or whatever it is used for.
Symbol(S, Option<RangeConstraint<T>>),
BinaryOperation(
Rc<Self>,
BinaryOperator,
Rc<Self>,
Option<RangeConstraint<T>>,
),
UnaryOperation(UnaryOperator, Rc<Self>, Option<RangeConstraint<T>>),
BitOperation(
Rc<Self>,
BitOperator,
T::Integer,
Option<RangeConstraint<T>>,
),
Symbol(S, RangeConstraint<T>),
BinaryOperation(Rc<Self>, BinaryOperator, Rc<Self>, RangeConstraint<T>),
UnaryOperation(UnaryOperator, Rc<Self>, RangeConstraint<T>),
BitOperation(Rc<Self>, BitOperator, T::Integer, RangeConstraint<T>),
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand All @@ -56,7 +46,7 @@ pub enum UnaryOperator {
}

impl<T: FieldElement, S> SymbolicExpression<T, S> {
pub fn from_symbol(symbol: S, rc: Option<RangeConstraint<T>>) -> Self {
pub fn from_symbol(symbol: S, rc: RangeConstraint<T>) -> Self {
SymbolicExpression::Symbol(symbol, rc)
}

Expand All @@ -75,17 +65,12 @@ impl<T: FieldElement, S> SymbolicExpression<T, S> {
pub fn is_known_nonzero(&self) -> bool {
// Only checking range constraint is enough since if this is a known
// fixed value, we will get a range constraint with just a single value.
if let Some(rc) = self.range_constraint() {
!rc.allows_value(0.into())
} else {
// unknown
false
}
!self.range_constraint().allows_value(0.into())
}

pub fn range_constraint(&self) -> Option<RangeConstraint<T>> {
pub fn range_constraint(&self) -> RangeConstraint<T> {
match self {
SymbolicExpression::Concrete(v) => Some(RangeConstraint::from_value(*v)),
SymbolicExpression::Concrete(v) => RangeConstraint::from_value(*v),
SymbolicExpression::Symbol(.., rc)
| SymbolicExpression::BinaryOperation(.., rc)
| SymbolicExpression::UnaryOperation(.., rc)
Expand Down Expand Up @@ -179,9 +164,7 @@ impl<T: FieldElement, V: Clone> Add for &SymbolicExpression<T, V> {
Rc::new(self.clone()),
BinaryOperator::Add,
Rc::new(rhs.clone()),
self.range_constraint()
.zip(rhs.range_constraint())
.map(|(a, b)| a.combine_sum(&b)),
self.range_constraint().combine_sum(&rhs.range_constraint()),
),
}
}
Expand All @@ -206,7 +189,7 @@ impl<T: FieldElement, V: Clone> Neg for &SymbolicExpression<T, V> {
_ => SymbolicExpression::UnaryOperation(
UnaryOperator::Neg,
Rc::new(self.clone()),
self.range_constraint().map(|rc| rc.multiple(-T::from(1))),
self.range_constraint().multiple(-T::from(1)),
),
}
}
Expand Down Expand Up @@ -240,7 +223,7 @@ impl<T: FieldElement, V: Clone> Mul for &SymbolicExpression<T, V> {
Rc::new(self.clone()),
BinaryOperator::Mul,
Rc::new(rhs.clone()),
None,
Default::default(),
)
}
}
Expand Down Expand Up @@ -272,7 +255,7 @@ impl<T: FieldElement, V: Clone> SymbolicExpression<T, V> {
Rc::new(self.clone()),
BinaryOperator::Div,
Rc::new(rhs.clone()),
None,
Default::default(),
)
}
}
Expand All @@ -286,7 +269,7 @@ impl<T: FieldElement, V: Clone> SymbolicExpression<T, V> {
Rc::new(self.clone()),
BinaryOperator::IntegerDiv,
Rc::new(rhs.clone()),
None,
Default::default(),
)
}
}
Expand All @@ -301,13 +284,7 @@ impl<T: FieldElement, V: Clone> BitAnd<T::Integer> for SymbolicExpression<T, V>
} else if self.is_known_zero() || rhs.is_zero() {
SymbolicExpression::Concrete(T::from(0))
} else {
let rc = Some(RangeConstraint::from_mask(
if let Some(rc) = self.range_constraint() {
*rc.mask() & rhs
} else {
rhs
},
));
let rc = RangeConstraint::from_mask(*self.range_constraint().mask() & rhs);
SymbolicExpression::BitOperation(Rc::new(self), BitOperator::And, rhs, rc)
}
}
Expand Down
37 changes: 17 additions & 20 deletions executor/src/witgen/jit/witgen_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F

pub fn value(&self, variable: &Variable) -> Value<T> {
let rc = self.range_constraint(variable);
if let Some(val) = rc.as_ref().and_then(|rc| rc.try_to_single_value()) {
if let Some(val) = rc.try_to_single_value() {
Value::Concrete(val)
} else if self.is_known(variable) {
Value::Known
Expand All @@ -119,7 +119,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
// The variable needs to be known, we need to have a range constraint but
// it cannot be a single value.
assert!(self.known_variables.contains(variable));
let rc = self.range_constraint(variable).unwrap();
let rc = self.range_constraint(variable);
assert!(rc.try_to_single_value().is_none());

log::trace!(
Expand Down Expand Up @@ -245,7 +245,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
.collect::<Vec<_>>();
let range_constraints = evaluated
.iter()
.map(|e| e.as_ref().and_then(|e| e.range_constraint()))
.map(|e| e.as_ref().map(|e| e.range_constraint()).unwrap_or_default())
.collect_vec();
let known = evaluated.iter().map(|e| e.is_some()).collect();

Expand Down Expand Up @@ -318,11 +318,9 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
match &e {
Effect::Assignment(variable, assignment) => {
assert!(self.known_variables.insert(variable.clone()));
if let Some(rc) = assignment.range_constraint() {
// If the variable was determined to be a constant, we add this
// as a range constraint, so we can use it in future evaluations.
self.add_range_constraint(variable.clone(), rc);
}
// If the variable was determined to be a constant, we add this
// as a range constraint, so we can use it in future evaluations.
self.add_range_constraint(variable.clone(), assignment.range_constraint());
progress = true;
self.code.push(e);
}
Expand Down Expand Up @@ -353,9 +351,11 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F

/// Adds a range constraint to the set of derived range constraints. Returns true if progress was made.
fn add_range_constraint(&mut self, variable: Variable, rc: RangeConstraint<T>) -> bool {
let rc = self
.range_constraint(&variable)
.map_or(rc.clone(), |existing_rc| existing_rc.conjunction(&rc));
let old_rc = self.range_constraint(&variable);
let rc = old_rc.conjunction(&rc);
if rc == old_rc {
return false;
}
if !self.known_variables.contains(&variable) {
if let Some(v) = rc.try_to_single_value() {
// Special case: Variable is fixed to a constant by range constraints only.
Expand All @@ -364,17 +364,13 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
.push(Effect::Assignment(variable.clone(), v.into()));
}
}
let old_rc = self
.derived_range_constraints
.insert(variable.clone(), rc.clone());

// If the range constraint changed, we made progress.
old_rc != Some(rc)
self.derived_range_constraints.insert(variable.clone(), rc);
true
}

/// Returns the current best-known range constraint on the given variable
/// combining global range constraints and newly derived local range constraints.
pub fn range_constraint(&self, variable: &Variable) -> Option<RangeConstraint<T>> {
pub fn range_constraint(&self, variable: &Variable) -> RangeConstraint<T> {
variable
.try_to_witness_poly_id()
.and_then(|poly_id| {
Expand All @@ -390,6 +386,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
.chain(self.derived_range_constraints.get(variable))
.cloned()
.reduce(|gc, rc| gc.conjunction(&rc))
.unwrap_or_default()
}

fn evaluate(
Expand Down Expand Up @@ -537,7 +534,7 @@ pub trait CanProcessCall<T: FieldElement> {
&self,
_identity_id: u64,
_known_inputs: &BitVec,
_range_constraints: &[Option<RangeConstraint<T>>],
_range_constraints: &[RangeConstraint<T>],
) -> bool;
}

Expand All @@ -546,7 +543,7 @@ impl<T: FieldElement, Q: QueryCallback<T>> CanProcessCall<T> for &MutableState<'
&self,
identity_id: u64,
known_inputs: &BitVec,
range_constraints: &[Option<RangeConstraint<T>>],
range_constraints: &[RangeConstraint<T>],
) -> bool {
MutableState::can_process_call_fully(self, identity_id, known_inputs, range_constraints)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<'a, T: FieldElement> Machine<'a, T> for DoubleSortedWitnesses32<'a, T> {
&mut self,
identity_id: u64,
known_arguments: &BitVec,
range_constraints: &[Option<RangeConstraint<T>>],
range_constraints: &[RangeConstraint<T>],
) -> bool {
assert!(self.parts.connections.contains_key(&identity_id));
assert_eq!(known_arguments.len(), 4);
Expand All @@ -209,10 +209,8 @@ impl<'a, T: FieldElement> Machine<'a, T> for DoubleSortedWitnesses32<'a, T> {
true
} else {
// It is not known, so we can only process if we do not write.
range_constraints[0].as_ref().is_some_and(|rc| {
!rc.allows_value(T::from(OPERATION_ID_BOOTLOADER_WRITE))
&& !rc.allows_value(T::from(OPERATION_ID_WRITE))
})
!range_constraints[0].allows_value(T::from(OPERATION_ID_BOOTLOADER_WRITE))
&& !range_constraints[0].allows_value(T::from(OPERATION_ID_WRITE))
}
}

Expand Down
2 changes: 1 addition & 1 deletion executor/src/witgen/machines/fixed_lookup_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> {
&mut self,
identity_id: u64,
known_arguments: &BitVec,
_range_constraints: &[Option<RangeConstraint<T>>],
_range_constraints: &[RangeConstraint<T>],
) -> bool {
if !Self::is_responsible(&self.connections[&identity_id]) {
return false;
Expand Down
4 changes: 2 additions & 2 deletions executor/src/witgen/machines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub trait Machine<'a, T: FieldElement>: Send + Sync {
&mut self,
_identity_id: u64,
_known_arguments: &BitVec,
_range_constraints: &[Option<RangeConstraint<T>>],
_range_constraints: &[RangeConstraint<T>],
) -> bool {
false
}
Expand Down Expand Up @@ -187,7 +187,7 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
&mut self,
identity_id: u64,
known_arguments: &BitVec,
range_constraints: &[Option<RangeConstraint<T>>],
range_constraints: &[RangeConstraint<T>],
) -> bool {
match self {
KnownMachine::SecondStageMachine(m) => {
Expand Down
Loading

0 comments on commit 707522d

Please sign in to comment.