From a44f5f7b03c3cd5689886c543a859028cbb287ad Mon Sep 17 00:00:00 2001 From: PatStiles Date: Fri, 23 Aug 2024 01:12:52 -0500 Subject: [PATCH] add flag + features --- Cargo.toml | 2 +- src/main.rs | 21 +++------------ src/risc0.rs | 54 +++++++++++++++++++++++++++++++------ src/sp1.rs | 15 ++++++++++- workspaces/base_files/risc0 | 2 ++ 5 files changed, 67 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e050149..1aadf5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ tokio = "1.38.0" sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", tag = "v1.0.1" } # Risc 0 -risc0-zkvm = { git = "https://github.com/risc0/risc0.git", tag = "v1.0.1" } +risc0-zkvm = { git = "https://github.com/risc0/risc0.git", tag = "v1.0.1", features = ["metal", "cuda"] } # Aligned SDK aligned-sdk = { git = "https://github.com/yetanotherco/aligned_layer", tag = "v0.4.0" } diff --git a/src/main.rs b/src/main.rs index b60a087..25db324 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,7 @@ use aligned_sdk::core::types::ProvingSystemId; -use clap::{Args, Parser, Subcommand}; +use clap::{Parser, Subcommand}; use log::info; use std::io; -use std::path::PathBuf; use zkRust::risc0; use zkRust::sp1; use zkRust::submit_proof_to_aligned; @@ -21,21 +20,9 @@ struct Cli { #[derive(Subcommand)] enum Commands { #[clap(about = "Generate a proof of execution of a program using SP1")] - ProveSp1(ProofArgs), + ProveSp1(sp1::Sp1Args), #[clap(about = "Generate a proof of execution of a program using RISC0")] - ProveRisc0(ProofArgs), -} - -#[derive(Args, Debug)] -struct ProofArgs { - guest_path: String, - output_proof_path: String, - #[clap(long)] - submit_to_aligned_with_keystore: Option, - #[clap(long)] - std: bool, - #[clap(long)] - precompiles: bool, + ProveRisc0(risc0::Risc0Args), } fn main() -> io::Result<()> { @@ -96,7 +83,7 @@ fn main() -> io::Result<()> { if args.precompiles { utils::insert(risc0::RISC0_GUEST_CARGO_TOML, risc0::RISC0_ACCELERATION_IMPORT, "[workspace]").unwrap(); } - risc0::generate_risc0_proof()?; + risc0::generate_risc0_proof(args)?; info!("risc0 proof and image ID generated"); diff --git a/src/risc0.rs b/src/risc0.rs index 94ea21b..310dbe1 100644 --- a/src/risc0.rs +++ b/src/risc0.rs @@ -1,4 +1,5 @@ -use std::{fs, io, process::Command}; +use std::{fs, io, path::PathBuf, process::Command}; +use clap::Args; use crate::utils; @@ -22,6 +23,22 @@ pub const RISC0_GUEST_PROGRAM_HEADER_STD: &str = /// RISC0 Cargo patch for accelerated SHA-256, K256, and bigint-multiplication circuits pub const RISC0_ACCELERATION_IMPORT: &str = "\n[patch.crates-io]\nsha2 = { git = \"https://github.com/risc0/RustCrypto-hashes\", tag = \"sha2-v0.10.6-risczero.0\" }\nk256 = { git = \"https://github.com/risc0/RustCrypto-elliptic-curves\", tag = \"k256/v0.13.1-risczero.1\" }\ncrypto-bigint = { git = \"https://github.com/risc0/RustCrypto-crypto-bigint\", tag = \"v0.5.2-risczero.0\" }"; +#[derive(Args, Debug)] +pub struct Risc0Args { + pub guest_path: String, + pub output_proof_path: String, + #[clap(long)] + pub submit_to_aligned_with_keystore: Option, + #[clap(long)] + pub std: bool, + #[clap(long)] + pub precompiles: bool, + #[clap(long)] + pub cuda: bool, + #[clap(long)] + pub metal: bool, +} + /// This function mainly adds this header to the guest in order for it to be proven by /// risc0: /// @@ -34,15 +51,36 @@ pub fn prepare_risc0_guest() -> io::Result<()> { } /// Generates RISC0 proof and image ID -pub fn generate_risc0_proof() -> io::Result<()> { +pub fn generate_risc0_proof(args: &Risc0Args) -> io::Result<()> { let guest_path = fs::canonicalize(RISC0_WORKSPACE_DIR)?; - Command::new("cargo") - .arg("run") - .arg("--release") - .current_dir(guest_path) - .status() - .unwrap(); + if args.cuda { + Command::new("cargo") + .arg("run") + .arg("--release") + .arg("-F") + .arg("cuda") + .current_dir(guest_path) + .status() + .unwrap(); + + } else if args.metal { + Command::new("cargo") + .arg("run") + .arg("--release") + .arg("-F") + .arg("metal") + .current_dir(guest_path) + .status() + .unwrap(); + } else { + Command::new("cargo") + .arg("run") + .arg("--release") + .current_dir(guest_path) + .status() + .unwrap(); + } Ok(()) } diff --git a/src/sp1.rs b/src/sp1.rs index 71c6f6b..8213d71 100644 --- a/src/sp1.rs +++ b/src/sp1.rs @@ -1,4 +1,5 @@ -use std::{fs, io, process::Command}; +use std::{fs, io, path::PathBuf, process::Command}; +use clap::Args; use crate::utils; @@ -20,6 +21,18 @@ pub const SP1_PROGRAM_HEADER: &str = "#![no_main]\nsp1_zkvm::entrypoint!(main);\ /// SP1 Cargo patch for accelerated SHA-256, K256, and bigint-multiplication circuits pub const SP1_ACCELERATION_IMPORT: &str = "\n[patch.crates-io]\nsha2-v0-10-8 = { git = \"https://github.com/sp1-patches/RustCrypto-hashes\", package = \"sha2\", branch = \"patch-sha2-v0.10.8\" }\nsha3-v0-10-8 = { git = \"https://github.com/sp1-patches/RustCrypto-hashes\", package = \"sha3\", branch = \"patch-sha3-v0.10.8\" }\ncrypto-bigint = { git = \"https://github.com/sp1-patches/RustCrypto-bigint\", branch = \"patch-v0.5.5\" }\ntiny-keccak = { git = \"https://github.com/sp1-patches/tiny-keccak\", branch = \"patch-v2.0.2\" }\ned25519-consensus = { git = \"https://github.com/sp1-patches/ed25519-consensus\", branch = \"patch-v2.1.0\" }\necdsa-core = { git = \"https://github.com/sp1-patches/signatures\", package = \"ecdsa\", branch = \"patch-ecdsa-v0.16.9\" }\n"; +#[derive(Args, Debug)] +pub struct Sp1Args { + pub guest_path: String, + pub output_proof_path: String, + #[clap(long)] + pub submit_to_aligned_with_keystore: Option, + #[clap(long)] + pub std: bool, + #[clap(long)] + pub precompiles: bool, +} + /// This function mainly adds this header to the guest in order for it to be proven by /// sp1: /// diff --git a/workspaces/base_files/risc0 b/workspaces/base_files/risc0 index a33a780..0be79be 100644 --- a/workspaces/base_files/risc0 +++ b/workspaces/base_files/risc0 @@ -9,4 +9,6 @@ edition = "2021" risc0-zkvm = { git = "https://github.com/risc0/risc0", features = [ "std", "getrandom", + "metal", + "cuda" ], tag = "v1.0.1" }