Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Block machine: only use JIT on Goldilocks #2541

Draft
wants to merge 13 commits into
base: inline-witness-assignments
Choose a base branch
from
3 changes: 3 additions & 0 deletions .github/workflows/pr-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ env:
CARGO_TERM_COLOR: always
POWDR_GENERATE_PROOFS: "true"
MAX_DEGREE_LOG: "20"
RUST_MIN_STACK: "1073741824"

jobs:
build:
Expand Down Expand Up @@ -234,6 +235,8 @@ jobs:
- name: Install pilcom
run: git clone https://github.com/0xPolygonHermez/pilcom.git && cd pilcom && npm install
- uses: taiki-e/install-action@nextest
- name: Increase Stack Size
run: ulimit -s 32768
- name: Run slow tests
# Number threads is set to 2 because the runner does not have enough memory for more.
run: |
Expand Down
2 changes: 2 additions & 0 deletions executor/src/witgen/eval_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum IncompleteCause<K = usize> {
SymbolicEvaluationOfChallenge,
/// Some knowledge was learnt, but not a concrete value. Example: `Y = X` if we know that `Y` is boolean. We learn that `X` is boolean, but not its exact value.
NotConcrete,
/// The JIT compiler was not able to generate a function that computes a unique witness.
JitCompilationFailed,
Multiple(Vec<IncompleteCause<K>>),
}

Expand Down
31 changes: 31 additions & 0 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,37 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

result.code = self.try_ensure_block_shape(result.code, &requested_known)?;

let needed_machine_call_variables = result
.code
.iter()
.flat_map(|effect| {
if let Effect::MachineCall(_, _, arguments) = effect {
for a in arguments {
assert!(matches!(a, Variable::MachineCallParam(_)));
}
arguments.clone()
} else {
vec![]
}
})
.collect::<BTreeSet<_>>();

result.code = result
.code
.into_iter()
.filter(|effect| {
if let Effect::Assignment(variable, _) = effect {
if let Variable::MachineCallParam(_) = variable {
needed_machine_call_variables.contains(variable)
} else {
true
}
} else {
true
}
})
.collect();

Ok((result, prover_functions))
}

