Skip to content

Commit

Permalink
feat(prt): handle rollups/compute in clients
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenctw committed Oct 31, 2024
1 parent f31a1fb commit 6083e85
Show file tree
Hide file tree
Showing 18 changed files with 251 additions and 104 deletions.
2 changes: 1 addition & 1 deletion cartesi-rollups/node/compute-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ where
.state_manager
.machine_state_hashes(last_sealed_epoch.epoch_number)?;
let mut player = Player::new(
inputs.into_iter().map(|i| Input(i)).collect(),
Some(inputs.into_iter().map(|i| Input(i)).collect()),
leafs
.into_iter()
.map(|l| {
Expand Down
17 changes: 5 additions & 12 deletions cartesi-rollups/node/machine-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,11 @@ where

fn process_input(&mut self, data: &[u8]) -> Result<(), SM> {
// TODO: review caclulations
let big_steps_in_stride = max_uint(LOG2_STRIDE - LOG2_UARCH_SPAN);
let stride_count_in_input = max_uint(LOG2_EMULATOR_SPAN + LOG2_UARCH_SPAN - LOG2_STRIDE);
let big_steps_in_stride = 1 << (LOG2_STRIDE - LOG2_UARCH_SPAN);
let stride_count_in_input = 1 << (LOG2_EMULATOR_SPAN + LOG2_UARCH_SPAN - LOG2_STRIDE);

// take snapshot and make it available to the compute client
// the snapshot taken before input insersion is for log/proof generation
self.snapshot(0)?;
self.feed_input(data)?;
self.run_machine(1)?;
// take snapshot and make it available to the compute client
// the snapshot taken after insersion and step is for commitment builder
self.snapshot(1)?;
self.run_machine(big_steps_in_stride - 1)?;
self.run_machine(big_steps_in_stride)?;

let mut i: u64 = 0;
while !self.machine.read_iflags_y()? {
Expand Down Expand Up @@ -228,12 +221,12 @@ where
Ok(())
}

fn snapshot(&self, offset: u64) -> Result<(), SM> {
fn take_snapshot(&self) -> Result<(), SM> {
// TODO: make sure "/rollups_data/{epoch_number}" exists
let snapshot_path = PathBuf::from(format!(
"/rollups_data/{}/{}",
self.epoch_number,
self.next_input_index_in_epoch << LOG2_EMULATOR_SPAN + offset
self.next_input_index_in_epoch << LOG2_EMULATOR_SPAN
));
if !snapshot_path.exists() {
self.machine.store(&snapshot_path)?;
Expand Down
21 changes: 14 additions & 7 deletions prt/client-lua/computation/commitment.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ local function run_uarch_span(machine)
end

local function build_small_machine_commitment(base_cycle, log2_stride_count, machine, snapshot_dir)
local machine_state = machine:run(base_cycle)
local machine_state = machine:state()
if save_snapshot then
-- taking snapshot for leafs to save time in next level
machine:snapshot(snapshot_dir, base_cycle)
machine:take_snapshot(snapshot_dir, base_cycle)
end
local initial_state = machine_state.root_hash

Expand All @@ -60,10 +60,10 @@ local function build_small_machine_commitment(base_cycle, log2_stride_count, mac
end

local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride_count, machine, snapshot_dir)
local machine_state = machine:run(base_cycle)
local machine_state = machine:state()
if save_snapshot then
-- taking snapshot for leafs to save time in next level
machine:snapshot(snapshot_dir, base_cycle)
machine:take_snapshot(snapshot_dir, base_cycle)
end
local initial_state = machine_state.root_hash

Expand All @@ -88,9 +88,16 @@ local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride
return initial_state, builder:build(initial_state)
end

local function build_commitment(base_cycle, log2_stride, log2_stride_count, machine_path, snapshot_dir)
local function build_commitment(base_cycle, log2_stride, log2_stride_count, machine_path, snapshot_dir, inputs)
local machine = Machine:new_from_path(machine_path)
machine:load_snapshot(snapshot_dir, base_cycle)
if inputs then
-- treat it as rollups
machine:run_with_inputs(base_cycle, inputs)
else
-- treat it as compute
machine:run(base_cycle)
end

if log2_stride >= consts.log2_uarch_span then
assert(
Expand Down Expand Up @@ -120,15 +127,15 @@ function CommitmentBuilder:new(machine_path, snapshot_dir, root_commitment)
return c
end

function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count)
function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count, inputs)
if not self.commitments[level] then
self.commitments[level] = {}
elseif self.commitments[level][base_cycle] then
return self.commitments[level][base_cycle]
end

local _, commitment = build_commitment(base_cycle, log2_stride, log2_stride_count, self.machine_path,
self.snapshot_dir)
self.snapshot_dir, inputs)
self.commitments[level][base_cycle] = commitment
return commitment
end
Expand Down
72 changes: 54 additions & 18 deletions prt/client-lua/computation/machine.lua
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ local function find_closest_snapshot(path, current_cycle, cycle)
return closest_dir
end

function Machine:snapshot(snapshot_dir, cycle)
function Machine:take_snapshot(snapshot_dir, cycle)
if helper.exists(snapshot_dir) then
local snapshot_path = snapshot_dir .. "/" .. tostring(cycle)

Expand Down Expand Up @@ -163,6 +163,34 @@ function Machine:run_uarch(ucycle)
self.ucycle = ucycle
end

function Machine:run_with_inputs(cycle, inputs)
local input_mask = arithmetic.max_uint(consts.log2_emulator_span)
local current_input_index = self.cycle >> consts.log2_emulator_span

local next_input_index

if self.cycle & input_mask == 0 then
next_input_index = current_input_index
else
next_input_index = current_input_index + 1
end
local next_input_cycle = next_input_index << consts.log2_emulator_span

while next_input_cycle < cycle do
self:run(next_input_cycle)
local input = inputs[next_input_index]
if input then
self.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input);
end

next_input_index = next_input_index + 1
next_input_cycle = next_input_index << consts.log2_emulator_span
end
self:run(cycle)

return self:state()
end

function Machine:increment_uarch()
self.machine:run_uarch(self.ucycle + 1)
self.ucycle = self.ucycle + 1
Expand Down Expand Up @@ -223,29 +251,37 @@ local function encode_access_log(logs)
return '"' .. hex_data .. '"'
end

function Machine.get_logs(path, snapshot_dir, cycle, ucycle, input)
function Machine.get_logs(path, snapshot_dir, cycle, ucycle, inputs)
local machine = Machine:new_from_path(path)
machine:load_snapshot(snapshot_dir, cycle)
local logs
local log_type = { annotations = true, proofs = true }
machine:run(cycle)

local mask = 1 << consts.log2_emulator_span - 1;
if cycle & mask == 0 and input then
-- need to process input
if ucycle == 0 then
logs = machine.machine:log_send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input,
log_type
)
local step_logs = machine.machine:log_uarch_step(log_type)
-- append step logs to cmio logs
for _, log in ipairs(step_logs) do
table.insert(logs, log)
if inputs then
-- treat it as rollups
machine:run_with_inputs(cycle, inputs)

local mask = arithmetic.max_uint(consts.log2_emulator_span);
local input = inputs[cycle >> consts.log2_emulator_span]
if cycle & mask == 0 and input then
-- need to process input
if ucycle == 0 then
-- need to log cmio
logs = machine.machine:log_send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input,
log_type
)
local step_logs = machine.machine:log_uarch_step(log_type)
-- append step logs to cmio logs
for _, log in ipairs(step_logs) do
table.insert(logs, log)
end
return encode_access_log(logs)
else
machine.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input)
end
return encode_access_log(logs)
else
machine.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input)
end
else
-- treat it as compute
machine:run(cycle)
end