Expand Down
46 changes: 23 additions & 23 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,22 +287,7 @@ fn witgen_code<T: FieldElement>(
.format("\n");
// We do not store "known" together with the values, because we hope
// that this way, the optimizer can group them better.
let store_known = vars_known
.iter()
.filter_map(|var| match var {
Variable::WitnessCell(cell) => Some(cell),
Variable::Param(_)
| Variable::FixedCell(_)
| Variable::IntermediateCell(_)
| Variable::MachineCallParam(_) => None,
})
.map(|cell| {
format!(
" set_known(known, row_offset, {}, {});",
cell.row_offset, cell.id
)
})
.format("\n");
let store_known = "".to_string();
format!(
r#"
#[no_mangle]
Expand Down Expand Up @@ -353,13 +338,26 @@ fn format_effects_inner<T: FieldElement>(
) -> String {
effects
.iter()
.map(|effect| format_effect(effect, is_top_level))
.filter_map(|effect| {
let code = format_effect(effect, is_top_level);
if code.is_empty() {
None
} else {
Some(code)
}
})
.join("\n")
}

fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bool) -> String {
match effect {
Effect::Assignment(var, e) => set(var, &format_expression(e), is_top_level, false),
Effect::Assignment(var, e) => {
if let Variable::MachineCallParam(_) = var {
return "".to_string();
} else {
set(var, &format_expression(e), is_top_level, false)
}
}
Effect::RangeConstraint(..) => {
unreachable!("Final code should not contain pure range constraints.")
}
Expand Down Expand Up @@ -403,9 +401,10 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
.to_string()
+ "\n"
};
format!(
"{var_decls}assert!(call_machine(mutable_state, {id}.into(), MutSlice::from((&mut [{args}]).as_mut_slice())));"
)
// format!(
// "{var_decls}assert!(call_machine(mutable_state, {id}.into(), MutSlice::from((&mut [{args}]).as_mut_slice())));"
// )
format!("{var_decls}// Skipping machine call")
}
Effect::ProverFunctionCall(ProverFunctionCall {
targets,
Expand All @@ -422,8 +421,9 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
.enumerate()
.map(|(i, v)| set(v, &format!("result[{i}]"), is_top_level, false))
.format("\n");
let block = format!("{function_call}\n{store_results}");
format!("{{\n{}\n}}", indent(block, 1))
let block = format!("{}\n{}", function_call, store_results);
// format!("{{\n{}\n}}", indent(block, 1))
"// Skipping prover function".to_string()
}
Effect::Branch(condition, first, second) => {
let var_decls = if is_top_level {
Expand Down
34 changes: 19 additions & 15 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pub struct FunctionCache<'a, T: FieldElement> {
/// The processor that generates the JIT code
processor: BlockMachineProcessor<'a, T>,
/// The cache of JIT functions and the returned range constraints.
/// If the entry is None, we attempted to generate the function but failed.
witgen_functions: HashMap<CacheKey<T>, Option<CacheEntry<T>>>,
/// If the entry is Err, we attempted to generate the function but failed.
witgen_functions: HashMap<CacheKey<T>, Result<CacheEntry<T>, CompilationError>>,
column_layout: ColumnLayout,
block_size: usize,
machine_name: String,
Expand All @@ -49,6 +49,12 @@ pub struct CacheEntry<T: FieldElement> {
pub range_constraints: Vec<RangeConstraint<T>>,
}

#[derive(Debug)]
pub enum CompilationError {
UnsupportedField,
Other(String),
}

impl<'a, T: FieldElement> FunctionCache<'a, T> {
pub fn new(
fixed_data: &'a FixedData<'a, T>,
Expand Down Expand Up @@ -81,15 +87,15 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
bus_id: T,
known_args: &BitVec,
known_concrete: Option<(usize, T)>,
) -> &Option<CacheEntry<T>> {
) -> &Result<CacheEntry<T>, CompilationError> {
// First try the generic version, then the specific.
let mut key = CacheKey {
bus_id,
known_args: known_args.clone(),
known_concrete: None,
};

if self.ensure_cache(can_process.clone(), &key).is_none() && known_concrete.is_some() {
if self.ensure_cache(can_process.clone(), &key).is_err() && known_concrete.is_some() {
key = CacheKey {
bus_id,
known_args: known_args.clone(),
Expand All @@ -104,15 +110,15 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
&mut self,
can_process: impl CanProcessCall<T>,
cache_key: &CacheKey<T>,
) -> &Option<CacheEntry<T>> {
) -> &Result<CacheEntry<T>, CompilationError> {
if !self.witgen_functions.contains_key(cache_key) {
record_start("Auto-witgen code derivation");
let f = match T::known_field() {
// Currently, we only support the Goldilocks fields
Some(KnownField::GoldilocksField) => {
self.compile_witgen_function(can_process, cache_key)
}
_ => None,
_ => Err(CompilationError::UnsupportedField),
};
assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none());
record_end("Auto-witgen code derivation");
Expand All @@ -124,7 +130,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
&self,
can_process: impl CanProcessCall<T>,
cache_key: &CacheKey<T>,
) -> Option<CacheEntry<T>> {
) -> Result<CacheEntry<T>, CompilationError> {
log::info!(
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
self.machine_name,
Expand All @@ -151,13 +157,9 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
cache_key.known_concrete,
)
.map_err(|e| {
// These errors can be pretty verbose and are quite common currently.
log::info!(
"=> Error generating JIT code: {}\n...",
e.to_string().lines().take(5).join("\n")
);
})
.ok()?;
log::info!("{e}");
CompilationError::Other(e)
})?;

log::info!("=> Success!");
let out_of_bounds_vars = code
Expand Down Expand Up @@ -198,7 +200,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.unwrap();
log::info!("Compilation done.");

Some(CacheEntry {
Ok(CacheEntry {
function,
range_constraints,
})
Expand All @@ -223,6 +225,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
known_concrete,
};

log::info!("Calling compiled function for {:?}", cache_key);
self.witgen_functions
.get(&cache_key)
.or_else(|| {
Expand All @@ -237,6 +240,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.expect("compile_cached() returned false!")
.function
.call(self.fixed_data, mutable_state, values, data);
log::info!("Done calling function");

Ok(true)
}
Expand Down
10 changes: 0 additions & 10 deletions executor/src/witgen/jit/includes/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,30 @@
// const column_count: u64 = ...;
// const first_column_id: u64 = ...;

#[inline]
fn known_to_slice<'a>(known: *mut u32, len: u64) -> &'a mut [u32] {
let words_per_row = (column_count + 31) / 32;
let rows = len / column_count;
let known_len = rows * words_per_row;
unsafe { std::slice::from_raw_parts_mut(known, known_len as usize) }
}

#[inline]
fn index(global_offset: u64, local_offset: i32, column: u64) -> usize {
let column = column - first_column_id;
let row = (global_offset as i64 + local_offset as i64) as u64;
(row * column_count + column) as usize
}

#[inline]
fn index_known(global_offset: u64, local_offset: i32, column: u64) -> (u64, u64) {
let column = column - first_column_id;
let row = (global_offset as i64 + local_offset as i64) as u64;
let words_per_row = (column_count + 31) / 32;
(row * words_per_row + column / 32, column % 32)
}

#[inline]
fn get(data: &[FieldElement], global_offset: u64, local_offset: i32, column: u64) -> FieldElement {
data[index(global_offset, local_offset, column)]
}

#[inline]
fn set(
data: &mut [FieldElement],
global_offset: u64,
Expand All @@ -42,20 +37,17 @@ fn set(
data[i] = value;
}

#[inline]
fn set_known(known: &mut [u32], global_offset: u64, local_offset: i32, column: u64) {
let (known_idx, known_bit) = index_known(global_offset, local_offset, column);
known[known_idx as usize] |= 1 << (known_bit);
}

#[inline]
fn get_param(params: &[LookupCell<FieldElement>], i: usize) -> FieldElement {
match params[i] {
LookupCell::Input(v) => *v,
LookupCell::Output(_) => panic!("Output cell used as input"),
}
}
#[inline]
fn set_param(params: &mut [LookupCell<FieldElement>], i: usize, value: FieldElement) {
match &mut params[i] {
LookupCell::Input(_) => panic!("Input cell used as output"),
Expand All @@ -78,7 +70,6 @@ pub struct MutSlice<T> {
}

impl<T> From<&mut [T]> for MutSlice<T> {
#[inline]
fn from(slice: &mut [T]) -> Self {
MutSlice {
data: slice.as_mut_ptr(),
Expand All @@ -88,7 +79,6 @@ impl<T> From<&mut [T]> for MutSlice<T> {
}

impl<T> MutSlice<T> {
#[inline]
fn to_mut_slice<'a>(self) -> &'a mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.data, self.len as usize) }
}
Expand Down
37 changes: 24 additions & 13 deletions executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::witgen::data_structures::caller_data::CallerData;
use crate::witgen::data_structures::finalizable_data::FinalizableData;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::global_constraints::RangeConstraintSet;
use crate::witgen::jit::function_cache::FunctionCache;
use crate::witgen::jit::function_cache::{CompilationError, FunctionCache};
use crate::witgen::jit::witgen_inference::CanProcessCall;
use crate::witgen::processor::{OuterQuery, Processor, SolverState};
use crate::witgen::range_constraints::RangeConstraint;
Expand Down Expand Up @@ -183,8 +183,8 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
known_arguments,
fixed_first_input,
) {
Some(entry) => (true, entry.range_constraints.clone()),
None => (false, range_constraints),
Ok(entry) => (true, entry.range_constraints.clone()),
Err(_) => (false, range_constraints),
}
}

Expand Down Expand Up @@ -454,16 +454,27 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
let fixed_first_input = arguments
.first()
.and_then(|a| a.constant_value().map(|v| (0, v)));
if self
.function_cache
.compile_cached(mutable_state, bus_id, &known_inputs, fixed_first_input)
.is_some()
{
let caller_data = CallerData::new(arguments, range_constraints);
let updates = self.process_lookup_via_jit(mutable_state, bus_id, caller_data)?;
assert!(updates.is_complete());
self.block_count_jit += 1;
return Ok(updates);
match self.function_cache.compile_cached(
mutable_state,
bus_id,
&known_inputs,
fixed_first_input,
) {
Ok(_) => {
let caller_data = CallerData::new(arguments, range_constraints);
let updates = self.process_lookup_via_jit(mutable_state, bus_id, caller_data)?;
assert!(updates.is_complete());
self.block_count_jit += 1;
return Ok(updates);
}
Err(CompilationError::Other(_e)) => {
// Assuming the JIT compiler is feature-complete, this means that the witness is not
// unique, which could happen e.g. if not all required arguments are provided.
return Ok(EvalValue::incomplete(IncompleteCause::JitCompilationFailed));
}
// If we're on an unsupported field, this won't be fixed in future invocations.
// Fall back to run-time witgen.
Err(CompilationError::UnsupportedField) => {}
}

let outer_query = OuterQuery::new(
Expand Down
2 changes: 1 addition & 1 deletion jit-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ fn cargo_toml(opt_level: Option<u32>) -> String {
}
}

const DEBUG: bool = false;
const DEBUG: bool = true;

/// Compiles the given code and returns the path to the
/// temporary directory containing the compiled library
Expand Down
Loading
Loading