machine:run_uarch(ucycle)
Expand Down
9 changes: 5 additions & 4 deletions prt/client-lua/player/strategy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ function HonestStrategy:_react_match(match, commitment, log)

local cycle = match.base_big_cycle
local ucycle = (match.leaf_cycle & constants.uarch_span):touinteger()
local input = self.inputs[cycle >> constants.log2_emulator_span]
local logs = Machine.get_logs(self.machine_path, self.commitment_builder.snapshot_dir, cycle, ucycle, input)
local logs = Machine.get_logs(self.machine_path, self.commitment_builder.snapshot_dir, cycle, ucycle, inputs)

helper.log_full(self.sender.index, string.format(
"win leaf match in tournament %s of level %d for commitment %s",
Expand Down Expand Up @@ -281,7 +280,8 @@ function HonestStrategy:_react_tournament(tournament, log)
tournament.base_big_cycle,
tournament.level,
tournament.log2_stride,
tournament.log2_stride_count
tournament.log2_stride_count,
self.inputs
)

table.insert(log.tournaments, tournament)
Expand All @@ -299,7 +299,8 @@ function HonestStrategy:_react_tournament(tournament, log)
tournament.parent.base_big_cycle,
tournament.parent.level,
tournament.parent.log2_stride,
tournament.parent.log2_stride_count
tournament.parent.log2_stride_count,
self.inputs
)
if tournament_winner.commitment ~= old_commitment then
helper.log_full(self.sender.index, "player lost tournament")
Expand Down
41 changes: 27 additions & 14 deletions prt/client-rs/src/db/compute_state_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ use std::{

#[derive(Debug, Serialize, Deserialize)]
pub struct InputsAndLeafs {
#[serde(default)]
inputs: Vec<Input>,
inputs: Option<Vec<Input>>,
leafs: Vec<Leaf>,
}

Expand All @@ -30,6 +29,7 @@ pub struct Leaf(#[serde(with = "alloy_hex::serde")] pub [u8; 32], pub u64);
#[derive(Debug)]
pub struct ComputeStateAccess {
connection: Mutex<Connection>,
pub handle_rollups: bool,
pub work_path: PathBuf,
}

Expand All @@ -46,7 +46,7 @@ fn read_json_file(file_path: &Path) -> Result<InputsAndLeafs> {

impl ComputeStateAccess {
pub fn new(
inputs: Vec<Input>,
inputs: Option<Vec<Input>>,
leafs: Vec<Leaf>,
root_tournament: String,
compute_data_path: &str,
Expand All @@ -59,13 +59,16 @@ impl ComputeStateAccess {
let work_path = PathBuf::from(work_dir);
let db_path = work_path.join("db");
let no_create_flags = OpenFlags::default() & !OpenFlags::SQLITE_OPEN_CREATE;
let handle_rollups;
match Connection::open_with_flags(&db_path, no_create_flags) {
// database already exists, return it
Ok(connection) => {
handle_rollups = compute_data::handle_rollups(&connection)?;
return Ok(Self {
connection: Mutex::new(connection),
handle_rollups,
work_path,
})
});
}
Err(_) => {
// create new database
Expand All @@ -77,24 +80,29 @@ impl ComputeStateAccess {
// prioritize json file over parameters
match read_json_file(&json_path) {
Ok(inputs_and_leafs) => {
handle_rollups = inputs_and_leafs.inputs.is_some();
compute_data::insert_handle_rollups(&connection, handle_rollups)?;
compute_data::insert_compute_data(
&connection,
inputs_and_leafs.inputs.iter(),
inputs_and_leafs.inputs.unwrap_or_default().iter(),
inputs_and_leafs.leafs.iter(),
)?;
}
Err(_) => {
info!("load inputs and leafs from parameters");
handle_rollups = inputs.is_some();
compute_data::insert_handle_rollups(&connection, handle_rollups)?;
compute_data::insert_compute_data(
&connection,
inputs.iter(),
inputs.unwrap_or_default().iter(),
leafs.iter(),
)?;
}
}

Ok(Self {
connection: Mutex::new(connection),
handle_rollups,
work_path,
})
}
Expand All @@ -106,6 +114,11 @@ impl ComputeStateAccess {
compute_data::input(&conn, id)
}

pub fn inputs(&self) -> Result<Vec<Vec<u8>>> {
let conn = self.connection.lock().unwrap();
compute_data::inputs(&conn)
}

pub fn insert_compute_leafs<'a>(
&self,
level: u64,
Expand Down Expand Up @@ -205,7 +218,7 @@ mod compute_state_access_tests {
create_directory(&work_dir).unwrap();
{
let access =
ComputeStateAccess::new(Vec::new(), Vec::new(), String::from("0x12345678"), "/tmp")
ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp")
.unwrap();

assert_eq!(access.closest_snapshot(0).unwrap(), None);
Expand Down Expand Up @@ -264,8 +277,7 @@ mod compute_state_access_tests {
remove_directory(&work_dir).unwrap();
create_directory(&work_dir).unwrap();
let access =
ComputeStateAccess::new(Vec::new(), Vec::new(), String::from("0x12345678"), "/tmp")
.unwrap();
ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp").unwrap();

let root = [
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1,
Expand All @@ -288,7 +300,7 @@ mod compute_state_access_tests {
fn test_deserialize() {
let json_str_1 = r#"{"leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#;
let inputs_and_leafs_1: InputsAndLeafs = serde_json::from_str(json_str_1).unwrap();
assert_eq!(inputs_and_leafs_1.inputs.len(), 0);
assert_eq!(inputs_and_leafs_1.inputs.unwrap_or_default().len(), 0);
assert_eq!(inputs_and_leafs_1.leafs.len(), 2);
assert_eq!(
inputs_and_leafs_1.leafs[0].0,
Expand All @@ -307,14 +319,15 @@ mod compute_state_access_tests {

let json_str_2 = r#"{"inputs": [], "leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#;
let inputs_and_leafs_2: InputsAndLeafs = serde_json::from_str(json_str_2).unwrap();
assert_eq!(inputs_and_leafs_2.inputs.len(), 0);
assert_eq!(inputs_and_leafs_2.inputs.unwrap_or_default().len(), 0);
assert_eq!(inputs_and_leafs_2.leafs.len(), 2);

let json_str_3 = r#"{"inputs": ["0x12345678", "0x22345678"], "leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#;
let inputs_and_leafs_3: InputsAndLeafs = serde_json::from_str(json_str_3).unwrap();
assert_eq!(inputs_and_leafs_3.inputs.len(), 2);
let inputs_3 = inputs_and_leafs_3.inputs.unwrap();
assert_eq!(inputs_3.len(), 2);
assert_eq!(inputs_and_leafs_3.leafs.len(), 2);
assert_eq!(inputs_and_leafs_3.inputs[0].0, [18, 52, 86, 120]);
assert_eq!(inputs_and_leafs_3.inputs[1].0, [34, 52, 86, 120]);
assert_eq!(inputs_3[0].0, [18, 52, 86, 120]);
assert_eq!(inputs_3[1].0, [34, 52, 86, 120]);
}
}
Loading

0 comments on commit 6083e85

Please sign in to comment.