diff --git a/Cargo.lock b/Cargo.lock index 67fc79bd..d9b610be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -640,7 +640,7 @@ dependencies = [ "lru 0.13.0", "parking_lot 0.12.3", "pin-project", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "thiserror 2.0.12", @@ -709,7 +709,7 @@ dependencies = [ "async-stream", "futures", "pin-project", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "tokio", @@ -980,7 +980,7 @@ checksum = "21238d7ce1425bdb42269624b59a4fcc1c744bcfcb3cc762cbb251761c45d488" dependencies = [ "alloy-json-rpc", "alloy-transport", - "reqwest", + "reqwest 0.12.15", "serde_json", "tower", "tracing", @@ -2615,7 +2615,7 @@ dependencies = [ "log", "redis", "redis-test", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "shared", @@ -3357,7 +3357,7 @@ dependencies = [ "google-cloud-token", "home", "jsonwebtoken", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "thiserror 1.0.69", @@ -3377,9 +3377,9 @@ dependencies = [ "base64 0.22.1", "derive_builder", "http 1.3.1", - "reqwest", + "reqwest 0.12.15", "rustls", - "rustls-pemfile", + "rustls-pemfile 2.2.0", "serde", "serde_json", "thiserror 2.0.12", @@ -3393,7 +3393,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d901aeb453fd80e51d64df4ee005014f6cf39f2d736dd64f7239c132d9d39a6a" dependencies = [ - "reqwest", + "reqwest 0.12.15", "thiserror 1.0.69", "tokio", ] @@ -3418,7 +3418,7 @@ dependencies = [ "percent-encoding", "pkcs8", "regex", - "reqwest", + "reqwest 0.12.15", "reqwest-middleware", "ring 0.17.14", "serde", @@ -3854,6 +3854,19 @@ dependencies = [ "webpki-roots 0.26.9", ] +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.32", + "native-tls", + "tokio", + "tokio-native-tls", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -4116,7 +4129,7 @@ dependencies = [ "netlink-proto", "netlink-sys", "rtnetlink 0.13.1", - "system-configuration", + "system-configuration 0.6.1", "tokio", "windows 0.52.0", ] @@ -4215,6 +4228,12 @@ dependencies = [ "serde", ] +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + [[package]] name = "inout" version = "0.1.4" @@ -4329,7 +4348,7 @@ dependencies = [ "portmapper", "rand 0.8.5", "rcgen 0.13.2", - "reqwest", + "reqwest 0.12.15", "ring 0.17.14", "rustls", "rustls-webpki 0.102.8", @@ -4462,7 +4481,7 @@ dependencies = [ "pkarr", "postcard", "rand 0.8.5", - "reqwest", + "reqwest 0.12.15", "rustls", "rustls-webpki 0.102.8", "serde", @@ -5489,6 +5508,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -5760,7 +5788,7 @@ dependencies = [ "netlink-packet-route 0.17.1", "netlink-sys", "once_cell", - "system-configuration", + "system-configuration 0.6.1", "windows-sys 0.52.0", ] @@ -5884,7 +5912,7 @@ dependencies = [ "bitflags 1.3.2", "cfg-if", "libc", - "memoffset", + "memoffset 0.7.1", "pin-utils", ] @@ -6152,6 +6180,31 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "operations" +version = "0.1.0" +dependencies = [ + "actix-web", + "alloy", + "alloy-provider", + "anyhow", + "env_logger", + "futures-util", + "hex", + "log", + "p2p", + "rand 0.8.5", + "redis", + "serde", + "serde_json", + "shared", + "subtle", + "tokio", + "tokio-util", + "url", + "uuid", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -6175,12 +6228,13 @@ dependencies = [ "hex", "log", "mockito", + "operations", "p2p", "prometheus 0.14.0", "rand 0.9.1", "redis", "redis-test", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "shared", @@ -6865,6 +6919,117 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "pyo3" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +dependencies = [ + "indoc", + "libc", + "memoffset 0.9.1", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + +[[package]] +name = "pyo3-macros" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "python-sdk" +version = "0.1.0" +dependencies = [ + "alloy", + "alloy-provider", + "anyhow", + "futures", + "hex", + "log", + "operations", + "p2p", + "pyo3", + "pyo3-log", + "pythonize", + "rand 0.8.5", + "reqwest 0.11.27", + "serde", + "serde_json", + "shared", + "test-log", + "thiserror 1.0.69", + "tokio", + "tokio-test", + "tokio-util", + "url", +] + +[[package]] +name = "pythonize" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597907139a488b22573158793aa7539df36ae863eba300c75f3a0d65fc475e27" +dependencies = [ + "pyo3", + "serde", +] + [[package]] name = "quanta" version = "0.10.1" @@ -7246,6 +7411,46 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "reqwest" +version = "0.11.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +dependencies = [ + "base64 0.21.7", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-tls 0.5.0", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 0.1.2", + "system-configuration 0.5.1", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "reqwest" version = "0.12.15" @@ -7264,7 +7469,7 @@ dependencies = [ "http-body-util", "hyper 1.6.0", "hyper-rustls", - "hyper-tls", + "hyper-tls 0.6.0", "hyper-util", "ipnet", "js-sys", @@ -7277,13 +7482,13 @@ dependencies = [ "pin-project-lite", "quinn", "rustls", - "rustls-pemfile", + "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", - "system-configuration", + "sync_wrapper 1.0.2", + "system-configuration 0.6.1", "tokio", "tokio-native-tls", "tokio-rustls", @@ -7308,7 +7513,7 @@ dependencies = [ "anyhow", "async-trait", "http 1.3.1", - "reqwest", + "reqwest 0.12.15", "serde", "thiserror 1.0.69", "tower-service", @@ -7687,6 +7892,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -8166,6 +8380,7 @@ dependencies = [ "rand 0.9.1", "redis", "regex", + "reqwest 0.12.15", "serde", "serde_json", "subtle", @@ -8573,6 +8788,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.2" @@ -8608,6 +8829,17 @@ dependencies = [ "windows 0.52.0", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys 0.5.0", +] + [[package]] name = "system-configuration" version = "0.6.1" @@ -8616,7 +8848,17 @@ checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.9.0", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.6.0", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -8641,6 +8883,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + [[package]] name = "tempfile" version = "3.14.0" @@ -8654,6 +8902,28 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -8840,6 +9110,19 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-tungstenite" version = "0.24.0" @@ -8929,7 +9212,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -9010,7 +9293,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3beec919fbdf99d719de8eda6adae3281f8a5b71ae40431f44dc7423053d34" dependencies = [ "loki-api", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "snap", @@ -9173,6 +9456,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "universal-hash" version = "0.5.1" @@ -9303,7 +9592,7 @@ dependencies = [ "base64 0.22.1", "mime_guess", "regex", - "reqwest", + "reqwest 0.12.15", "rust-embed", "serde", "serde_json", @@ -9351,7 +9640,7 @@ dependencies = [ "redis", "redis-test", "regex", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "shared", @@ -10189,10 +10478,11 @@ dependencies = [ "libc", "log", "nvml-wrapper", + "operations", "p2p", "rand 0.8.5", "rand 0.9.1", - "reqwest", + "reqwest 0.12.15", "rust-ipfs", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 1bc9e2ac..f35a3e7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,12 +7,15 @@ members = [ "crates/orchestrator", "crates/p2p", "crates/dev-utils", + "crates/python-sdk", + "crates/operations", ] resolver = "2" [workspace.dependencies] shared = { path = "crates/shared" } p2p = { path = "crates/p2p" } +operations = { path = "crates/operations" } actix-web = "4.9.0" clap = { version = "4.5.27", features = ["derive"] } @@ -42,7 +45,6 @@ mockito = "1.7.0" iroh = "0.34.1" rand_v8 = { package = "rand", version = "0.8.5", features = ["std"] } rand_core_v6 = { package = "rand_core", version = "0.6.4", features = ["std"] } -ipld-core = "0.4" rust-ipfs = "0.14" cid = "0.11" tracing = "0.1.41" @@ -59,3 +61,10 @@ manual_let_else = "warn" [workspace.lints.rust] unreachable_pub = "warn" + +[workspace.metadata.rust-analyzer] +# Help rust-analyzer with proc-macros +procMacro.enable = true +procMacro.attributes.enable = true +# Use a separate target directory for rust-analyzer +targetDir = true diff --git a/Makefile b/Makefile index decd07f6..fb4b62be 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,18 @@ up: @# Attach to session @tmux attach-session -t prime-dev +# Start Docker services and deploy contracts only +.PHONY: bootstrap +bootstrap: + @echo "Starting Docker services and deploying contracts..." + @# Start Docker services + @docker compose up -d reth redis discovery --wait --wait-timeout 180 + @# Deploy contracts + @cd smart-contracts && sh deploy.sh && sh deploy_work_validation.sh && cd .. + @# Run setup + @$(MAKE) setup + @echo "Bootstrap complete - Docker services running and contracts deployed" + # Stop development environment .PHONY: down down: @@ -268,3 +280,12 @@ deregister-worker: set -a; source ${ENV_FILE}; set +a; \ cargo run --bin worker -- deregister --compute-pool-id $${WORKER_COMPUTE_POOL_ID} --private-key-provider $${PRIVATE_KEY_PROVIDER} --private-key-node $${PRIVATE_KEY_NODE} --rpc-url $${RPC_URL} +# Python Package +.PHONY: python-install +python-install: + @cd crates/prime-protocol-py && make install + +.PHONY: python-test +python-test: + @cd crates/prime-protocol-py && make test + diff --git a/crates/dev-utils/examples/compute_pool.rs b/crates/dev-utils/examples/compute_pool.rs index 2569980c..8870a528 100644 --- a/crates/dev-utils/examples/compute_pool.rs +++ b/crates/dev-utils/examples/compute_pool.rs @@ -58,7 +58,7 @@ async fn main() -> Result<()> { let compute_limit = U256::from(0); - let tx = contracts + let _tx = contracts .compute_pool .create_compute_pool( domain_id, @@ -68,31 +68,21 @@ async fn main() -> Result<()> { compute_limit, ) .await; - println!("Transaction: {:?}", tx); let rewards_distributor_address = contracts .compute_pool .get_reward_distributor_address(U256::from(0)) .await .unwrap(); - println!( - "Rewards distributor address: {:?}", - rewards_distributor_address - ); let rewards_distributor = RewardsDistributor::new( rewards_distributor_address, wallet.provider(), "rewards_distributor.json", ); let rate = U256::from(10000000000000000u64); - let tx = rewards_distributor.set_reward_rate(rate).await; - println!("Setting reward rate: {:?}", tx); + let _tx = rewards_distributor.set_reward_rate(rate).await; - let reward_rate = rewards_distributor.get_reward_rate().await.unwrap(); - println!( - "Reward rate: {}", - reward_rate.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0) - ); + let _reward_rate = rewards_distributor.get_reward_rate().await.unwrap(); Ok(()) } diff --git a/crates/dev-utils/examples/create_domain.rs b/crates/dev-utils/examples/create_domain.rs index 4365c764..d1da5ea2 100644 --- a/crates/dev-utils/examples/create_domain.rs +++ b/crates/dev-utils/examples/create_domain.rs @@ -59,6 +59,6 @@ async fn main() -> Result<()> { .await; println!("Creating domain: {}", args.domain_name); println!("Validation logic: {}", args.validation_logic); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/eject_node.rs b/crates/dev-utils/examples/eject_node.rs index e2ed03a3..142aa1cd 100644 --- a/crates/dev-utils/examples/eject_node.rs +++ b/crates/dev-utils/examples/eject_node.rs @@ -52,20 +52,20 @@ async fn main() -> Result<()> { .compute_registry .get_node(provider_address, node_address) .await; - println!("Node info: {:?}", node_info); + println!("Node info: {node_info:?}"); let tx = contracts .compute_pool .eject_node(args.pool_id, node_address) .await; println!("Ejected node {} from pool {}", args.node, args.pool_id); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let node_info = contracts .compute_registry .get_node(provider_address, node_address) .await; - println!("Post ejection node info: {:?}", node_info); + println!("Post ejection node info: {node_info:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/get_node_info.rs b/crates/dev-utils/examples/get_node_info.rs index fec5f526..79c7c120 100644 --- a/crates/dev-utils/examples/get_node_info.rs +++ b/crates/dev-utils/examples/get_node_info.rs @@ -55,9 +55,6 @@ async fn main() -> Result<()> { .await .unwrap(); - println!( - "Node Active: {}, Validated: {}, In Pool: {}", - active, validated, is_node_in_pool - ); + println!("Node Active: {active}, Validated: {validated}, In Pool: {is_node_in_pool}"); Ok(()) } diff --git a/crates/dev-utils/examples/invalidate_work.rs b/crates/dev-utils/examples/invalidate_work.rs index 78154b07..c93c8cee 100644 --- a/crates/dev-utils/examples/invalidate_work.rs +++ b/crates/dev-utils/examples/invalidate_work.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { "Invalidated work in pool {} with penalty {}", args.pool_id, args.penalty ); - println!("Transaction hash: {:?}", tx); + println!("Transaction hash: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/mint_ai_token.rs b/crates/dev-utils/examples/mint_ai_token.rs index 5e572b40..bc43b78d 100644 --- a/crates/dev-utils/examples/mint_ai_token.rs +++ b/crates/dev-utils/examples/mint_ai_token.rs @@ -45,9 +45,9 @@ async fn main() -> Result<()> { let amount = U256::from(args.amount) * Unit::ETHER.wei(); let tx = contracts.ai_token.mint(address, amount).await; println!("Minting to address: {}", args.address); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let balance = contracts.ai_token.balance_of(address).await; - println!("Balance: {:?}", balance); + println!("Balance: {balance:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/set_min_stake_amount.rs b/crates/dev-utils/examples/set_min_stake_amount.rs index 82644e61..2858f5c7 100644 --- a/crates/dev-utils/examples/set_min_stake_amount.rs +++ b/crates/dev-utils/examples/set_min_stake_amount.rs @@ -36,13 +36,13 @@ async fn main() -> Result<()> { .unwrap(); let min_stake_amount = U256::from(args.min_stake_amount) * Unit::ETHER.wei(); - println!("Min stake amount: {}", min_stake_amount); + println!("Min stake amount: {min_stake_amount}"); let tx = contracts .prime_network .set_stake_minimum(min_stake_amount) .await; - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/start_compute_pool.rs b/crates/dev-utils/examples/start_compute_pool.rs index b11e2b2c..a94c0b6f 100644 --- a/crates/dev-utils/examples/start_compute_pool.rs +++ b/crates/dev-utils/examples/start_compute_pool.rs @@ -41,6 +41,6 @@ async fn main() -> Result<()> { .start_compute_pool(U256::from(args.pool_id)) .await; println!("Started compute pool with id: {}", args.pool_id); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/submit_work.rs b/crates/dev-utils/examples/submit_work.rs index aa3b489c..0fcf20d0 100644 --- a/crates/dev-utils/examples/submit_work.rs +++ b/crates/dev-utils/examples/submit_work.rs @@ -64,7 +64,7 @@ async fn main() -> Result<()> { "Submitted work for node {} in pool {}", args.node, args.pool_id ); - println!("Transaction hash: {:?}", tx); + println!("Transaction hash: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/test_concurrent_calls.rs b/crates/dev-utils/examples/test_concurrent_calls.rs index 47f7bbea..1bef230a 100644 --- a/crates/dev-utils/examples/test_concurrent_calls.rs +++ b/crates/dev-utils/examples/test_concurrent_calls.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { let wallet = Arc::new(Wallet::new(&args.key, Url::parse(&args.rpc_url)?).unwrap()); let price = wallet.provider.get_gas_price().await?; - println!("Gas price: {:?}", price); + println!("Gas price: {price:?}"); let current_nonce = wallet .provider @@ -50,8 +50,8 @@ async fn main() -> Result<()> { .block_id(BlockId::Number(BlockNumberOrTag::Pending)) .await?; - println!("Pending nonce: {:?}", pending_nonce); - println!("Current nonce: {:?}", current_nonce); + println!("Pending nonce: {pending_nonce:?}"); + println!("Current nonce: {current_nonce:?}"); // Unfortunately have to build all contracts atm let contracts = Arc::new( @@ -67,7 +67,7 @@ async fn main() -> Result<()> { let address = Address::from_str(&args.address).unwrap(); let amount = U256::from(args.amount) * Unit::ETHER.wei(); let random = (rand::random::() % 10) + 1; - println!("Random: {:?}", random); + println!("Random: {random:?}"); let contracts_one = contracts.clone(); let wallet_one = wallet.clone(); @@ -80,7 +80,7 @@ async fn main() -> Result<()> { let tx = retry_call(mint_call, 5, wallet_one.provider(), None) .await .unwrap(); - println!("Transaction hash I: {:?}", tx); + println!("Transaction hash I: {tx:?}"); }); let contracts_two = contracts.clone(); @@ -93,11 +93,11 @@ async fn main() -> Result<()> { let tx = retry_call(mint_call_two, 5, wallet_two.provider(), None) .await .unwrap(); - println!("Transaction hash II: {:?}", tx); + println!("Transaction hash II: {tx:?}"); }); let balance = contracts.ai_token.balance_of(address).await.unwrap(); - println!("Balance: {:?}", balance); + println!("Balance: {balance:?}"); tokio::time::sleep(tokio::time::Duration::from_secs(40)).await; Ok(()) } diff --git a/crates/discovery/src/api/routes/node.rs b/crates/discovery/src/api/routes/node.rs index b2cf780f..aa6ca45a 100644 --- a/crates/discovery/src/api/routes/node.rs +++ b/crates/discovery/src/api/routes/node.rs @@ -465,12 +465,10 @@ mod tests { assert_eq!(body.data, "Node registered successfully"); let nodes = app_state.node_store.get_nodes().await; - let nodes = match nodes { - Ok(nodes) => nodes, - Err(_) => { - panic!("Error getting nodes"); - } + let Ok(nodes) = nodes else { + panic!("Error getting nodes"); }; + assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].id, node.id); assert_eq!(nodes[0].last_updated, None); @@ -611,12 +609,10 @@ mod tests { assert_eq!(body.data, "Node registered successfully"); let nodes = app_state.node_store.get_nodes().await; - let nodes = match nodes { - Ok(nodes) => nodes, - Err(_) => { - panic!("Error getting nodes"); - } + let Ok(nodes) = nodes else { + panic!("Error getting nodes"); }; + assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].id, node.id); } diff --git a/crates/discovery/src/store/redis.rs b/crates/discovery/src/store/redis.rs index 508815c2..c0a0c36b 100644 --- a/crates/discovery/src/store/redis.rs +++ b/crates/discovery/src/store/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/operations/Cargo.toml b/crates/operations/Cargo.toml new file mode 100644 index 00000000..875478cd --- /dev/null +++ b/crates/operations/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "operations" +version = "0.1.0" +edition = "2021" + +[lints] +workspace = true + +[lib] +name = "operations" +path = "src/lib.rs" + +[dependencies] +shared = { workspace = true } +p2p = { workspace = true } +alloy = { workspace = true } +alloy-provider = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +url = { workspace = true } +actix-web = { workspace = true } +anyhow = { workspace = true } +futures-util = { workspace = true } +hex = { workspace = true } +uuid = { workspace = true } +log = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +redis = { workspace = true, features = ["aio", "tokio-comp"] } +rand_v8 = { workspace = true } +env_logger = { workspace = true } +subtle = "2.6.1" diff --git a/crates/operations/src/invite/admin.rs b/crates/operations/src/invite/admin.rs new file mode 100644 index 00000000..02de9a1e --- /dev/null +++ b/crates/operations/src/invite/admin.rs @@ -0,0 +1,81 @@ +use alloy::primitives::utils::keccak256 as keccak; +use alloy::primitives::{Address, U256}; +use alloy::signers::Signer; +use anyhow::Result; +use rand_v8::prelude::*; +use shared::web3::wallet::Wallet; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Generates an invite signature for a node +/// +/// This function is used by pool owners/admins to create signed invites +/// that authorize nodes to join their pool. +pub async fn generate_invite_signature( + wallet: &Wallet, + domain_id: u32, + pool_id: u32, + node_address: Address, + nonce: [u8; 32], + expiration: [u8; 32], +) -> Result<[u8; 65]> { + let domain_id_bytes: [u8; 32] = U256::from(domain_id).to_be_bytes(); + let pool_id_bytes: [u8; 32] = U256::from(pool_id).to_be_bytes(); + + let digest = keccak( + [ + &domain_id_bytes, + &pool_id_bytes, + node_address.as_slice(), + &nonce, + &expiration, + ] + .concat(), + ); + + let signature = wallet + .signer + .sign_message(digest.as_slice()) + .await? + .as_bytes() + .to_owned(); + + Ok(signature) +} + +/// Generates an invite expiration timestamp +/// +/// Creates an expiration timestamp for an invite, defaulting to 1000 seconds from now +pub fn generate_invite_expiration(seconds_from_now: Option) -> Result<[u8; 32]> { + let duration = seconds_from_now.unwrap_or(1000); + let expiration = U256::from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs() + + duration, + ); + Ok(expiration.to_be_bytes()) +} + +/// Generates a random nonce for invite +pub fn generate_invite_nonce() -> [u8; 32] { + rand_v8::rngs::OsRng.gen() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_invite_nonce() { + let nonce1 = generate_invite_nonce(); + let nonce2 = generate_invite_nonce(); + assert_ne!(nonce1, nonce2, "Nonces should be unique"); + } + + #[test] + fn test_generate_invite_expiration() { + let expiration = generate_invite_expiration(Some(3600)).unwrap(); + assert_eq!(expiration.len(), 32); + } +} diff --git a/crates/operations/src/invite/common.rs b/crates/operations/src/invite/common.rs new file mode 100644 index 00000000..dd153980 --- /dev/null +++ b/crates/operations/src/invite/common.rs @@ -0,0 +1,108 @@ +use alloy::primitives::Address; +use anyhow::Result; +use p2p::{InviteRequest, InviteRequestUrl}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Builder for creating invite requests +pub struct InviteBuilder { + pool_id: u32, + url: InviteRequestUrl, +} + +impl InviteBuilder { + /// Creates a new InviteBuilder with a master URL + pub fn with_url(pool_id: u32, url: String) -> Self { + Self { + pool_id, + url: InviteRequestUrl::MasterUrl(url), + } + } + + /// Creates a new InviteBuilder with IP and port + pub fn with_ip_port(pool_id: u32, ip: String, port: u16) -> Self { + Self { + pool_id, + url: InviteRequestUrl::MasterIpPort(ip, port), + } + } + + /// Builds an InviteRequest with the given parameters + pub fn build( + self, + invite_signature: [u8; 65], + nonce: [u8; 32], + expiration: [u8; 32], + ) -> Result { + Ok(InviteRequest { + invite: hex::encode(invite_signature), + pool_id: self.pool_id, + url: self.url, + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs(), + expiration, + nonce, + }) + } +} + +/// Metadata for an invite request that includes worker information +#[derive(Debug, Clone)] +pub struct InviteMetadata { + pub worker_wallet_address: Address, + pub worker_p2p_id: String, + pub worker_addresses: Vec, +} + +/// Full invite request with metadata +#[derive(Debug)] +pub struct InviteWithMetadata { + pub metadata: InviteMetadata, + pub request: InviteRequest, +} + +/// Helper to parse InviteRequestUrl into a usable endpoint +pub fn get_endpoint_from_url(url: &InviteRequestUrl, path: &str) -> String { + match url { + InviteRequestUrl::MasterIpPort(ip, port) => { + format!("http://{ip}:{port}/{path}") + } + InviteRequestUrl::MasterUrl(url) => { + let url = url.trim_end_matches('/'); + format!("{url}/{path}") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_invite_builder_with_url() { + let builder = InviteBuilder::with_url(1, "https://example.com".to_string()); + let signature = [0u8; 65]; + let nonce = [1u8; 32]; + let expiration = [2u8; 32]; + + let invite = builder.build(signature, nonce, expiration).unwrap(); + assert_eq!(invite.pool_id, 1); + assert!(matches!(invite.url, InviteRequestUrl::MasterUrl(_))); + } + + #[test] + fn test_get_endpoint_from_url() { + let url = InviteRequestUrl::MasterUrl("https://example.com".to_string()); + assert_eq!( + get_endpoint_from_url(&url, "heartbeat"), + "https://example.com/heartbeat" + ); + + let ip_port = InviteRequestUrl::MasterIpPort("192.168.1.1".to_string(), 8080); + assert_eq!( + get_endpoint_from_url(&ip_port, "heartbeat"), + "http://192.168.1.1:8080/heartbeat" + ); + } +} diff --git a/crates/operations/src/invite/mod.rs b/crates/operations/src/invite/mod.rs new file mode 100644 index 00000000..5d7d044d --- /dev/null +++ b/crates/operations/src/invite/mod.rs @@ -0,0 +1,7 @@ +pub mod admin; +pub mod common; +pub mod worker; + +pub use admin::*; +pub use common::*; +pub use worker::*; diff --git a/crates/operations/src/invite/worker.rs b/crates/operations/src/invite/worker.rs new file mode 100644 index 00000000..682855a5 --- /dev/null +++ b/crates/operations/src/invite/worker.rs @@ -0,0 +1,153 @@ +use alloy::primitives::utils::keccak256 as keccak; +use alloy::primitives::{Address, Signature, U256}; +use anyhow::{bail, Result}; +use p2p::InviteRequest; +use std::str::FromStr; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Verifies an invite signature +/// +/// This function is used by workers to verify that an invite is valid +/// and was signed by the correct pool owner/admin. +pub fn verify_invite_signature( + invite: &InviteRequest, + domain_id: u32, + worker_address: Address, + signer_address: Address, +) -> Result { + // Parse the signature from hex string + let signature = Signature::from_str(&invite.invite) + .map_err(|e| anyhow::anyhow!("Failed to parse invite signature: {}", e))?; + + // Recreate the message that was signed + let domain_id_bytes: [u8; 32] = U256::from(domain_id).to_be_bytes(); + let pool_id_bytes: [u8; 32] = U256::from(invite.pool_id).to_be_bytes(); + + let message = keccak( + [ + &domain_id_bytes, + &pool_id_bytes, + worker_address.as_slice(), + &invite.nonce, + &invite.expiration, + ] + .concat(), + ); + + // Verify the signature + let recovered = signature + .recover_address_from_msg(message.as_slice()) + .map_err(|e| anyhow::anyhow!("Failed to recover address: {}", e))?; + + Ok(recovered == signer_address) +} + +/// Checks if an invite has expired +pub fn is_invite_expired(invite: &InviteRequest) -> Result { + let expiration_u256 = U256::from_be_bytes(invite.expiration); + let expiration_secs = expiration_u256.to::(); + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs(); + + Ok(current_time > expiration_secs) +} + +/// Validates an invite for a worker +/// +/// Performs full validation including signature verification and expiration check +pub fn validate_invite( + invite: &InviteRequest, + domain_id: u32, + worker_address: Address, + pool_owner_address: Address, +) -> Result<()> { + // Check expiration + if is_invite_expired(invite)? { + bail!("Invite has expired"); + } + + // Verify signature + if !verify_invite_signature(invite, domain_id, worker_address, pool_owner_address)? { + bail!("Invalid invite signature"); + } + + Ok(()) +} + +/// Extract pool information from an invite +#[derive(Debug, Clone)] +pub struct PoolInfo { + pub pool_id: u32, + pub endpoint_url: String, +} + +impl PoolInfo { + pub fn from_invite(invite: &InviteRequest) -> Self { + use crate::invite::common::get_endpoint_from_url; + + Self { + pool_id: invite.pool_id, + endpoint_url: get_endpoint_from_url(&invite.url, ""), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use p2p::InviteRequestUrl; + + #[test] + fn test_is_invite_expired() { + // Create an expired invite + let past_time = U256::from(1000u64); + let invite = InviteRequest { + invite: "test".to_string(), + pool_id: 1, + url: InviteRequestUrl::MasterUrl("https://example.com".to_string()), + timestamp: 0, + expiration: past_time.to_be_bytes(), + nonce: [0u8; 32], + }; + + assert!(is_invite_expired(&invite).unwrap()); + + // Create a future invite + let future_time = U256::from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600, + ); + let future_invite = InviteRequest { + invite: "test".to_string(), + pool_id: 1, + url: InviteRequestUrl::MasterUrl("https://example.com".to_string()), + timestamp: 0, + expiration: future_time.to_be_bytes(), + nonce: [0u8; 32], + }; + + assert!(!is_invite_expired(&future_invite).unwrap()); + } + + #[test] + fn test_pool_info_from_invite() { + let invite = InviteRequest { + invite: "test".to_string(), + pool_id: 42, + url: InviteRequestUrl::MasterUrl("https://example.com/api".to_string()), + timestamp: 0, + expiration: [0u8; 32], + nonce: [0u8; 32], + }; + + let pool_info = PoolInfo::from_invite(&invite); + assert_eq!(pool_info.pool_id, 42); + assert_eq!(pool_info.endpoint_url, "https://example.com/api/"); + } +} diff --git a/crates/operations/src/lib.rs b/crates/operations/src/lib.rs new file mode 100644 index 00000000..9734c863 --- /dev/null +++ b/crates/operations/src/lib.rs @@ -0,0 +1,2 @@ +pub mod invite; +pub mod operations; diff --git a/crates/operations/src/operations/compute_node.rs b/crates/operations/src/operations/compute_node.rs new file mode 100644 index 00000000..c294291a --- /dev/null +++ b/crates/operations/src/operations/compute_node.rs @@ -0,0 +1,92 @@ +use alloy::{primitives::utils::keccak256 as keccak, primitives::U256, signers::Signer}; +use anyhow::Result; +use shared::web3::wallet::Wallet; +use shared::web3::{contracts::core::builder::Contracts, wallet::WalletProvider}; + +pub struct ComputeNodeOperations<'c> { + provider_wallet: &'c Wallet, + node_wallet: &'c Wallet, + contracts: Contracts, +} + +impl<'c> ComputeNodeOperations<'c> { + pub fn new( + provider_wallet: &'c Wallet, + node_wallet: &'c Wallet, + contracts: Contracts, + ) -> Self { + Self { + provider_wallet, + node_wallet, + contracts, + } + } + + pub async fn check_compute_node_exists(&self) -> Result> { + let compute_node = self + .contracts + .compute_registry + .get_node( + self.provider_wallet.wallet.default_signer().address(), + self.node_wallet.wallet.default_signer().address(), + ) + .await; + + match compute_node { + Ok(_) => Ok(true), + Err(_) => Ok(false), + } + } + + // Returns true if the compute node was added, false if it already exists + pub async fn add_compute_node( + &self, + compute_units: U256, + ) -> Result> { + log::info!("🔄 Adding compute node"); + + if self.check_compute_node_exists().await? { + return Ok(false); + } + + log::info!("Adding compute node"); + let provider_address = self.provider_wallet.wallet.default_signer().address(); + let node_address = self.node_wallet.wallet.default_signer().address(); + let digest = keccak([provider_address.as_slice(), node_address.as_slice()].concat()); + + let signature = self + .node_wallet + .signer + .sign_message(digest.as_slice()) + .await? + .as_bytes(); + + // Create the signature bytes + let add_node_tx = self + .contracts + .prime_network + .add_compute_node(node_address, compute_units, signature.to_vec()) + .await?; + log::info!("Add node tx: {add_node_tx:?}"); + Ok(true) + } + + pub async fn remove_compute_node(&self) -> Result> { + log::info!("🔄 Removing compute node"); + + if !self.check_compute_node_exists().await? { + return Ok(false); + } + + log::info!("Removing compute node"); + let provider_address = self.provider_wallet.wallet.default_signer().address(); + let node_address = self.node_wallet.wallet.default_signer().address(); + let remove_node_tx = self + .contracts + .prime_network + .remove_compute_node(provider_address, node_address) + .await?; + log::info!("Remove node tx: {remove_node_tx:?}"); + Ok(true) + } +} diff --git a/crates/operations/src/operations/mod.rs b/crates/operations/src/operations/mod.rs new file mode 100644 index 00000000..089315f5 --- /dev/null +++ b/crates/operations/src/operations/mod.rs @@ -0,0 +1,2 @@ +pub mod compute_node; +pub mod provider; diff --git a/crates/worker/src/operations/provider.rs b/crates/operations/src/operations/provider.rs similarity index 67% rename from crates/worker/src/operations/provider.rs rename to crates/operations/src/operations/provider.rs index fb8aba5f..c07f6189 100644 --- a/crates/worker/src/operations/provider.rs +++ b/crates/operations/src/operations/provider.rs @@ -1,4 +1,3 @@ -use crate::console::Console; use alloy::primitives::utils::format_ether; use alloy::primitives::{Address, U256}; use log::error; @@ -9,18 +8,14 @@ use std::{fmt, io}; use tokio::time::{sleep, Duration}; use tokio_util::sync::CancellationToken; -pub(crate) struct ProviderOperations { +pub struct ProviderOperations { wallet: Wallet, contracts: Contracts, auto_accept: bool, } impl ProviderOperations { - pub(crate) fn new( - wallet: Wallet, - contracts: Contracts, - auto_accept: bool, - ) -> Self { + pub fn new(wallet: Wallet, contracts: Contracts, auto_accept: bool) -> Self { Self { wallet, contracts, @@ -44,7 +39,7 @@ impl ProviderOperations { } } - pub(crate) fn start_monitoring(&self, cancellation_token: CancellationToken) { + pub fn start_monitoring(&self, cancellation_token: CancellationToken) { let provider_address = self.wallet.wallet.default_signer().address(); let contracts = self.contracts.clone(); @@ -58,12 +53,12 @@ impl ProviderOperations { loop { tokio::select! { _ = cancellation_token.cancelled() => { - Console::info("Monitor", "Shutting down provider status monitor..."); + log::info!("Shutting down provider status monitor..."); break; } _ = async { let Some(stake_manager) = contracts.stake_manager.as_ref() else { - Console::user_error("Cannot start monitoring - stake manager not initialized"); + log::error!("Cannot start monitoring - stake manager not initialized"); return; }; @@ -71,21 +66,21 @@ impl ProviderOperations { match stake_manager.get_stake(provider_address).await { Ok(stake) => { if first_check || stake != last_stake { - Console::info("🔄 Chain Sync - Provider stake", &format_ether(stake)); + log::info!("🔄 Chain Sync - Provider stake: {}", format_ether(stake)); if !first_check { if stake < last_stake { - Console::warning(&format!("Stake decreased - possible slashing detected: From {} to {}", + log::warn!("Stake decreased - possible slashing detected: From {} to {}", format_ether(last_stake), format_ether(stake) - )); + ); if stake == U256::ZERO { - Console::warning("Stake is 0 - you might have to restart the node to increase your stake (if you still have balance left)"); + log::warn!("Stake is 0 - you might have to restart the node to increase your stake (if you still have balance left)"); } } else { - Console::info("🔄 Chain Sync - Stake changed", &format!("From {} to {}", + log::info!("🔄 Chain Sync - Stake increased: From {} to {}", format_ether(last_stake), format_ether(stake) - )); + ); } } last_stake = stake; @@ -102,13 +97,7 @@ impl ProviderOperations { match contracts.ai_token.balance_of(provider_address).await { Ok(balance) => { if first_check || balance != last_balance { - Console::info("🔄 Chain Sync - Balance", &format_ether(balance)); - if !first_check { - Console::info("🔄 Chain Sync - Balance changed", &format!("From {} to {}", - format_ether(last_balance), - format_ether(balance) - )); - } + log::info!("🔄 Chain Sync - Balance: {}", format_ether(balance)); last_balance = balance; } Some(balance) @@ -123,12 +112,12 @@ impl ProviderOperations { match contracts.compute_registry.get_provider(provider_address).await { Ok(provider) => { if first_check || provider.is_whitelisted != last_whitelist_status { - Console::info("🔄 Chain Sync - Whitelist status", &format!("{}", provider.is_whitelisted)); + log::info!("🔄 Chain Sync - Whitelist status: {}", provider.is_whitelisted); if !first_check { - Console::info("🔄 Chain Sync - Whitelist status changed", &format!("From {} to {}", + log::info!("🔄 Chain Sync - Whitelist status changed: {} -> {}", last_whitelist_status, provider.is_whitelisted - )); + ); } last_whitelist_status = provider.is_whitelisted; } @@ -146,7 +135,7 @@ impl ProviderOperations { }); } - pub(crate) async fn check_provider_exists(&self) -> Result { + pub async fn check_provider_exists(&self) -> Result { let address = self.wallet.wallet.default_signer().address(); let provider = self @@ -159,7 +148,7 @@ impl ProviderOperations { Ok(provider.provider_address != Address::default()) } - pub(crate) async fn check_provider_whitelisted(&self) -> Result { + pub async fn check_provider_whitelisted(&self) -> Result { let address = self.wallet.wallet.default_signer().address(); let provider = self @@ -171,29 +160,32 @@ impl ProviderOperations { Ok(provider.is_whitelisted) } - - pub(crate) async fn retry_register_provider( + pub async fn retry_register_provider( &self, stake: U256, max_attempts: u32, - cancellation_token: CancellationToken, + cancellation_token: Option, ) -> Result<(), ProviderError> { - Console::title("Registering Provider"); + log::info!("Registering Provider"); let mut attempts = 0; while attempts < max_attempts || max_attempts == 0 { - Console::progress("Registering provider..."); + log::info!("Registering provider..."); match self.register_provider(stake).await { Ok(_) => { return Ok(()); } Err(e) => match e { ProviderError::NotWhitelisted | ProviderError::InsufficientBalance => { - Console::info("Info", "Retrying in 10 seconds..."); - tokio::select! { - _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} - _ = cancellation_token.cancelled() => { - return Err(e); + log::info!("Retrying in 10 seconds..."); + if let Some(ref token) = cancellation_token { + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} + _ = token.cancelled() => { + return Err(e); + } } + } else { + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; } attempts += 1; continue; @@ -206,7 +198,7 @@ impl ProviderOperations { Err(ProviderError::Other) } - pub(crate) async fn register_provider(&self, stake: U256) -> Result<(), ProviderError> { + pub async fn register_provider(&self, stake: U256) -> Result<(), ProviderError> { let address = self.wallet.wallet.default_signer().address(); let balance: U256 = self .contracts @@ -224,42 +216,39 @@ impl ProviderOperations { let provider_exists = self.check_provider_exists().await?; if !provider_exists { - Console::info("Balance", &format_ether(balance)); - Console::info( - "ETH Balance", + log::info!("Balance: {}", &format_ether(balance)); + log::info!( + "ETH Balance: {}", &format!("{} ETH", format_ether(U256::from(eth_balance))), ); if balance < stake { - Console::user_error(&format!( - "Insufficient balance for stake: {}", - format_ether(stake) - )); + log::error!("Insufficient balance for stake: {}", format_ether(stake)); return Err(ProviderError::InsufficientBalance); } if !self.prompt_user_confirmation(&format!( "Do you want to approve staking {}?", format_ether(stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving for Stake transaction"); + log::info!("Approving for Stake transaction"); self.contracts .ai_token .approve(stake) .await .map_err(|_| ProviderError::Other)?; - Console::progress("Registering Provider"); + log::info!("Registering Provider"); let Ok(register_tx) = self.contracts.prime_network.register_provider(stake).await else { return Err(ProviderError::Other); }; - Console::info("Registration tx", &format!("{register_tx:?}")); + log::info!("Registration tx: {}", &format!("{register_tx:?}")); } // Get provider details again - cleanup later - Console::progress("Getting provider details"); + log::info!("Getting provider details"); let _ = self .contracts .compute_registry @@ -270,32 +259,29 @@ impl ProviderOperations { let provider_exists = self.check_provider_exists().await?; if !provider_exists { - Console::info("Balance", &format_ether(balance)); - Console::info( - "ETH Balance", + log::info!("Balance: {}", &format_ether(balance)); + log::info!( + "ETH Balance: {}", &format!("{} ETH", format_ether(U256::from(eth_balance))), ); if balance < stake { - Console::user_error(&format!( - "Insufficient balance for stake: {}", - format_ether(stake) - )); + log::error!("Insufficient balance for stake: {}", format_ether(stake)); return Err(ProviderError::InsufficientBalance); } if !self.prompt_user_confirmation(&format!( "Do you want to approve staking {}?", format_ether(stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving Stake transaction"); + log::info!("Approving Stake transaction"); self.contracts.ai_token.approve(stake).await.map_err(|e| { error!("Failed to approve stake: {e}"); ProviderError::Other })?; - Console::progress("Registering Provider"); + log::info!("Registering Provider"); let register_tx = match self.contracts.prime_network.register_provider(stake).await { Ok(tx) => tx, Err(e) => { @@ -303,7 +289,7 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info("Registration tx", &format!("{register_tx:?}")); + log::info!("Registration tx: {register_tx:?}"); } let provider = self @@ -315,23 +301,23 @@ impl ProviderOperations { let provider_exists = provider.provider_address != Address::default(); if !provider_exists { - Console::user_error( - "Provider could not be registered. Please ensure your balance is high enough.", + log::error!( + "Provider could not be registered. Please ensure your balance is high enough." ); return Err(ProviderError::Other); } - Console::success("Provider registered"); + log::info!("Provider registered"); if !provider.is_whitelisted { - Console::user_error("Provider is not whitelisted yet."); + log::error!("Provider is not whitelisted yet."); return Err(ProviderError::NotWhitelisted); } Ok(()) } - pub(crate) async fn increase_stake(&self, additional_stake: U256) -> Result<(), ProviderError> { - Console::title("💰 Increasing Provider Stake"); + pub async fn increase_stake(&self, additional_stake: U256) -> Result<(), ProviderError> { + log::info!("💰 Increasing Provider Stake"); let address = self.wallet.wallet.default_signer().address(); let balance: U256 = self @@ -341,11 +327,14 @@ impl ProviderOperations { .await .map_err(|_| ProviderError::Other)?; - Console::info("Current Balance", &format_ether(balance)); - Console::info("Additional stake amount", &format_ether(additional_stake)); + log::info!("Current Balance: {}", &format_ether(balance)); + log::info!( + "Additional stake amount: {}", + &format_ether(additional_stake) + ); if balance < additional_stake { - Console::user_error("Insufficient balance for stake increase"); + log::error!("Insufficient balance for stake increase"); return Err(ProviderError::Other); } @@ -353,20 +342,20 @@ impl ProviderOperations { "Do you want to approve staking {} additional funds?", format_ether(additional_stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving additional stake"); + log::info!("Approving additional stake"); let approve_tx = self .contracts .ai_token .approve(additional_stake) .await .map_err(|_| ProviderError::Other)?; - Console::info("Transaction approved", &format!("{approve_tx:?}")); + log::info!("Transaction approved: {}", &format!("{approve_tx:?}")); - Console::progress("Increasing stake"); + log::info!("Increasing stake"); let stake_tx = match self.contracts.prime_network.stake(additional_stake).await { Ok(tx) => tx, Err(e) => { @@ -374,17 +363,15 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info( - "Stake increase transaction completed: ", - &format!("{stake_tx:?}"), + log::info!( + "Stake increase transaction completed: {}", + &format!("{stake_tx:?}") ); - Console::success("Provider stake increased successfully"); Ok(()) } - pub(crate) async fn reclaim_stake(&self, amount: U256) -> Result<(), ProviderError> { - Console::progress("Reclaiming stake"); + pub async fn reclaim_stake(&self, amount: U256) -> Result<(), ProviderError> { let reclaim_tx = match self.contracts.prime_network.reclaim_stake(amount).await { Ok(tx) => tx, Err(e) => { @@ -392,17 +379,16 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info( - "Stake reclaim transaction completed: ", - &format!("{reclaim_tx:?}"), + log::info!( + "Stake reclaim transaction completed: {}", + &format!("{reclaim_tx:?}") ); - Console::success("Provider stake reclaimed successfully"); Ok(()) } } #[derive(Debug)] -pub(crate) enum ProviderError { +pub enum ProviderError { NotWhitelisted, UserCancelled, Other, diff --git a/crates/orchestrator/Cargo.toml b/crates/orchestrator/Cargo.toml index ce733ee6..594df148 100644 --- a/crates/orchestrator/Cargo.toml +++ b/crates/orchestrator/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] p2p = { workspace = true} shared = { workspace = true } +operations = { workspace = true } actix-web = { workspace = true } alloy = { workspace = true } diff --git a/crates/orchestrator/src/api/routes/heartbeat.rs b/crates/orchestrator/src/api/routes/heartbeat.rs index a8110e61..4d6261f9 100644 --- a/crates/orchestrator/src/api/routes/heartbeat.rs +++ b/crates/orchestrator/src/api/routes/heartbeat.rs @@ -404,7 +404,7 @@ mod tests { let task = match task.try_into() { Ok(task) => task, - Err(e) => panic!("Failed to convert TaskRequest to Task: {}", e), + Err(e) => panic!("Failed to convert TaskRequest to Task: {e}"), }; let _ = app_state.store_context.task_store.add_task(task).await; diff --git a/crates/orchestrator/src/api/routes/task.rs b/crates/orchestrator/src/api/routes/task.rs index 7cff4b6d..fa167dc7 100644 --- a/crates/orchestrator/src/api/routes/task.rs +++ b/crates/orchestrator/src/api/routes/task.rs @@ -315,8 +315,8 @@ mod tests { // Add tasks in sequence with delays for i in 1..=3 { let task: Task = TaskRequest { - image: format!("test{}", i), - name: format!("test{}", i), + image: format!("test{i}"), + name: format!("test{i}"), ..Default::default() } .try_into() diff --git a/crates/orchestrator/src/node/invite.rs b/crates/orchestrator/src/node/invite.rs index 8391d047..7aa68558 100644 --- a/crates/orchestrator/src/node/invite.rs +++ b/crates/orchestrator/src/node/invite.rs @@ -3,19 +3,17 @@ use crate::models::node::OrchestratorNode; use crate::p2p::InviteRequest as InviteRequestWithMetadata; use crate::store::core::StoreContext; use crate::utils::loop_heartbeats::LoopHeartbeats; -use alloy::primitives::utils::keccak256 as keccak; -use alloy::primitives::U256; -use alloy::signers::Signer; use anyhow::{bail, Result}; use futures::stream; use futures::StreamExt; use log::{debug, error, info, warn}; -use p2p::InviteRequest; +use operations::invite::{ + admin::{generate_invite_expiration, generate_invite_nonce, generate_invite_signature}, + common::InviteBuilder, +}; use p2p::InviteRequestUrl; use shared::web3::wallet::Wallet; use std::sync::Arc; -use std::time::SystemTime; -use std::time::UNIX_EPOCH; use tokio::sync::mpsc::Sender; use tokio::time::{interval, Duration}; @@ -89,29 +87,15 @@ impl NodeInviter { nonce: [u8; 32], expiration: [u8; 32], ) -> Result<[u8; 65]> { - let domain_id: [u8; 32] = U256::from(self.domain_id).to_be_bytes(); - let pool_id: [u8; 32] = U256::from(self.pool_id).to_be_bytes(); - - let digest = keccak( - [ - &domain_id, - &pool_id, - node.address.as_slice(), - &nonce, - &expiration, - ] - .concat(), - ); - - let signature = self - .wallet - .signer - .sign_message(digest.as_slice()) - .await? - .as_bytes() - .to_owned(); - - Ok(signature) + generate_invite_signature( + &self.wallet, + self.domain_id, + self.pool_id, + node.address, + nonce, + expiration, + ) + .await } async fn send_invite(&self, node: &OrchestratorNode) -> Result<(), anyhow::Error> { @@ -122,29 +106,21 @@ impl NodeInviter { let p2p_addresses = node.worker_p2p_addresses.as_ref().unwrap(); // Generate random nonce and expiration - let nonce: [u8; 32] = rand::random(); - let expiration: [u8; 32] = U256::from( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| anyhow::anyhow!("System time error: {}", e))? - .as_secs() - + 1000, - ) - .to_be_bytes(); + let nonce = generate_invite_nonce(); + let expiration = generate_invite_expiration(Some(1000))?; let invite_signature = self.generate_invite(node, nonce, expiration).await?; - let payload = InviteRequest { - invite: hex::encode(invite_signature), - pool_id: self.pool_id, - url: self.url.clone(), - timestamp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| anyhow::anyhow!("System time error: {}", e))? - .as_secs(), - expiration, - nonce, + + // Build the invite request using the builder + let builder = match &self.url { + InviteRequestUrl::MasterUrl(url) => InviteBuilder::with_url(self.pool_id, url.clone()), + InviteRequestUrl::MasterIpPort(ip, port) => { + InviteBuilder::with_ip_port(self.pool_id, ip.clone(), *port) + } }; + let payload = builder.build(invite_signature, nonce, expiration)?; + info!("Sending invite to node: {p2p_id}"); let (response_tx, response_rx) = tokio::sync::oneshot::channel(); diff --git a/crates/orchestrator/src/plugins/node_groups/tests.rs b/crates/orchestrator/src/plugins/node_groups/tests.rs index a7d73b36..5fc22430 100644 --- a/crates/orchestrator/src/plugins/node_groups/tests.rs +++ b/crates/orchestrator/src/plugins/node_groups/tests.rs @@ -276,9 +276,7 @@ async fn test_group_formation_with_multiple_configs() { let _ = plugin.try_form_new_groups().await; let mut conn = plugin.store.client.get_connection().unwrap(); - let groups: Vec = conn - .keys(format!("{}*", GROUP_KEY_PREFIX).as_str()) - .unwrap(); + let groups: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*").as_str()).unwrap(); assert_eq!(groups.len(), 2); // Verify group was created @@ -1102,7 +1100,7 @@ async fn test_node_cannot_be_in_multiple_groups() { ); // Get all group keys - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); let group_copy = group_keys.clone(); // There should be exactly one group @@ -1167,7 +1165,7 @@ async fn test_node_cannot_be_in_multiple_groups() { let _ = plugin.try_form_new_groups().await; // Get updated group keys - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); // There should now be exactly two groups assert_eq!( @@ -1544,7 +1542,7 @@ async fn test_task_observer() { let _ = store_context.task_store.add_task(task2.clone()).await; let _ = plugin.try_form_new_groups().await; let all_tasks = store_context.task_store.get_all_tasks().await.unwrap(); - println!("All tasks: {:?}", all_tasks); + println!("All tasks: {all_tasks:?}"); assert_eq!(all_tasks.len(), 2); assert!(all_tasks[0].id != all_tasks[1].id); let topologies = get_task_topologies(&task).unwrap(); @@ -1588,7 +1586,7 @@ async fn test_task_observer() { .unwrap(); assert!(group_3.is_some()); let all_tasks = store_context.task_store.get_all_tasks().await.unwrap(); - println!("All tasks: {:?}", all_tasks); + println!("All tasks: {all_tasks:?}"); assert_eq!(all_tasks.len(), 2); // Manually assign the first task to the group to test immediate dissolution let group_3_before = plugin @@ -1615,7 +1613,7 @@ async fn test_task_observer() { .get_node_group(&node_3.address.to_string()) .await .unwrap(); - println!("Group 3 after task deletion: {:?}", group_3); + println!("Group 3 after task deletion: {group_3:?}"); // With new behavior, group should be dissolved immediately when its assigned task is deleted assert!(group_3.is_none()); @@ -1833,7 +1831,7 @@ async fn test_group_formation_priority() { let nodes: Vec<_> = (1..=4) .map(|i| { create_test_node( - &format!("0x{}234567890123456789012345678901234567890", i), + &format!("0x{i}234567890123456789012345678901234567890"), NodeStatus::Healthy, None, ) @@ -1863,7 +1861,7 @@ async fn test_group_formation_priority() { // Verify: Should form one 3-node group + one 1-node group // NOT four 1-node groups let mut conn = plugin.store.client.get_connection().unwrap(); - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); assert_eq!(group_keys.len(), 2, "Should form exactly 2 groups"); // Check group compositions @@ -1944,7 +1942,7 @@ async fn test_multiple_groups_same_configuration() { let nodes: Vec<_> = (1..=6) .map(|i| { create_test_node( - &format!("0x{}234567890123456789012345678901234567890", i), + &format!("0x{i}234567890123456789012345678901234567890"), NodeStatus::Healthy, None, ) @@ -1958,7 +1956,7 @@ async fn test_multiple_groups_same_configuration() { // Verify: Should create 3 groups of 2 nodes each let mut conn = plugin.store.client.get_connection().unwrap(); - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); assert_eq!(group_keys.len(), 3, "Should form exactly 3 groups"); // Verify all groups have exactly 2 nodes and same configuration @@ -2663,7 +2661,7 @@ async fn test_no_merge_when_policy_disabled() { // Create 3 nodes let nodes: Vec<_> = (1..=3) - .map(|i| create_test_node(&format!("0x{:040x}", i), NodeStatus::Healthy, None)) + .map(|i| create_test_node(&format!("0x{i:040x}"), NodeStatus::Healthy, None)) .collect(); for node in &nodes { diff --git a/crates/orchestrator/src/scheduler/mod.rs b/crates/orchestrator/src/scheduler/mod.rs index 711f313f..d5ffa506 100644 --- a/crates/orchestrator/src/scheduler/mod.rs +++ b/crates/orchestrator/src/scheduler/mod.rs @@ -144,12 +144,12 @@ mod tests { ); assert_eq!( env_vars.get("NODE_VAR").unwrap(), - &format!("node-{}", node_address) + &format!("node-{node_address}") ); // Check cmd replacement let cmd = returned_task.cmd.unwrap(); assert_eq!(cmd[0], format!("--task={}", task.id)); - assert_eq!(cmd[1], format!("--node={}", node_address)); + assert_eq!(cmd[1], format!("--node={node_address}")); } } diff --git a/crates/orchestrator/src/status_update/mod.rs b/crates/orchestrator/src/status_update/mod.rs index b2738488..67140cbc 100644 --- a/crates/orchestrator/src/status_update/mod.rs +++ b/crates/orchestrator/src/status_update/mod.rs @@ -372,6 +372,7 @@ async fn process_node( } #[cfg(test)] +#[allow(clippy::unused_async)] async fn is_node_in_pool(_: Contracts, _: u32, _: &OrchestratorNode) -> bool { true } @@ -433,7 +434,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let heartbeat = HeartbeatRequest { address: node.address.to_string(), @@ -451,7 +452,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } let _ = updater.process_nodes().await; @@ -510,7 +511,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -563,7 +564,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -623,7 +624,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } if let Err(e) = app_state .store_context @@ -631,7 +632,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); } let mode = ServerMode::Full; @@ -687,7 +688,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); }; let heartbeat = HeartbeatRequest { @@ -702,7 +703,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } if let Err(e) = app_state .store_context @@ -710,7 +711,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; @@ -772,7 +773,7 @@ mod tests { .set_unhealthy_counter(&node1.address, 1) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); }; if let Err(e) = app_state .store_context @@ -780,7 +781,7 @@ mod tests { .add_node(node1.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let node2 = OrchestratorNode { @@ -797,7 +798,7 @@ mod tests { .add_node(node2.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; @@ -873,7 +874,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } if let Err(e) = app_state .store_context @@ -881,7 +882,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); } let mode = ServerMode::Full; @@ -926,7 +927,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } sleep(Duration::from_secs(5)).await; @@ -960,7 +961,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -1029,7 +1030,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( diff --git a/crates/orchestrator/src/store/core/redis.rs b/crates/orchestrator/src/store/core/redis.rs index 79f57ce8..3b524b33 100644 --- a/crates/orchestrator/src/store/core/redis.rs +++ b/crates/orchestrator/src/store/core/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/orchestrator/src/store/domains/heartbeat_store.rs b/crates/orchestrator/src/store/domains/heartbeat_store.rs index b2f8138a..8bb43374 100644 --- a/crates/orchestrator/src/store/domains/heartbeat_store.rs +++ b/crates/orchestrator/src/store/domains/heartbeat_store.rs @@ -80,7 +80,7 @@ impl HeartbeatStore { .get_multiplexed_async_connection() .await .map_err(|_| anyhow!("Failed to get connection"))?; - let key = format!("{}:{}", ORCHESTRATOR_UNHEALTHY_COUNTER_KEY, address); + let key = format!("{ORCHESTRATOR_UNHEALTHY_COUNTER_KEY}:{address}"); con.set(key, counter.to_string()) .await .map_err(|_| anyhow!("Failed to set value")) diff --git a/crates/orchestrator/src/store/domains/metrics_store.rs b/crates/orchestrator/src/store/domains/metrics_store.rs index 1a0d79ac..5520860a 100644 --- a/crates/orchestrator/src/store/domains/metrics_store.rs +++ b/crates/orchestrator/src/store/domains/metrics_store.rs @@ -145,7 +145,7 @@ impl MetricsStore { task_id: &str, ) -> Result> { let mut con = self.redis.client.get_multiplexed_async_connection().await?; - let pattern = format!("{}:*", ORCHESTRATOR_NODE_METRICS_STORE); + let pattern = format!("{ORCHESTRATOR_NODE_METRICS_STORE}:*"); // Scan all node keys let mut iter: redis::AsyncIter = con.scan_match(&pattern).await?; diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs index f5bc648c..228660e9 100644 --- a/crates/p2p/src/lib.rs +++ b/crates/p2p/src/lib.rs @@ -274,6 +274,8 @@ impl NodeBuilder { cancellation_token, } = self; + println!("multi addrs: {listen_addrs:?}"); + let keypair = keypair.unwrap_or(identity::Keypair::generate_ed25519()); let peer_id = keypair.public().to_peer_id(); diff --git a/crates/p2p/src/message/mod.rs b/crates/p2p/src/message/mod.rs index 74b09c5a..5d4431d8 100644 --- a/crates/p2p/src/message/mod.rs +++ b/crates/p2p/src/message/mod.rs @@ -229,7 +229,7 @@ impl From for Response { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeneralRequest { - data: Vec, + pub data: Vec, } impl From for Request { @@ -240,7 +240,7 @@ impl From for Request { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeneralResponse { - data: Vec, + pub data: Vec, } impl From for Response { diff --git a/crates/python-sdk/.gitignore b/crates/python-sdk/.gitignore new file mode 100644 index 00000000..454f9f33 --- /dev/null +++ b/crates/python-sdk/.gitignore @@ -0,0 +1,24 @@ +# Python +__pycache__/ +*.py[cod] +*.so +*.pyd +*.egg-info/ +dist/ + +# Virtual environments +.venv/ + +# Testing +.pytest_cache/ + +# IDE +.vscode/ +.idea/ + +# Rust/Maturin +target/ +Cargo.lock + +# OS +.DS_Store \ No newline at end of file diff --git a/crates/python-sdk/.python-version b/crates/python-sdk/.python-version new file mode 100644 index 00000000..4b7e4839 --- /dev/null +++ b/crates/python-sdk/.python-version @@ -0,0 +1 @@ +3.11 \ No newline at end of file diff --git a/crates/python-sdk/Cargo.toml b/crates/python-sdk/Cargo.toml new file mode 100644 index 00000000..fffe1880 --- /dev/null +++ b/crates/python-sdk/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "python-sdk" +version = "0.1.0" +authors = ["Prime Protocol"] +edition = "2021" +rust-version = "1.70" + +[lib] +name = "primeprotocol" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.25.1", features = ["extension-module"] } +thiserror = "1.0" +shared = { workspace = true } +operations = { workspace = true } +p2p = { workspace = true } +alloy = { workspace = true } +alloy-provider = { workspace = true } +tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "sync", "time", "macros"] } +url = "2.5" +log = { workspace = true } +pyo3-log = "0.12.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +pythonize = "0.25" +futures = { workspace = true } +anyhow = { workspace = true } +tokio-util = { workspace = true } +rand = { version = "0.8", features = ["std"] } +hex = "0.4" +reqwest = { version = "0.11", features = ["json"] } + +[dev-dependencies] +test-log = "0.2" +tokio-test = "0.4" + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +strip = true + diff --git a/crates/python-sdk/Makefile b/crates/python-sdk/Makefile new file mode 100644 index 00000000..dfb10ac9 --- /dev/null +++ b/crates/python-sdk/Makefile @@ -0,0 +1,21 @@ +.PHONY: install +install: + @command -v uv > /dev/null || (echo "Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" && exit 1) + @./setup.sh # Uses uv for fast package management + +.PHONY: build +build: install + @uv cache clean + @source .venv/bin/activate && maturin develop + @source .venv/bin/activate && uv pip install --force-reinstall -e . + +.PHONY: clean +clean: + @rm -rf target/ dist/ *.egg-info .pytest_cache __pycache__ .venv/ + +.PHONY: help +help: + @echo "Available commands:" + @echo " make install - Setup environment and install dependencies" + @echo " make build - Build development version (includes install and cache clear)" + @echo " make clean - Clean build artifacts" \ No newline at end of file diff --git a/crates/python-sdk/README.md b/crates/python-sdk/README.md new file mode 100644 index 00000000..1b3cdc28 --- /dev/null +++ b/crates/python-sdk/README.md @@ -0,0 +1,80 @@ +# Prime Protocol Python SDK + +## Features + +- Worker Client +- Validator Client +- Orchestrator Client + +## Local Development Setup +Startup the local dev chain with contracts using the following Make cmd from the base folder: +```bash +make bootstrap +``` + +Within this folder run the following to install the setup: +```bash +make install +``` + +Whenever you change something in your code, simply run `make build`. + +## Example flow: +The following example flow implements the typical prime protocol toplogy with an orchestrator, validator and worker node. +It still uses a centralized discovery service which will be replaced shortly. + +```bash +export PRIVATE_KEY_NODE=XYZ +export PRIVATE_KEY_PROVIDER=XYZ +``` +(for the private keys you can just use the keys from the `.env.example`) + +```bash +uv run examples/worker.py +``` + +You'll likely need to whitelist the provider from the base folder: +```bash +make whitelist-provider +``` + +You should be able to see the node now on the discovery service: +```bash +curl http://localhost:8089/api/platform -H "Authorization: Bearer prime" | jq +``` + +### Validator: +Next run the validator to ensure nodes are validated on chain. This is very basic logic. The pool owner could come up with something more sophisticated in the future. +```bash +uv run examples/validator.py +``` + +On underlying python side all the logic here is within: +```python +from primeprotocol import ValidatorClient +``` + +You should actually see that the node is validated now on the discovery service. + +### Orchestrator / pool owner +Once we actually have validated nodes, we still want to invite them to the compute pool and have them contribute work. + +Use the pool owner private key from the .env.example and export: +```bash +export PRIVATE_KEY_ORCHESTRATOR=xyz +``` +To run the orchestrator simply execute: +```bash +uv run examples/orchestrator.py +``` +You should see that the orchestrator invites the worker node that you started in the earlier step and keeps sending messages via p2p to this node. + +## TODOs: +- [ ] restart keeps increasing provider stake? +- [ ] I keep forgetting to run make build +- [ ] what about formatting? + +## Known Limitations: +1. Orchestrator can no longer send messages to a node when the worker node restarts. +This is because the discovery info is not refreshed when the node was already active in the pool. This will change when introducing a proper DHT layer. +Shortcut here is to simply reset the discovery service using: `redis-cli -p 6380 "FLUSHALL"` \ No newline at end of file diff --git a/crates/python-sdk/examples/orchestrator.py b/crates/python-sdk/examples/orchestrator.py new file mode 100644 index 00000000..4fba6717 --- /dev/null +++ b/crates/python-sdk/examples/orchestrator.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Orchestrator Client.""" + +import os +import logging +import time +from primeprotocol import OrchestratorClient + +# Configure logging +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + +def main(): + # Get configuration from environment variables + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + private_key = os.getenv("PRIVATE_KEY_ORCHESTRATOR") + discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") + discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] + pool_id = int(os.getenv("POOL_ID", "0")) + p2p_port = int(os.getenv("ORCHESTRATOR_P2P_PORT", "8180")) + + if not private_key: + print("Error: PRIVATE_KEY_ORCHESTRATOR environment variable is required") + return + + print(f"Initializing orchestrator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + print(f"Pool ID: {pool_id}") + print(f"P2P Port: {p2p_port}") + + # Initialize and start the orchestrator + orchestrator = OrchestratorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls, + ) + + print("Starting orchestrator client...") + orchestrator.start(p2p_port=p2p_port) + print(f"Orchestrator started with peer ID: {orchestrator.get_peer_id()}") + + print("\nStarting orchestrator loop...") + print("Press Ctrl+C to stop\n") + + try: + while True: + print(f"{'='*50}") + print(f"Cycle at {time.strftime('%H:%M:%S')}") + + # Check for a single message + message = orchestrator.get_next_message() + print(f"Got message - python orchestrator: {message}") + + # 1. Invite validated but inactive nodes + pool_nodes = orchestrator.list_nodes_for_pool(pool_id) + validated_inactive = [n for n in pool_nodes if n.is_validated and not n.is_active and n.worker_p2p_id] + + if validated_inactive: + print(f"\nInviting {len(validated_inactive)} validated inactive nodes...") + for node in validated_inactive: + try: + print(f" 📨 Inviting {node.id[:8]}...") + orchestrator.invite_node( + peer_id=node.worker_p2p_id, + worker_address=node.id, + pool_id=pool_id, + multiaddrs=node.worker_p2p_addresses, + domain_id=0, + orchestrator_url=None, + expiration_seconds=1000 + ) + except Exception as e: + print(f" ❌ Error inviting {node.id[:8]}: {e}") + + # 2. Send messages to active nodes + active_nodes = [n for n in pool_nodes if n.is_active and n.worker_p2p_id] + + if active_nodes: + print(f"\nMessaging {len(active_nodes)} active nodes...") + for node in active_nodes: + try: + orchestrator.send_message( + peer_id=node.worker_p2p_id, + multiaddrs=node.worker_p2p_addresses, + data=b"Hello from orchestrator!", + ) + print(f" 💬 Sent to {node.id[:8]}...") + except Exception as e: + print(f" ❌ Error messaging {node.id[:8]}: {e}") + + # Wait before next cycle + print(f"\nWaiting 10 seconds...") + time.sleep(10) + print() + + except KeyboardInterrupt: + print("\n\nShutting down orchestrator...") + orchestrator.stop() + print("Orchestrator stopped") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/python-sdk/examples/validator.py b/crates/python-sdk/examples/validator.py new file mode 100644 index 00000000..2fdf2f79 --- /dev/null +++ b/crates/python-sdk/examples/validator.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Validator Client.""" + +import os +import logging +import time +from primeprotocol import ValidatorClient + +# Configure logging +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + +def main(): + # Get configuration from environment variables + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + private_key = os.getenv("PRIVATE_KEY_VALIDATOR") + discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") + discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] + p2p_port = int(os.getenv("VALIDATOR_P2P_PORT", "8665")) + + if not private_key: + print("Error: PRIVATE_KEY_VALIDATOR environment variable is required") + return + + print(f"Initializing validator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + print(f"P2P Port: {p2p_port}") + + # Initialize and start the validator + validator = ValidatorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls, + ) + + print("Starting validator client...") + validator.start(p2p_port=p2p_port) + print(f"Validator started with peer ID: {validator.get_peer_id()}") + + print("\nStarting validator loop...") + print("Press Ctrl+C to stop\n") + + try: + while True: + print(f"{'='*50}") + print(f"Cycle at {time.strftime('%H:%M:%S')}") + # Check for a single message + message = validator.get_next_message() + print(f"Message: {message}") + if message: + msg_data = message.get('message', {}) + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + print(f"Message payload: {data}") + non_validated = validator.list_non_validated_nodes() + if non_validated: + print(f"\nValidating {len(non_validated)} nodes...") + for node in non_validated: + try: + print(f" ✅ Validating {node.id[:8]}...") + validator.validate_node(node.id, node.provider_address) + except Exception as e: + print(f" ❌ Error validating {node.id[:8]}: {e}") + else: + print("\nNo nodes to validate") + + # 2. Send messages to validated nodes + all_nodes = validator.list_all_nodes_dict() + validated = [n for n in all_nodes if n.get('is_validated') and n.get('worker_p2p_id')] + + if validated: + print(f"\nMessaging {len(validated)} validated nodes...") + for node in validated: + try: + validator.send_message( + peer_id=node['worker_p2p_id'], + multiaddrs=node.get('worker_p2p_addresses', []), + data=b"Hello from validator!", + ) + print(f" 💬 Sent to {node['id'][:8]}...") + except Exception as e: + print(f" ❌ Error messaging {node['id'][:8]}: {e}") + + # Wait before next cycle + print(f"\nWaiting 10 seconds...") + time.sleep(10) + print() + + except KeyboardInterrupt: + print("\n\nShutting down validator...") + # ValidatorClient doesn't have a stop() method + print("Validator stopped") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/python-sdk/examples/worker.py b/crates/python-sdk/examples/worker.py new file mode 100644 index 00000000..cd708040 --- /dev/null +++ b/crates/python-sdk/examples/worker.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Python client.""" + +import asyncio +import logging +import os +import signal +import sys +import time +from typing import Dict, Any, Optional +from primeprotocol import WorkerClient + +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def main(): + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + pool_id = os.getenv("POOL_ID", 0) + private_key_provider = os.getenv("PRIVATE_KEY_PROVIDER", None) + private_key_node = os.getenv("PRIVATE_KEY_NODE", None) + + logging.info(f"Connecting to: {rpc_url}") + + port = int(os.getenv("PORT", 8003)) + client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node, port) + + # Track known peer addresses + known_peers = {} + + def signal_handler(sig, frame): + logging.info("Received interrupt signal, shutting down gracefully...") + try: + client.stop() + logging.info("Client stopped successfully") + except Exception as e: + logging.error(f"Error during shutdown: {e}") + sys.exit(0) + + # Register signal handler for Ctrl+C before starting client + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + logging.info("Starting client... (Press Ctrl+C to interrupt)") + client.start() + + client.upload_to_discovery("127.0.0.1", None) + + my_peer_id = client.get_own_peer_id() + logging.info(f"My Peer ID: {my_peer_id}") + + time.sleep(5) + + # Note: To send messages to other peers manually, you can use: + # peer_id = "12D3KooWELi4p1oR3QBSYiq1rvPpyjbkiQVhQJqCobBBUS7C6JrX" + # peer_multi_addr = "/ip4/127.0.0.1/tcp/8002" + # client.send_message(peer_id, b"Hello, world!", [peer_multi_addr]) + + logging.info("Setup completed. Starting message polling loop...") + print("Worker client started. Polling for orchestrator/validator messages. Press Ctrl+C to stop.") + + # Message polling loop - listening for orchestrator and validator messages + while True: + try: + message = client.get_next_message() + if message: + msg_data = message.get('message', {}) + peer_id = message['peer_id'] + multiaddrs = message.get('multiaddrs', []) + + # Determine sender type + sender_type = "worker" + if message.get('is_sender_validator'): + sender_type = "VALIDATOR" + elif message.get('is_sender_pool_owner'): + sender_type = "POOL_OWNER" + + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + print(f"\n[{time.strftime('%H:%M:%S')}] Message from {peer_id} ({sender_type}): {data}") + print(f" Multiaddrs received: {multiaddrs}") + + # Check if it's an error message + if data.startswith(b"ERROR:"): + print(f" ⚠️ Received error: {data}") + continue + + # Respond to validators and pool owners + if sender_type in ["VALIDATOR", "POOL_OWNER"]: + try: + response_msg = f"Hello {sender_type}! Worker received: {data.decode('utf-8', errors='ignore')}" + + print(f" Attempting to respond to {peer_id}...") + print(f" Response message: {response_msg}") + + # Send response using empty multiaddrs (peer should already be connected) + client.send_message( + peer_id=peer_id, + multiaddrs=[], # Empty - peer already connected + data=response_msg.encode() + ) + print(f" ✓ Response sent to {sender_type}") + except Exception as e: + print(f" ✗ Error sending response: {e}") + print(f" Error type: {type(e).__name__}") + elif msg_data.get('type') == 'authentication_complete': + print(f"\n[{time.strftime('%H:%M:%S')}] ✓ Authentication complete with {peer_id} ({sender_type})") + else: + print(f"\n[{time.strftime('%H:%M:%S')}] Message from {peer_id} ({sender_type}): type={msg_data.get('type')}") + + time.sleep(0.1) # Small delay to prevent busy waiting + except KeyboardInterrupt: + # Handle Ctrl+C during message polling + logging.info("Keyboard interrupt received during polling") + signal_handler(signal.SIGINT, None) + break + + except KeyboardInterrupt: + # Handle Ctrl+C during client startup + logging.info("Keyboard interrupt received during startup") + signal_handler(signal.SIGINT, None) + except Exception as e: + logging.error(f"Unexpected error: {e}") + try: + client.stop() + except: + pass + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/python-sdk/pyproject.toml b/crates/python-sdk/pyproject.toml new file mode 100644 index 00000000..9834d8b4 --- /dev/null +++ b/crates/python-sdk/pyproject.toml @@ -0,0 +1,37 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "primeprotocol" +description = "Simple Python bindings for Prime Protocol client" +readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} +keywords = ["prime", "protocol"] +authors = [ + {name = "Prime Protocol", email = "jannik@primeintellect.ai"} +] +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dynamic = ["version"] + +[project.urls] +"Homepage" = "https://github.com/primeprotocol/protocol" +"Bug Tracker" = "https://github.com/primeprotocol/protocol/issues" + +[tool.maturin] +features = ["pyo3/extension-module"] +module-name = "primeprotocol" \ No newline at end of file diff --git a/crates/python-sdk/requirements-dev.txt b/crates/python-sdk/requirements-dev.txt new file mode 100644 index 00000000..f2af3c5d --- /dev/null +++ b/crates/python-sdk/requirements-dev.txt @@ -0,0 +1,3 @@ +# Development dependencies +maturin>=1.0,<2.0 +pytest>=7.0 \ No newline at end of file diff --git a/crates/python-sdk/setup.sh b/crates/python-sdk/setup.sh new file mode 100755 index 00000000..7609b236 --- /dev/null +++ b/crates/python-sdk/setup.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e + +# Check if uv is installed +if ! command -v uv &> /dev/null; then + echo "Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" + exit 1 +fi + +# Setup environment +uv venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +maturin develop + +echo "Setup complete." \ No newline at end of file diff --git a/crates/python-sdk/src/common/mod.rs b/crates/python-sdk/src/common/mod.rs new file mode 100644 index 00000000..89aae39a --- /dev/null +++ b/crates/python-sdk/src/common/mod.rs @@ -0,0 +1,71 @@ +use pyo3::prelude::*; +use shared::models::node::DiscoveryNode; + +/// Node details structure shared between validator and orchestrator +#[pyclass] +#[derive(Clone)] +pub struct NodeDetails { + #[pyo3(get)] + pub id: String, + #[pyo3(get)] + pub provider_address: String, + #[pyo3(get)] + pub ip_address: String, + #[pyo3(get)] + pub port: u16, + #[pyo3(get)] + pub compute_pool_id: u32, + #[pyo3(get)] + pub is_validated: bool, + #[pyo3(get)] + pub is_active: bool, + #[pyo3(get)] + pub is_provider_whitelisted: bool, + #[pyo3(get)] + pub is_blacklisted: bool, + #[pyo3(get)] + pub worker_p2p_id: Option, + #[pyo3(get)] + pub worker_p2p_addresses: Option>, + #[pyo3(get)] + pub last_updated: Option, + #[pyo3(get)] + pub created_at: Option, +} + +impl From for NodeDetails { + fn from(node: DiscoveryNode) -> Self { + Self { + id: node.node.id, + provider_address: node.node.provider_address, + ip_address: node.node.ip_address, + port: node.node.port, + compute_pool_id: node.node.compute_pool_id, + is_validated: node.is_validated, + is_active: node.is_active, + is_provider_whitelisted: node.is_provider_whitelisted, + is_blacklisted: node.is_blacklisted, + last_updated: node.last_updated.map(|dt| dt.to_rfc3339()), + created_at: node.created_at.map(|dt| dt.to_rfc3339()), + worker_p2p_id: node.node.worker_p2p_id, + worker_p2p_addresses: node.node.worker_p2p_addresses, + } + } +} + +#[pymethods] +impl NodeDetails { + /// Get compute specifications as a Python dictionary + pub fn get_compute_specs(&self, py: Python) -> PyResult { + // This method would need access to the original node data + // For now, return None since we don't store compute specs in NodeDetails + Ok(py.None()) + } + + /// Get location information as a Python dictionary + pub fn get_location(&self, py: Python) -> PyResult { + // This method would need access to the original node data + // For now, return None since we don't store location in NodeDetails + Ok(py.None()) + } +} diff --git a/crates/python-sdk/src/constants.rs b/crates/python-sdk/src/constants.rs new file mode 100644 index 00000000..800b5e49 --- /dev/null +++ b/crates/python-sdk/src/constants.rs @@ -0,0 +1,26 @@ +use std::time::Duration; + +/// Default P2P port for worker nodes +pub const DEFAULT_P2P_PORT: u16 = 8000; + +/// Default number of retries for funding operations +pub const DEFAULT_FUNDING_RETRY_COUNT: u32 = 10; + +/// Default compute units for node registration +pub const DEFAULT_COMPUTE_UNITS: u64 = 1; + +/// Timeout for blockchain operations +pub const BLOCKCHAIN_OPERATION_TIMEOUT: Duration = Duration::from_secs(300); + +/// Timeout for message queue operations +pub const MESSAGE_QUEUE_TIMEOUT: Duration = Duration::from_millis(100); + +/// Pool status check interval +pub const POOL_STATUS_CHECK_INTERVAL: Duration = Duration::from_secs(15); + +/// P2P shutdown timeout +pub const P2P_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +/// Channel sizes +pub const P2P_CHANNEL_SIZE: usize = 100; +pub const MESSAGE_QUEUE_CHANNEL_SIZE: usize = 300; diff --git a/crates/python-sdk/src/error.rs b/crates/python-sdk/src/error.rs new file mode 100644 index 00000000..cf561595 --- /dev/null +++ b/crates/python-sdk/src/error.rs @@ -0,0 +1,21 @@ +use thiserror::Error; + +/// Result type alias for Prime Protocol operations +pub type Result = std::result::Result; + +/// Errors that can occur in the Prime Protocol client +#[derive(Debug, Error)] +pub enum PrimeProtocolError { + /// Invalid configuration provided + #[error("Invalid configuration: {0}")] + InvalidConfig(String), + + /// Blockchain interaction error + #[error("Blockchain error: {0}")] + BlockchainError(String), + + /// General runtime error + #[error("Runtime error: {0}")] + #[allow(dead_code)] + RuntimeError(String), +} diff --git a/crates/python-sdk/src/lib.rs b/crates/python-sdk/src/lib.rs new file mode 100644 index 00000000..a2e721b9 --- /dev/null +++ b/crates/python-sdk/src/lib.rs @@ -0,0 +1,24 @@ +use crate::common::NodeDetails; +use crate::orchestrator::OrchestratorClient; +use crate::validator::ValidatorClient; +use crate::worker::WorkerClient; +use pyo3::prelude::*; + +mod common; +mod constants; +mod error; +mod orchestrator; +mod p2p_handler; +mod utils; +mod validator; +mod worker; + +#[pymodule] +fn primeprotocol(m: &Bound<'_, PyModule>) -> PyResult<()> { + pyo3_log::init(); + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/crates/python-sdk/src/orchestrator/mod.rs b/crates/python-sdk/src/orchestrator/mod.rs new file mode 100644 index 00000000..34f51cba --- /dev/null +++ b/crates/python-sdk/src/orchestrator/mod.rs @@ -0,0 +1,723 @@ +use crate::common::NodeDetails; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; +use crate::p2p_handler::{Message, MessageType, Service as P2PService}; +use p2p::{Keypair, PeerId}; +use pyo3::prelude::*; +use pythonize::pythonize; +use shared::web3::wallet::Wallet; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use url::Url; + +// Add new imports for discovery functionality +use shared::discovery::fetch_nodes_from_discovery_urls; + +// Add imports for invite functionality +use alloy::primitives::Address; +use operations::invite::{ + admin::{generate_invite_expiration, generate_invite_nonce, generate_invite_signature}, + common::InviteBuilder, +}; +use std::str::FromStr; + +// Add imports for compute pool contract functionality +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::wallet::WalletProvider; + +/// Prime Protocol Orchestrator Client - for managing and distributing tasks +#[pyclass] +pub struct OrchestratorClient { + runtime: Option, + wallet: Option, + cancellation_token: CancellationToken, + // P2P fields + auth_manager: Option>, + outbound_tx: Option>>>, + user_message_rx: Option>>>, + message_processor_handle: Option>, + peer_id: Option, + // Discovery service URLs + discovery_urls: Vec, + // Contracts + contracts: Option>>, +} + +#[pymethods] +impl OrchestratorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None, discovery_urls=vec!["http://localhost:8089".to_string()]))] + pub fn new( + rpc_url: String, + private_key: Option, + discovery_urls: Vec, + ) -> PyResult { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + let cancellation_token = CancellationToken::new(); + + // Create wallet if private key is provided + let wallet = if let Some(key) = private_key { + let rpc_url_parsed = Url::parse(&rpc_url).map_err(|e| { + PyErr::new::(format!("Invalid RPC URL: {}", e)) + })?; + let w = Wallet::new(&key, rpc_url_parsed) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let wallet_address = w.wallet.default_signer().address().to_string(); + log::info!("Orchestrator wallet address: {}", wallet_address); + + Some(w) + } else { + None + }; + + Ok(Self { + runtime: Some(runtime), + wallet, + cancellation_token, + auth_manager: None, + outbound_tx: None, + user_message_rx: None, + message_processor_handle: None, + peer_id: None, + discovery_urls, + contracts: None, + }) + } + + /// Initialize the orchestrator client with optional P2P support + #[pyo3(signature = (p2p_port=None))] + pub fn start(&mut self, py: Python, p2p_port: Option) -> PyResult<()> { + // Initialize contracts if not already done + if self.contracts.is_none() { + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })?; + + let wallet_provider = wallet.provider(); + + // Get runtime reference and use it before mutating self + let rt = self.runtime.as_ref().ok_or_else(|| { + PyErr::new::("Runtime not initialized") + })?; + + let contracts = py.allow_threads(|| { + rt.block_on(async { + // Build all contracts (required due to known bug) + ContractBuilder::new(wallet_provider) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build contracts: {}", + e + )) + }) + }) + })?; + + self.contracts = Some(Arc::new(contracts)); + } + + let rt = self.get_or_create_runtime()?; + + if let Some(port) = p2p_port { + // Initialize P2P if port is provided + let wallet = self + .wallet + .as_ref() + .ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })? + .clone(); + + let cancellation_token = self.cancellation_token.clone(); + + // Create the P2P components + let (auth_manager, peer_id, outbound_tx, user_message_rx, message_processor_handle) = + py.allow_threads(|| { + rt.block_on(async { + Self::create_p2p_components(wallet, port, cancellation_token) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + }) + }) + })?; + + // Update self with the created components + self.auth_manager = Some(auth_manager); + self.peer_id = Some(peer_id); + self.outbound_tx = Some(outbound_tx); + self.user_message_rx = Some(user_message_rx); + self.message_processor_handle = Some(message_processor_handle); + } + + Ok(()) + } + + /// Send a message to a peer + pub fn send_message( + &self, + py: Python, + peer_id: String, + multiaddrs: Vec, + data: Vec, + ) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let outbound_tx = self.outbound_tx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = Message { + message_type: MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: self + .wallet + .as_ref() + .map(|w| w.wallet.default_signer().address().to_string()), + is_sender_validator: false, + is_sender_pool_owner: false, // This will be determined by the receiver + response_tx: None, + }; + + py.allow_threads(|| { + rt.block_on(async { + crate::p2p_handler::send_message_with_auth(message, auth_manager, outbound_tx) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + }) + } + + /// Get the next message from the P2P network + pub fn get_next_message(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let user_message_rx = self.user_message_rx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = py.allow_threads(|| { + rt.block_on(async { + tokio::time::timeout( + crate::constants::MESSAGE_QUEUE_TIMEOUT, + user_message_rx.lock().await.recv(), + ) + .await + .ok() + .flatten() + }) + }); + + match message { + Some(msg) => { + let py_msg = pythonize(py, &msg) + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(Some(py_msg.into())) + } + None => Ok(None), + } + } + + /// Get the orchestrator's peer ID + pub fn get_peer_id(&self) -> PyResult> { + Ok(self.peer_id.map(|id| id.to_string())) + } + + /// Get the orchestrator's wallet address + pub fn get_wallet_address(&self) -> PyResult> { + Ok(self + .wallet + .as_ref() + .map(|w| w.wallet.default_signer().address().to_string())) + } + + /// Stop the orchestrator client + pub fn stop(&mut self, _py: Python) -> PyResult<()> { + self.cancellation_token.cancel(); + + // Stop message processor + if let Some(handle) = self.message_processor_handle.take() { + handle.abort(); + } + + Ok(()) + } + + /// List nodes for a specific compute pool + pub fn list_nodes_for_pool(&self, py: Python, pool_id: u32) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })?; + + let discovery_urls = self.discovery_urls.clone(); + let route = format!("/api/pool/{}", pool_id); + + let nodes = py.allow_threads(|| { + rt.block_on(async { + fetch_nodes_from_discovery_urls(&discovery_urls, &route, wallet) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + })?; + + Ok(nodes.into_iter().map(NodeDetails::from).collect()) + } + + /// Invite a node to join a compute pool + /// + /// This method creates a signed invite and sends it to the specified worker node. + /// The invite contains pool information and authentication details that the worker + /// can validate before joining the pool. + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (peer_id, worker_address, pool_id, multiaddrs, domain_id=1, orchestrator_url=None, expiration_seconds=1000))] + pub fn invite_node( + &self, + py: Python, + peer_id: String, + worker_address: String, + pool_id: u32, + multiaddrs: Vec, + domain_id: u32, + orchestrator_url: Option, + expiration_seconds: u64, + ) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })?; + + let outbound_tx = self.outbound_tx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + // Parse worker address + let worker_addr = Address::from_str(&worker_address).map_err(|e| { + PyErr::new::(format!( + "Invalid worker address: {}", + e + )) + })?; + + let wallet = wallet.clone(); + let outbound_tx = outbound_tx.clone(); + let auth_manager = auth_manager.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Generate invite parameters + let nonce = generate_invite_nonce(); + let expiration = + generate_invite_expiration(Some(expiration_seconds)).map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + // Generate the invite signature + let invite_signature = generate_invite_signature( + &wallet, + domain_id, + pool_id, + worker_addr, + nonce, + expiration, + ) + .await + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Build the invite request + let invite_builder = if let Some(url) = orchestrator_url { + InviteBuilder::with_url(pool_id, url) + } else { + // Default to localhost if no URL provided + InviteBuilder::with_url(pool_id, "http://localhost:8080".to_string()) + }; + + let invite_request = invite_builder + .build(invite_signature, nonce, expiration) + .map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + // Serialize the invite request + let invite_data = serde_json::to_vec(&invite_request).map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + // Create a general message with the invite data + let message = Message { + message_type: MessageType::General { data: invite_data }, + peer_id, + multiaddrs, + sender_address: Some(wallet.wallet.default_signer().address().to_string()), + is_sender_validator: false, + is_sender_pool_owner: false, // This will be determined by the receiver + response_tx: None, + }; + + // Send the invite + crate::p2p_handler::send_message_with_auth(message, &auth_manager, &outbound_tx) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + Ok(()) + }) + }) + } + + /// Eject a node from a compute pool + /// + /// This method removes a node from the specified compute pool. Only the pool's + /// compute manager can eject nodes from their pool. + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to eject + /// + /// Returns: + /// The transaction hash as a hex string + pub fn eject_node(&self, py: Python, pool_id: u32, node_address: String) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Call eject_node on the contract + let tx_hash = contracts_clone + .compute_pool + .eject_node(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to eject node: {}", + e + )) + })?; + + // Convert transaction hash to hex string + Ok(format!("0x{}", tx_hash)) + }) + }) + } + + /// Blacklist a node from a compute pool + /// + /// This method blacklists a node from the specified compute pool, preventing it from + /// rejoining. Only the pool's compute manager can blacklist nodes in their pool. + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to blacklist + /// + /// Returns: + /// The transaction hash as a hex string + pub fn blacklist_node( + &self, + py: Python, + pool_id: u32, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Call blacklist_node on the contract + let tx_hash = contracts_clone + .compute_pool + .blacklist_node(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to blacklist node: {}", + e + )) + })?; + + // Convert transaction hash to hex string + Ok(format!("0x{}", tx_hash)) + }) + }) + } + + /// Check if a node is blacklisted from a compute pool + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to check + /// + /// Returns: + /// True if the node is blacklisted, False otherwise + pub fn is_node_blacklisted( + &self, + py: Python, + pool_id: u32, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Check if node is blacklisted + contracts_clone + .compute_pool + .is_node_blacklisted(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to check blacklist status: {}", + e + )) + }) + }) + }) + } + + /// Get all blacklisted nodes for a compute pool + /// + /// Args: + /// pool_id: The ID of the compute pool + /// + /// Returns: + /// List of blacklisted node addresses + pub fn get_blacklisted_nodes(&self, py: Python, pool_id: u32) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Get blacklisted nodes + let nodes = contracts_clone + .compute_pool + .get_blacklisted_nodes(pool_id) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to get blacklisted nodes: {}", + e + )) + })?; + + // Convert addresses to strings + Ok(nodes.iter().map(|addr| addr.to_string()).collect()) + }) + }) + } + + /// Check if a node is in a compute pool + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to check + /// + /// Returns: + /// True if the node is in the pool, False otherwise + pub fn is_node_in_pool( + &self, + py: Python, + pool_id: u32, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Check if node is in pool + contracts_clone + .compute_pool + .is_node_in_pool(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to check if node is in pool: {}", + e + )) + }) + }) + }) + } +} + +// Private implementation methods +impl OrchestratorClient { + fn get_or_create_runtime(&self) -> PyResult<&tokio::runtime::Runtime> { + if let Some(ref rt) = self.runtime { + Ok(rt) + } else { + Err(PyErr::new::( + "Runtime not initialized. Call start() first.", + )) + } + } + + async fn create_p2p_components( + wallet: Wallet, + port: u16, + cancellation_token: CancellationToken, + ) -> Result< + ( + Arc, + PeerId, + Arc>>, + Arc>>, + JoinHandle<()>, + ), + anyhow::Error, + > { + // Initialize authentication manager + let auth_manager = Arc::new(AuthenticationManager::new(Arc::new(wallet.clone()))); + + // Create P2P service + let keypair = Keypair::generate_ed25519(); + let wallet_address = Some(wallet.wallet.default_signer().address().to_string()); + + let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); + + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = P2PService::new( + keypair, + port, + cancellation_token.clone(), + wallet_address.clone(), + )?; + + let peer_id = p2p_service.node.peer_id(); + let outbound_tx = Arc::new(Mutex::new(outbound_tx)); + let user_message_rx = Arc::new(Mutex::new(user_message_rx)); + + // Start P2P service + tokio::task::spawn(p2p_service.run()); + + // Start message processor + let config = MessageProcessorConfig { + auth_manager: auth_manager.clone(), + message_queue_rx: Arc::new(Mutex::new(message_queue_rx)), + user_message_tx, + outbound_tx: outbound_tx.clone(), + authenticated_peers, + cancellation_token: cancellation_token.clone(), + validator_addresses: Arc::new(std::collections::HashSet::new()), // Orchestrator doesn't need validator info + pool_owner_address: wallet_address + .as_ref() + .and_then(|addr| alloy::primitives::Address::from_str(addr).ok()), // Parse orchestrator's address + compute_manager_address: None, // Orchestrator doesn't need this + }; + + let message_processor = MessageProcessor::from_config(config); + let message_processor_handle = message_processor.spawn(); + + log::info!( + "P2P service started on port {} with peer ID: {:?}", + port, + peer_id + ); + + Ok(( + auth_manager, + peer_id, + outbound_tx, + user_message_rx, + message_processor_handle, + )) + } +} diff --git a/crates/python-sdk/src/p2p_handler/auth.rs b/crates/python-sdk/src/p2p_handler/auth.rs new file mode 100644 index 00000000..a6aa1c99 --- /dev/null +++ b/crates/python-sdk/src/p2p_handler/auth.rs @@ -0,0 +1,300 @@ +use crate::error::{PrimeProtocolError, Result}; +use crate::p2p_handler::Message; +use alloy::primitives::{Address, Signature}; +use rand::Rng; +use shared::security::request_signer::sign_message; +use shared::web3::wallet::Wallet; +use std::collections::{HashMap, HashSet}; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::RwLock; + +/* + * Authentication Flow: + * + * This module implements a mutual authentication protocol between peers using cryptographic signatures. + * The flow works as follows: + * + * 1. INITIATION: Peer A wants to send a message to Peer B but they're not authenticated + * - Peer A generates a random challenge (32 bytes, hex-encoded) + * - Peer A stores the challenge and queues the original message + * - Peer A sends AuthenticationInitiation{challenge} to Peer B + * + * 2. RESPONSE: Peer B receives the initiation + * - Peer B signs Peer A's challenge with their private key + * - Peer B generates their own challenge for Peer A + * - Peer B sends AuthenticationResponse{challenge: B's challenge, signature: signed A's challenge} + * + * 3. SOLUTION: Peer A receives the response + * - Peer A verifies Peer B's signature against the original challenge to get B's wallet address + * - Peer A signs Peer B's challenge with their private key + * - Peer A marks Peer B as authenticated + * - Peer A sends AuthenticationSolution{signature: signed B's challenge} + * - Peer A sends the originally queued message + * + * 4. COMPLETION: Peer B receives the solution + * - Peer B verifies Peer A's signature against their challenge to get A's wallet address + * - Peer B marks Peer A as authenticated + * - Both peers can now exchange messages freely + * + * Security Properties: + * - Mutual authentication: Both peers prove they control their private keys + * - Replay protection: Each challenge is randomly generated + * - Address verification: Signature recovery provides cryptographic proof of wallet ownership + */ + +/// Represents an ongoing authentication challenge with a peer +#[derive(Debug)] +pub struct OngoingAuthChallenge { + pub peer_wallet_address: Address, + pub auth_challenge_request_message: String, + pub outgoing_message: Message, + pub their_challenge: Option, // The challenge we received from them that we need to sign +} + +/// Manages authentication state and operations +pub struct AuthenticationManager { + /// Track ongoing authentication requests (when we initiate) + ongoing_auth_requests: Arc>>, + /// Track outgoing challenges (when they initiate and we respond) + outgoing_challenges: Arc>>, + /// Track peers we're responding to (to prevent initiating auth with them) + responding_to_peers: Arc>>, + /// Queue messages for peers we're authenticating with as responder + responder_message_queue: Arc>>>, + /// Track authenticated peers + authenticated_peers: Arc>>, + /// Track peers we're waiting for authentication acknowledgment from + pending_auth_acknowledgment: Arc>>, + /// Our wallet for signing + node_wallet: Arc, +} + +impl AuthenticationManager { + pub fn new(node_wallet: Arc) -> Self { + Self { + ongoing_auth_requests: Arc::new(RwLock::new(HashMap::new())), + outgoing_challenges: Arc::new(RwLock::new(HashMap::new())), + responding_to_peers: Arc::new(RwLock::new(HashSet::new())), + responder_message_queue: Arc::new(RwLock::new(HashMap::new())), + authenticated_peers: Arc::new(RwLock::new(HashSet::new())), + pending_auth_acknowledgment: Arc::new(RwLock::new(HashMap::new())), + node_wallet, + } + } + + /// Check if a peer is already authenticated + pub async fn is_authenticated(&self, peer_id: &str) -> bool { + self.authenticated_peers.read().await.contains(peer_id) + } + + /// Mark a peer as authenticated + pub async fn mark_authenticated(&self, peer_id: String) { + self.authenticated_peers.write().await.insert(peer_id); + } + + /// Check our role in ongoing authentication + pub async fn get_auth_role(&self, peer_id: &str) -> Option { + if self + .ongoing_auth_requests + .read() + .await + .contains_key(peer_id) + { + Some("initiator".to_string()) + } else if self.responding_to_peers.read().await.contains(peer_id) { + Some("responder".to_string()) + } else { + None + } + } + + /// Queue a message for a peer we're responding to + pub async fn queue_message_as_responder( + &self, + peer_id: String, + message: Message, + ) -> Result<()> { + let mut queue = self.responder_message_queue.write().await; + queue.entry(peer_id).or_insert_with(Vec::new).push(message); + Ok(()) + } + + /// Start authentication with a peer + pub async fn start_authentication( + &self, + peer_id: String, + outgoing_message: Message, + ) -> Result { + // Generate authentication challenge + let challenge_bytes: [u8; 32] = rand::thread_rng().gen(); + let auth_challenge_message = hex::encode(challenge_bytes); + + // Store the ongoing auth challenge + let mut ongoing_auth = self.ongoing_auth_requests.write().await; + ongoing_auth.insert( + peer_id, + OngoingAuthChallenge { + peer_wallet_address: Address::ZERO, // Will be updated when we get their signature + auth_challenge_request_message: auth_challenge_message.clone(), + outgoing_message, + their_challenge: None, + }, + ); + + Ok(auth_challenge_message) + } + + /// Handle incoming authentication initiation + pub async fn handle_auth_initiation( + &self, + peer_id: &str, + challenge: &str, + ) -> Result<(String, String)> { + // Mark that we're responding to this peer + self.responding_to_peers + .write() + .await + .insert(peer_id.to_string()); + + // Sign the challenge + let signature = sign_message(challenge, &self.node_wallet) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to sign challenge: {}", e)) + })?; + + // Generate our own challenge for the peer + let our_challenge_bytes: [u8; 32] = rand::thread_rng().gen(); + let our_challenge = hex::encode(our_challenge_bytes); + + // Store the challenge we're sending so we can verify the signature later + self.outgoing_challenges + .write() + .await + .insert(peer_id.to_string(), our_challenge.clone()); + + Ok((our_challenge, signature)) + } + + /// Handle authentication response from peer (when we initiated) + pub async fn handle_auth_response( + &self, + peer_id: &str, + their_challenge: &str, + their_signature: &str, + ) -> Result { + // Get our ongoing auth challenge + let mut ongoing_auth = self.ongoing_auth_requests.write().await; + let auth_challenge = ongoing_auth.get_mut(peer_id).ok_or_else(|| { + PrimeProtocolError::InvalidConfig(format!( + "No ongoing auth request for peer {}", + peer_id + )) + })?; + + // Verify the signature matches the challenge we sent + let parsed_signature = Signature::from_str(their_signature).map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Invalid signature format: {}", e)) + })?; + + let recovered_address = parsed_signature + .recover_address_from_msg(&auth_challenge.auth_challenge_request_message) + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to recover address: {}", e)) + })?; + + // Update the peer's wallet address and store their challenge + auth_challenge.peer_wallet_address = recovered_address; + auth_challenge.their_challenge = Some(their_challenge.to_string()); + log::debug!("Recovered peer address: {}", recovered_address); + + // Sign their challenge + let our_signature = sign_message(their_challenge, &self.node_wallet) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to sign challenge: {}", e)) + })?; + + // Store the queued message for sending after acknowledgment + if let Some(auth) = ongoing_auth.remove(peer_id) { + self.queue_for_acknowledgment(peer_id.to_string(), auth.outgoing_message) + .await; + } + + Ok(our_signature) + } + + /// Handle authentication solution from peer + pub async fn handle_auth_solution( + &self, + peer_id: &str, + signature: &str, + ) -> Result<(Address, Vec)> { + // Get the challenge we sent to this peer + let mut outgoing_challenges = self.outgoing_challenges.write().await; + let challenge = outgoing_challenges.remove(peer_id).ok_or_else(|| { + PrimeProtocolError::InvalidConfig(format!( + "No outgoing challenge found for peer {}", + peer_id + )) + })?; + + // Parse and verify the signature + let parsed_signature = Signature::from_str(signature).map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Invalid signature format: {}", e)) + })?; + + // Recover the peer's address from their signature of our challenge + let recovered_address = parsed_signature + .recover_address_from_msg(&challenge) + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to recover address: {}", e)) + })?; + + log::debug!( + "Verified auth solution from peer {} with address {}", + peer_id, + recovered_address + ); + self.mark_authenticated(peer_id.to_string()).await; + + // Clean up responding state + self.responding_to_peers.write().await.remove(peer_id); + + // Get any queued messages to send now that we're authenticated + let queued_messages = self + .responder_message_queue + .write() + .await + .remove(peer_id) + .unwrap_or_default(); + + Ok((recovered_address, queued_messages)) + } + + /// Get wallet address + pub fn wallet_address(&self) -> String { + self.node_wallet + .wallet + .default_signer() + .address() + .to_string() + } + + /// Store a message to send after receiving authentication acknowledgment + pub async fn queue_for_acknowledgment(&self, peer_id: String, message: Message) { + self.pending_auth_acknowledgment + .write() + .await + .insert(peer_id, message); + } + + /// Handle authentication acknowledgment and return queued message if any + pub async fn handle_auth_acknowledgment(&self, peer_id: &str) -> Option { + self.pending_auth_acknowledgment + .write() + .await + .remove(peer_id) + } +} diff --git a/crates/python-sdk/src/p2p_handler/common.rs b/crates/python-sdk/src/p2p_handler/common.rs new file mode 100644 index 00000000..22703dcd --- /dev/null +++ b/crates/python-sdk/src/p2p_handler/common.rs @@ -0,0 +1,83 @@ +use crate::error::{PrimeProtocolError, Result}; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::{Message, MessageType}; +use std::sync::Arc; +use tokio::sync::mpsc::Sender; +use tokio::sync::Mutex; + +/// Shared send_message function that handles authentication and message sending +pub async fn send_message_with_auth( + message: Message, + auth_manager: &Arc, + outbound_tx: &Arc>>, +) -> Result<()> { + log::debug!("Sending message to peer: {}", message.peer_id); + + // Check if we're already authenticated with this peer + if auth_manager.is_authenticated(&message.peer_id).await { + log::debug!( + "Already authenticated with peer {}, sending message directly", + message.peer_id + ); + return outbound_tx.lock().await.send(message).await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to send message: {}", e)) + }); + } + + // Not authenticated yet, check if we have ongoing authentication + log::debug!("Not authenticated with peer {}", message.peer_id); + + // Check if there's already an ongoing auth request + if let Some(role) = auth_manager.get_auth_role(&message.peer_id).await { + match role.as_str() { + "initiator" => { + return Err(PrimeProtocolError::InvalidConfig(format!( + "Already initiated authentication with peer {}", + message.peer_id + ))); + } + "responder" => { + // We're responding to their auth, queue the message + log::debug!( + "Queuing message for peer {} (we're responding to their auth)", + message.peer_id + ); + return auth_manager + .queue_message_as_responder(message.peer_id.clone(), message) + .await; + } + _ => {} + } + } + + // Extract fields we need before moving the message + let peer_id = message.peer_id.clone(); + let multiaddrs = message.multiaddrs.clone(); + + // Start authentication (takes ownership of message) + let auth_challenge = auth_manager + .start_authentication(peer_id.clone(), message) + .await?; + + // Send authentication initiation + let auth_message = Message { + message_type: MessageType::AuthenticationInitiation { + challenge: auth_challenge, + }, + peer_id, + multiaddrs, + sender_address: Some(auth_manager.wallet_address()), + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: None, + }; + + outbound_tx + .lock() + .await + .send(auth_message) + .await + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to send auth message: {}", e)) + }) +} diff --git a/crates/python-sdk/src/p2p_handler/message_processor.rs b/crates/python-sdk/src/p2p_handler/message_processor.rs new file mode 100644 index 00000000..8345e4df --- /dev/null +++ b/crates/python-sdk/src/p2p_handler/message_processor.rs @@ -0,0 +1,400 @@ +use crate::error::Result; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::{Message, MessageType}; +use alloy::primitives::Address; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + Mutex, RwLock, +}; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; + +/// Configuration for creating a MessageProcessor +pub struct MessageProcessorConfig { + pub auth_manager: Arc, + pub message_queue_rx: Arc>>, + pub user_message_tx: Sender, + pub outbound_tx: Arc>>, + pub authenticated_peers: Arc>>, + pub cancellation_token: CancellationToken, + pub validator_addresses: Arc>, + pub pool_owner_address: Option
, + pub compute_manager_address: Option
, +} + +/// Handles processing of incoming P2P messages +pub struct MessageProcessor { + auth_manager: Arc, + message_queue_rx: Arc>>, + user_message_tx: Sender, + outbound_tx: Arc>>, + authenticated_peers: Arc>>, + cancellation_token: CancellationToken, + validator_addresses: Arc>, + pool_owner_address: Option
, + compute_manager_address: Option
, +} + +impl MessageProcessor { + #[allow(clippy::too_many_arguments)] + pub fn new( + auth_manager: Arc, + message_queue_rx: Arc>>, + user_message_tx: Sender, + outbound_tx: Arc>>, + authenticated_peers: Arc>>, + cancellation_token: CancellationToken, + validator_addresses: Arc>, + pool_owner_address: Option
, + compute_manager_address: Option
, + ) -> Self { + Self { + auth_manager, + message_queue_rx, + user_message_tx, + outbound_tx, + authenticated_peers, + cancellation_token, + validator_addresses, + pool_owner_address, + compute_manager_address, + } + } + + /// Create a MessageProcessor from a config struct + pub fn from_config(config: MessageProcessorConfig) -> Self { + Self::new( + config.auth_manager, + config.message_queue_rx, + config.user_message_tx, + config.outbound_tx, + config.authenticated_peers, + config.cancellation_token, + config.validator_addresses, + config.pool_owner_address, + config.compute_manager_address, + ) + } + + /// Start the message processor as a background task + /// Returns a JoinHandle that can be used to await or abort the task + pub fn spawn(self) -> JoinHandle<()> { + tokio::task::spawn(self.run()) + } + + /// Run the message processing loop + pub async fn run(self) { + loop { + tokio::select! { + _ = self.cancellation_token.cancelled() => { + log::info!("Message processor shutting down"); + break; + } + message_result = async { + let mut rx = self.message_queue_rx.lock().await; + tokio::time::timeout( + crate::constants::MESSAGE_QUEUE_TIMEOUT, + rx.recv() + ).await + } => { + let message = match message_result { + Ok(Some(msg)) => msg, + Ok(None) => { + log::debug!("Message queue closed"); + break; + } + Err(_) => continue, // Timeout, continue loop + }; + log::debug!("Received message: {:?}", message); + + if let Err(e) = self.process_message(message).await { + log::error!("Failed to process message: {}", e); + } + } + } + } + } + + /// Process a single message + async fn process_message(&self, message: Message) -> Result<()> { + let Message { + message_type, + peer_id, + multiaddrs, + sender_address, + response_tx, + .. + } = message; + + match message_type { + MessageType::AuthenticationInitiation { challenge } => { + if let Some(tx) = response_tx { + self.handle_auth_initiation_with_response( + peer_id, + multiaddrs, + sender_address, + challenge, + tx, + ) + .await + } else { + log::error!("AuthenticationInitiation received without response tx"); + Ok(()) + } + } + MessageType::AuthenticationResponse { + challenge, + signature, + } => { + let msg = Message { + message_type: MessageType::AuthenticationResponse { + challenge: challenge.clone(), + signature: signature.clone(), + }, + peer_id, + multiaddrs, + sender_address: sender_address.clone(), + is_sender_validator: self.is_address_validator(&sender_address), + is_sender_pool_owner: self.is_address_pool_owner(&sender_address), + response_tx: None, + }; + self.handle_auth_response(msg, challenge, signature).await + } + MessageType::AuthenticationSolution { signature } => { + if let Some(tx) = response_tx { + self.handle_auth_solution_with_response( + peer_id, + multiaddrs, + sender_address, + signature, + tx, + ) + .await + } else { + log::error!("AuthenticationSolution received without response tx"); + Ok(()) + } + } + MessageType::AuthenticationComplete => { + // Authentication has been acknowledged by the peer + self.handle_auth_complete(peer_id).await + } + MessageType::General { data } => { + // Forward general messages to user + let msg = Message { + message_type: MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: sender_address.clone(), + is_sender_validator: self.is_address_validator(&sender_address), + is_sender_pool_owner: self.is_address_pool_owner(&sender_address), + response_tx: None, + }; + self.user_message_tx.send(msg).await.map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to forward message to user: {}", + e + )) + }) + } + } + } + + async fn handle_auth_initiation_with_response( + &self, + peer_id: String, + _multiaddrs: Vec, + _sender_address: Option, + challenge: String, + response_tx: tokio::sync::oneshot::Sender, + ) -> Result<()> { + let (our_challenge, signature) = self + .auth_manager + .handle_auth_initiation(&peer_id, &challenge) + .await?; + + // Send authentication response via the one-shot channel + let response = p2p::AuthenticationInitiationResponse { + message: our_challenge, + signature, + }; + + response_tx.send(response.into()).map_err(|_| { + crate::error::PrimeProtocolError::InvalidConfig( + "Failed to send auth response: receiver dropped".to_string(), + ) + }) + } + + async fn handle_auth_response( + &self, + message: Message, + challenge: String, + signature: String, + ) -> Result<()> { + let our_signature = self + .auth_manager + .handle_auth_response(&message.peer_id, &challenge, &signature) + .await?; + + // Extract and store the peer's wallet address from their signature + if let Ok(parsed_signature) = alloy::primitives::Signature::from_str(&signature) { + if let Ok(recovered_address) = parsed_signature.recover_address_from_msg(&challenge) { + self.authenticated_peers + .write() + .await + .insert(message.peer_id.clone(), recovered_address.to_string()); + } + } + + // Send authentication solution + let solution = Message { + message_type: MessageType::AuthenticationSolution { + signature: our_signature, + }, + peer_id: message.peer_id.clone(), + multiaddrs: message.multiaddrs, + sender_address: Some(self.auth_manager.wallet_address()), + is_sender_validator: false, // We're sending this message, so we check our own address + is_sender_pool_owner: false, // We're sending this message, so we check our own address + response_tx: None, + }; + + self.outbound_tx + .lock() + .await + .send(solution) + .await + .map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send auth solution: {}", + e + )) + })?; + + Ok(()) + } + + async fn handle_auth_solution_with_response( + &self, + peer_id: String, + _multiaddrs: Vec, + _sender_address: Option, + signature: String, + response_tx: tokio::sync::oneshot::Sender, + ) -> Result<()> { + let (peer_address, queued_messages) = self + .auth_manager + .handle_auth_solution(&peer_id, &signature) + .await?; + + // Store the peer's wallet address for future message handling + self.authenticated_peers + .write() + .await + .insert(peer_id.clone(), peer_address.to_string()); + + // Send any queued messages now that we're authenticated + for msg in queued_messages { + self.outbound_tx.lock().await.send(msg).await.map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send queued message: {}", + e + )) + })?; + } + + // Send the response + let response = p2p::AuthenticationSolutionResponse::Granted; + response_tx.send(response.into()).map_err(|_| { + crate::error::PrimeProtocolError::InvalidConfig( + "Failed to send auth solution response: receiver dropped".to_string(), + ) + }) + } + + async fn handle_auth_complete(&self, peer_id: String) -> Result<()> { + log::info!("Authentication complete for peer: {}", peer_id); + + // Mark peer as authenticated + self.auth_manager.mark_authenticated(peer_id.clone()).await; + + // Get and send any queued message + if let Some(queued_message) = self.auth_manager.handle_auth_acknowledgment(&peer_id).await { + log::debug!("Sending queued message to peer {}", peer_id); + self.outbound_tx + .lock() + .await + .send(queued_message) + .await + .map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send queued message: {}", + e + )) + })?; + } + + Ok(()) + } + + /// Check if an address is a validator + fn is_address_validator(&self, address: &Option) -> bool { + if let Some(addr_str) = address { + // Parse the string address to an Address type + match Address::from_str(addr_str) { + Ok(addr) => { + let is_validator = self.validator_addresses.contains(&addr); + log::debug!( + "Checking if {} is a validator: {} (validator set has {} addresses)", + addr_str, + is_validator, + self.validator_addresses.len() + ); + is_validator + } + Err(e) => { + log::warn!("Failed to parse address {}: {}", addr_str, e); + false + } + } + } else { + log::debug!("No sender address provided, cannot check validator status"); + false + } + } + + /// Check if an address is the pool owner + fn is_address_pool_owner(&self, address: &Option) -> bool { + if let Some(addr_str) = address { + // Parse the string address to an Address type + match Address::from_str(addr_str) { + Ok(addr) => { + // Check if it's the pool creator OR the compute manager + let is_creator = self.pool_owner_address == Some(addr); + let is_compute_manager = self.compute_manager_address == Some(addr); + let is_owner = is_creator || is_compute_manager; + + log::debug!( + "Checking if {} is pool owner/manager: {} (pool owner: {:?}, compute manager: {:?})", + addr_str, + is_owner, + self.pool_owner_address, + self.compute_manager_address + ); + is_owner + } + Err(e) => { + log::warn!("Failed to parse address {}: {}", addr_str, e); + false + } + } + } else { + log::debug!("No sender address provided, cannot check pool owner status"); + false + } + } +} diff --git a/crates/python-sdk/src/p2p_handler/mod.rs b/crates/python-sdk/src/p2p_handler/mod.rs new file mode 100644 index 00000000..8e9f160c --- /dev/null +++ b/crates/python-sdk/src/p2p_handler/mod.rs @@ -0,0 +1,552 @@ +use anyhow::{Context, Result}; +use p2p::{IncomingMessage, Keypair, Node, NodeBuilder, OutgoingMessage, PeerId, Protocols}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + RwLock, +}; +use tokio_util::sync::CancellationToken; + +use crate::constants::{MESSAGE_QUEUE_CHANNEL_SIZE, P2P_CHANNEL_SIZE}; + +pub(crate) mod auth; +pub(crate) mod common; +pub(crate) mod message_processor; + +pub use common::send_message_with_auth; + +// Type alias for the complex return type of Service::new +type ServiceNewResult = Result<( + Service, + Sender, + Receiver, + Arc>>, +)>; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + General { + data: Vec, + }, + AuthenticationInitiation { + challenge: String, + }, + AuthenticationResponse { + challenge: String, + signature: String, + }, + AuthenticationSolution { + signature: String, + }, + AuthenticationComplete, +} + +impl Default for MessageType { + fn default() -> Self { + MessageType::General { data: vec![] } + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Message { + pub message_type: MessageType, + pub peer_id: String, + pub multiaddrs: Vec, + pub sender_address: Option, // Ethereum address of the sender + pub is_sender_validator: bool, // Whether the sender is a validator + pub is_sender_pool_owner: bool, // Whether the sender is the pool owner + #[serde(skip)] + pub response_tx: Option>, // For sending responses to auth requests +} + +pub struct Service { + p2p_rx: Receiver, + outbound_message_tx: Sender, + incoming_message_rx: Receiver, + message_queue_tx: Sender, + pub node: Node, + cancellation_token: CancellationToken, + wallet_address: Option, // Our wallet address for authentication + // Map peer_id to their wallet address after authentication + authenticated_peers: Arc>>, +} + +impl Service { + pub(crate) fn new( + keypair: Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet_address: Option, + ) -> ServiceNewResult { + // Channels for application <-> P2P service communication + let (p2p_tx, p2p_rx) = tokio::sync::mpsc::channel(P2P_CHANNEL_SIZE); + let (message_queue_tx, message_queue_rx) = + tokio::sync::mpsc::channel(MESSAGE_QUEUE_CHANNEL_SIZE); + + let protocols = Protocols::new() + .with_general() + .with_authentication() + .with_invite(); + + let listen_addr = format!("/ip4/0.0.0.0/tcp/{}", port) + .parse() + .context("Failed to parse listen address")?; + + let (node, incoming_messages_rx, outgoing_messages_tx) = NodeBuilder::new() + .with_keypair(keypair) + .with_port(port) + .with_listen_addr(listen_addr) + .with_protocols(protocols) + .with_cancellation_token(cancellation_token.clone()) + .try_build() + .context("Failed to create P2P node")?; + + let authenticated_peers = Arc::new(RwLock::new(HashMap::new())); + + Ok(( + Self { + p2p_rx, + outbound_message_tx: outgoing_messages_tx, + incoming_message_rx: incoming_messages_rx, + message_queue_tx, + node, + cancellation_token, + wallet_address, + authenticated_peers: authenticated_peers.clone(), + }, + p2p_tx, + message_queue_rx, + authenticated_peers, + )) + } + + async fn handle_outgoing_message( + mut message: Message, + outgoing_message_tx: &Sender, + wallet_address: &Option, + ) -> Result<()> { + // Add our wallet address if not already set + if message.sender_address.is_none() { + message.sender_address = wallet_address.clone(); + } + + let req = match &message.message_type { + MessageType::General { data } => { + p2p::Request::General(p2p::GeneralRequest { data: data.clone() }) + } + MessageType::AuthenticationInitiation { challenge } => p2p::Request::Authentication( + p2p::AuthenticationRequest::Initiation(p2p::AuthenticationInitiationRequest { + message: challenge.clone(), + }), + ), + MessageType::AuthenticationResponse { + challenge: _, + signature: _, + } => { + // This message type should not be sent as a request + // It should be handled via response channels + log::error!( + "AuthenticationResponse should be sent via response channel, not as a request" + ); + return Ok(()); + } + MessageType::AuthenticationSolution { signature } => p2p::Request::Authentication( + p2p::AuthenticationRequest::Solution(p2p::AuthenticationSolutionRequest { + signature: signature.clone(), + }), + ), + MessageType::AuthenticationComplete => { + // This message type should not be sent as a request + // It should be handled via response channels + log::error!( + "AuthenticationComplete should be sent via response channel, not as a request" + ); + return Ok(()); + } + }; + + let peer_id = PeerId::from_str(&message.peer_id).context("Failed to parse peer ID")?; + + let multiaddrs = message + .multiaddrs + .iter() + .map(|addr| addr.parse()) + .collect::, _>>() + .context("Failed to parse multiaddresses")?; + + log::debug!( + "Sending message to peer: {}, multiaddrs: {:?}", + peer_id, + multiaddrs + ); + + outgoing_message_tx + .send(req.into_outgoing_message(peer_id, multiaddrs)) + .await + .context("Failed to send outgoing message")?; + + Ok(()) + } + + pub(crate) async fn run(self) -> Result<()> { + // Extract all necessary fields before the async move block + let node = self.node; + let node_cancel_token = self.cancellation_token.clone(); + let mut p2p_rx = self.p2p_rx; + let outbound_message_tx = self.outbound_message_tx; + let mut incoming_message_rx = self.incoming_message_rx; + let message_queue_tx = self.message_queue_tx; + let cancellation_token = self.cancellation_token; + let wallet_address = self.wallet_address; + let authenticated_peers = self.authenticated_peers; + + // Start the P2P node in a separate task + let node_handle = tokio::spawn(async move { + tokio::select! { + result = node.run() => { + if let Err(e) = result { + log::error!("P2P node error: {}", e); + } + } + _ = node_cancel_token.cancelled() => { + log::info!("P2P node shutdown requested"); + } + } + }); + + log::info!("P2P service started"); + + // Main event loop + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + log::info!("P2P service shutdown requested"); + break; + } + + // Handle outgoing messages from application + Some(message) = p2p_rx.recv() => { + if let Err(e) = Self::handle_outgoing_message(message, &outbound_message_tx, &wallet_address).await { + log::error!("Failed to handle outgoing message: {}", e); + } + } + + // Handle incoming messages from network + Some(incoming) = incoming_message_rx.recv() => { + let peer_id = incoming.peer; + match incoming.message { + p2p::Libp2pIncomingMessage::Request { + request_id, + request, + channel, + } => { + if let Err(e) = Self::handle_incoming_request_static( + &message_queue_tx, + &outbound_message_tx, + &authenticated_peers, + peer_id, + request_id, + request, + channel + ).await { + log::error!("Failed to handle incoming request: {}", e); + } + } + p2p::Libp2pIncomingMessage::Response { + request_id, + response, + } => { + if let Err(e) = Self::handle_incoming_response_static( + &message_queue_tx, + peer_id, + request_id, + response + ).await { + log::error!("Failed to handle incoming response: {}", e); + } + } + } + } + } + } + + // Wait for the node task to complete + if let Err(e) = node_handle.await { + log::error!("P2P node task error: {}", e); + } + + log::info!("P2P service shutdown complete"); + Ok(()) + } + + async fn handle_incoming_request_static( + message_queue_tx: &Sender, + outbound_message_tx: &Sender, + authenticated_peers: &Arc>>, + peer_id: PeerId, + request_id: T, + request: p2p::Request, + channel: p2p::ResponseChannel, + ) -> Result<()> + where + T: std::fmt::Debug, + { + log::debug!( + "Received request from peer: {} (ID: {:?})", + peer_id, + request_id + ); + + match request { + p2p::Request::General(p2p::GeneralRequest { data }) => { + log::debug!( + "Processing GeneralRequest with {} bytes from {}", + data.len(), + peer_id + ); + + // Check if peer is authenticated + let sender_address = { + let auth_peers = authenticated_peers.read().await; + match auth_peers.get(&peer_id.to_string()) { + Some(address) => address.clone(), + None => { + log::warn!("Rejecting message from unauthenticated peer: {}", peer_id); + // Send error response + let response = p2p::Response::General(p2p::GeneralResponse { + data: b"ERROR: Not authenticated".to_vec(), + }); + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send error response")?; + return Ok(()); + } + } + }; + + // Forward message to application + let message = Message { + message_type: MessageType::General { data: data.clone() }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], // TODO: Extract multiaddrs from peer info + sender_address: Some(sender_address), + is_sender_validator: false, // Will be populated by message processor + is_sender_pool_owner: false, // Will be populated by message processor + response_tx: None, // General messages don't need response channels + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward message to application: {}", e); + } + + // Send acknowledgment + let response = p2p::Response::General(p2p::GeneralResponse { + data: b"ACK".to_vec(), + }); + + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send acknowledgment")?; + } + p2p::Request::Authentication(auth_req) => { + log::debug!("Processing Authentication request from {}", peer_id); + + match auth_req { + p2p::AuthenticationRequest::Initiation(init_req) => { + // Create a one-shot channel for the response + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + // Forward to application for handling + let message = Message { + message_type: MessageType::AuthenticationInitiation { + challenge: init_req.message, + }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: Some(response_tx), // Pass the sender for response + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward auth initiation to application: {}", e); + return Ok(()); + } + + // Wait for the response from the message processor + match response_rx.await { + Ok(response) => { + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send auth response")?; + } + Err(_) => { + log::error!("Response channel closed for auth initiation"); + } + } + } + p2p::AuthenticationRequest::Solution(sol_req) => { + // Create a one-shot channel for the response + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + // Forward to application for handling + let message = Message { + message_type: MessageType::AuthenticationSolution { + signature: sol_req.signature, + }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: Some(response_tx), // Pass the sender for response + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward auth solution to application: {}", e); + return Ok(()); + } + + // Wait for the response from the message processor + match response_rx.await { + Ok(response) => { + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send auth solution response")?; + } + Err(_) => { + log::error!("Response channel closed for auth solution"); + } + } + } + } + } + _ => { + log::warn!("Received unsupported request type: {:?}", request); + } + } + + Ok(()) + } + + async fn handle_incoming_response_static( + message_queue_tx: &Sender, + peer_id: PeerId, + request_id: T, + response: p2p::Response, + ) -> Result<()> + where + T: std::fmt::Debug, + { + log::debug!( + "Received response from peer: {} (ID: {:?})", + peer_id, + request_id + ); + + match response { + p2p::Response::General(p2p::GeneralResponse { data }) => { + log::debug!("General response received: {} bytes", data.len()); + } + p2p::Response::Authentication(auth_resp) => { + log::debug!("Authentication response received from {}", peer_id); + + match auth_resp { + p2p::AuthenticationResponse::Initiation(init_resp) => { + // Forward to application + let message = Message { + message_type: MessageType::AuthenticationResponse { + challenge: init_resp.message, + signature: init_resp.signature, + }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: None, + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward auth response to application: {}", e); + } + } + p2p::AuthenticationResponse::Solution(sol_resp) => { + log::debug!("Auth solution response: {:?}", sol_resp); + // Forward the auth solution response to the message processor + match sol_resp { + p2p::AuthenticationSolutionResponse::Granted => { + // Create a special message to indicate auth is complete + let message = Message { + message_type: MessageType::AuthenticationComplete, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: None, + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!( + "Failed to forward auth complete to application: {}", + e + ); + } + } + p2p::AuthenticationSolutionResponse::Rejected => { + log::warn!("Authentication rejected by peer: {}", peer_id); + } + } + } + } + } + _ => { + log::debug!("Received other response type: {:?}", response); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_message_serialization() { + let message = Message { + message_type: MessageType::General { + data: vec![1, 2, 3, 4], + }, + peer_id: "12D3KooWExample".to_string(), + multiaddrs: vec!["/ip4/127.0.0.1/tcp/4001".to_string()], + sender_address: Some("0x1234567890123456789012345678901234567890".to_string()), + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: None, + }; + + let json = serde_json::to_string(&message).unwrap(); + let deserialized: Message = serde_json::from_str(&json).unwrap(); + + match (&message.message_type, &deserialized.message_type) { + (MessageType::General { data: data1 }, MessageType::General { data: data2 }) => { + assert_eq!(data1, data2); + } + _ => panic!("Message types don't match"), + } + assert_eq!(message.peer_id, deserialized.peer_id); + assert_eq!(message.multiaddrs, deserialized.multiaddrs); + assert_eq!(message.sender_address, deserialized.sender_address); + } +} diff --git a/crates/python-sdk/src/utils/json_parser.rs b/crates/python-sdk/src/utils/json_parser.rs new file mode 100644 index 00000000..b5ed4aa2 --- /dev/null +++ b/crates/python-sdk/src/utils/json_parser.rs @@ -0,0 +1,8 @@ +use pyo3::prelude::*; +use pythonize::pythonize; + +/// Convert a serde_json::Value to a Python object +pub fn json_to_pyobject(py: Python, value: &serde_json::Value) -> PyObject { + // pythonize handles all the conversion automatically! + pythonize(py, value).unwrap().into() +} diff --git a/crates/python-sdk/src/utils/mod.rs b/crates/python-sdk/src/utils/mod.rs new file mode 100644 index 00000000..0ab14864 --- /dev/null +++ b/crates/python-sdk/src/utils/mod.rs @@ -0,0 +1 @@ +pub(crate) mod json_parser; diff --git a/crates/python-sdk/src/validator/mod.rs b/crates/python-sdk/src/validator/mod.rs new file mode 100644 index 00000000..514ed6f5 --- /dev/null +++ b/crates/python-sdk/src/validator/mod.rs @@ -0,0 +1,577 @@ +use crate::common::NodeDetails; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; +use crate::p2p_handler::{Message, MessageType, Service as P2PService}; +use alloy::primitives::Address; +use p2p::{Keypair, PeerId}; +use pyo3::prelude::*; +use pythonize::pythonize; +use shared::models::node::DiscoveryNode; +use shared::security::request_signer::sign_request_with_nonce; +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::wallet::{Wallet, WalletProvider}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use url::Url; + +/// Prime Protocol Validator Client - for validating nodes and tasks +#[pyclass] +pub struct ValidatorClient { + runtime: Option, + wallet: Option, + discovery_urls: Vec, + cancellation_token: CancellationToken, + // P2P fields + auth_manager: Option>, + outbound_tx: Option>>>, + user_message_rx: Option>>>, + message_processor_handle: Option>, + peer_id: Option, + contracts: Option>>, +} + +#[pymethods] +impl ValidatorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key, discovery_urls=vec!["http://localhost:8089".to_string()]))] + pub fn new( + rpc_url: String, + private_key: String, + discovery_urls: Vec, + ) -> PyResult { + let rpc_url_parsed = Url::parse(&rpc_url).map_err(|e| { + PyErr::new::(format!("Invalid RPC URL: {}", e)) + })?; + let wallet = Wallet::new(&private_key, rpc_url_parsed) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + let cancellation_token = CancellationToken::new(); + + Ok(Self { + runtime: Some(runtime), + wallet: Some(wallet), + discovery_urls, + cancellation_token, + auth_manager: None, + outbound_tx: None, + user_message_rx: None, + message_processor_handle: None, + peer_id: None, + contracts: None, + }) + } + + /// List all nodes that are not validated yet + pub fn list_non_validated_nodes(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })?; + + let discovery_urls = self.discovery_urls.clone(); + + // Release the GIL while performing async operations + py.allow_threads(|| { + rt.block_on(async { + self.fetch_non_validated_nodes(&discovery_urls, wallet) + .await + }) + }) + } + + /// List all nodes with their details as Python dictionaries + pub fn list_all_nodes_dict(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })?; + + let discovery_urls = self.discovery_urls.clone(); + + // Release the GIL while performing async operations + let nodes = py.allow_threads(|| { + rt.block_on(async { self.fetch_all_nodes(&discovery_urls, wallet).await }) + })?; + + // Convert to Python dictionaries + let result: Result, _> = + nodes.into_iter().map(|node| pythonize(py, &node)).collect(); + + let python_objects = + result.map_err(|e| PyErr::new::(e.to_string()))?; + + // Convert Bound to Py + let py_objects: Vec = python_objects + .into_iter() + .map(|bound| bound.into()) + .collect(); + + Ok(py_objects) + } + + /// Get the number of non-validated nodes + pub fn get_non_validated_count(&self, py: Python) -> PyResult { + let nodes = self.list_non_validated_nodes(py)?; + Ok(nodes.len()) + } + + /// Initialize the validator client with optional P2P support + #[pyo3(signature = (p2p_port=None))] + pub fn start(&mut self, py: Python, p2p_port: Option) -> PyResult<()> { + // Initialize contracts if not already done + if self.contracts.is_none() { + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })?; + + let wallet_provider = wallet.provider(); + + // Get runtime reference and use it before mutating self + let rt = self.runtime.as_ref().ok_or_else(|| { + PyErr::new::("Runtime not initialized") + })?; + + let contracts = py.allow_threads(|| { + rt.block_on(async { + // Build all contracts (required due to known bug) + ContractBuilder::new(wallet_provider) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build contracts: {}", + e + )) + }) + }) + })?; + + self.contracts = Some(Arc::new(contracts)); + } + + if let Some(port) = p2p_port { + // Initialize P2P if port is provided + let wallet = self + .wallet + .as_ref() + .ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })? + .clone(); + + let cancellation_token = self.cancellation_token.clone(); + + // Get runtime reference for P2P initialization + let rt = self.runtime.as_ref().ok_or_else(|| { + PyErr::new::("Runtime not initialized") + })?; + + // Create the P2P components + let (auth_manager, peer_id, outbound_tx, user_message_rx, message_processor_handle) = + py.allow_threads(|| { + rt.block_on(async { + Self::create_p2p_components(wallet, port, cancellation_token) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + }) + }) + })?; + + // Update self with the created components + self.auth_manager = Some(auth_manager); + self.peer_id = Some(peer_id); + self.outbound_tx = Some(outbound_tx); + self.user_message_rx = Some(user_message_rx); + self.message_processor_handle = Some(message_processor_handle); + } + + Ok(()) + } + + /// Send a message to a peer + pub fn send_message( + &self, + py: Python, + peer_id: String, + multiaddrs: Vec, + data: Vec, + ) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let outbound_tx = self.outbound_tx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = Message { + message_type: MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: None, + is_sender_validator: true, // Validator is sending this message + is_sender_pool_owner: false, + response_tx: None, + }; + + py.allow_threads(|| { + rt.block_on(async { + crate::p2p_handler::send_message_with_auth(message, auth_manager, outbound_tx) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + }) + } + + /// Get the next message from the P2P network + pub fn get_next_message(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let user_message_rx = self.user_message_rx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = py.allow_threads(|| { + rt.block_on(async { + tokio::time::timeout( + crate::constants::MESSAGE_QUEUE_TIMEOUT, + user_message_rx.lock().await.recv(), + ) + .await + .ok() + .flatten() + }) + }); + + match message { + Some(msg) => { + let py_msg = pythonize(py, &msg) + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(Some(py_msg.into())) + } + None => Ok(None), + } + } + + /// Validate a node on the Prime Network contract + /// + /// Args: + /// node_address: The node address to validate + /// provider_address: The provider's address + /// + /// Returns: + /// Transaction hash as a string if successful + pub fn validate_node( + &self, + py: Python, + node_address: String, + provider_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + let contracts_clone = contracts.clone(); + + // Release the GIL while performing async operations + py.allow_threads(|| { + rt.block_on(async { + // Parse addresses + let provider_addr = + Address::parse_checksummed(&provider_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid provider address: {}", + e + )) + })?; + + let node_addr = Address::parse_checksummed(&node_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid node address: {}", + e + )) + })?; + + let tx_hash = contracts_clone + .prime_network + .validate_node(provider_addr, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to validate node: {}", + e + )) + })?; + + Ok(format!("0x{}", hex::encode(tx_hash))) + }) + }) + } + + /// Validate a node on the Prime Network contract with explicit addresses + /// + /// Args: + /// provider_address: The provider's address + /// node_address: The node's address + /// + /// Returns: + /// Transaction hash as a string if successful + pub fn validate_node_with_addresses( + &self, + py: Python, + provider_address: String, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + let contracts_clone = contracts.clone(); + + // Release the GIL while performing async operations + py.allow_threads(|| { + rt.block_on(async { + // Parse addresses + let provider_addr = + Address::parse_checksummed(&provider_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid provider address: {}", + e + )) + })?; + + let node_addr = Address::parse_checksummed(&node_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid node address: {}", + e + )) + })?; + + let tx_hash = contracts_clone + .prime_network + .validate_node(provider_addr, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to validate node: {}", + e + )) + })?; + + Ok(format!("0x{}", hex::encode(tx_hash))) + }) + }) + } + + /// Get the validator's peer ID + pub fn get_peer_id(&self) -> PyResult> { + Ok(self.peer_id.map(|id| id.to_string())) + } +} + +// Private implementation methods +impl ValidatorClient { + fn get_or_create_runtime(&self) -> PyResult<&tokio::runtime::Runtime> { + if let Some(ref rt) = self.runtime { + Ok(rt) + } else { + Err(PyErr::new::( + "Runtime not initialized. Call start() first.", + )) + } + } + + async fn create_p2p_components( + wallet: Wallet, + port: u16, + cancellation_token: CancellationToken, + ) -> Result< + ( + Arc, + PeerId, + Arc>>, + Arc>>, + JoinHandle<()>, + ), + anyhow::Error, + > { + // Initialize authentication manager + let auth_manager = Arc::new(AuthenticationManager::new(Arc::new(wallet.clone()))); + + // Create P2P service + let keypair = Keypair::generate_ed25519(); + let wallet_address = Some(wallet.wallet.default_signer().address().to_string()); + + let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); + + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = + P2PService::new(keypair, port, cancellation_token.clone(), wallet_address)?; + + let peer_id = p2p_service.node.peer_id(); + let outbound_tx = Arc::new(Mutex::new(outbound_tx)); + let user_message_rx = Arc::new(Mutex::new(user_message_rx)); + + // Start P2P service + tokio::task::spawn(p2p_service.run()); + + // Start message processor + let config = MessageProcessorConfig { + auth_manager: auth_manager.clone(), + message_queue_rx: Arc::new(Mutex::new(message_queue_rx)), + user_message_tx, + outbound_tx: outbound_tx.clone(), + authenticated_peers, + cancellation_token, + validator_addresses: Arc::new(std::collections::HashSet::new()), // Validator doesn't need to check other validators + pool_owner_address: None, // Validator doesn't need pool owner info + compute_manager_address: None, // Validator doesn't need this + }; + + let message_processor = MessageProcessor::from_config(config); + let message_processor_handle = message_processor.spawn(); + + log::info!( + "P2P service started on port {} with peer ID: {:?}", + port, + peer_id + ); + + Ok(( + auth_manager, + peer_id, + outbound_tx, + user_message_rx, + message_processor_handle, + )) + } + + async fn fetch_non_validated_nodes( + &self, + discovery_urls: &[String], + wallet: &Wallet, + ) -> PyResult> { + let nodes = self.fetch_all_nodes(discovery_urls, wallet).await?; + Ok(nodes + .into_iter() + .filter(|node| !node.is_validated) + .map(NodeDetails::from) + .collect()) + } + + async fn fetch_all_nodes( + &self, + discovery_urls: &[String], + wallet: &Wallet, + ) -> PyResult> { + let mut all_nodes = Vec::new(); + let mut any_success = false; + + for discovery_url in discovery_urls { + match self + .fetch_nodes_from_discovery_url(discovery_url, wallet) + .await + { + Ok(nodes) => { + all_nodes.extend(nodes); + any_success = true; + } + Err(e) => { + // Log error but continue with other discovery services + log::error!("Failed to fetch nodes from {}: {}", discovery_url, e); + } + } + } + + if !any_success { + return Err(PyErr::new::( + "Failed to fetch nodes from all discovery services", + )); + } + + // Remove duplicates based on node ID + let mut unique_nodes = Vec::new(); + let mut seen_ids = std::collections::HashSet::new(); + for node in all_nodes { + if seen_ids.insert(node.node.id.clone()) { + unique_nodes.push(node); + } + } + + Ok(unique_nodes) + } + + async fn fetch_nodes_from_discovery_url( + &self, + discovery_url: &str, + wallet: &Wallet, + ) -> Result, anyhow::Error> { + let address = wallet.wallet.default_signer().address().to_string(); + + let discovery_route = "/api/validator"; + let signature = sign_request_with_nonce(discovery_route, wallet, None) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "x-address", + reqwest::header::HeaderValue::from_str(&address)?, + ); + headers.insert( + "x-signature", + reqwest::header::HeaderValue::from_str(&signature.signature)?, + ); + + let response = reqwest::Client::new() + .get(format!("{}{}", discovery_url, discovery_route)) + .query(&[("nonce", signature.nonce)]) + .headers(headers) + .timeout(Duration::from_secs(10)) + .send() + .await?; + + let response_text = response.text().await?; + let parsed_response: shared::models::api::ApiResponse> = + serde_json::from_str(&response_text)?; + + if !parsed_response.success { + return Err(anyhow::anyhow!( + "Failed to fetch nodes from {}: {:?}", + discovery_url, + parsed_response + )); + } + + Ok(parsed_response.data) + } +} diff --git a/crates/python-sdk/src/worker/blockchain.rs b/crates/python-sdk/src/worker/blockchain.rs new file mode 100644 index 00000000..0ed4ea5b --- /dev/null +++ b/crates/python-sdk/src/worker/blockchain.rs @@ -0,0 +1,453 @@ +use alloy::primitives::utils::format_ether; +use alloy::primitives::{Address, U256}; +use anyhow::{Context, Result}; +use operations::operations::compute_node::ComputeNodeOperations; +use operations::operations::provider::ProviderOperations; +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::contracts::structs::compute_pool::PoolStatus; +use shared::web3::wallet::{Wallet, WalletProvider}; +use std::sync::Arc; +use url::Url; + +use crate::constants::{BLOCKCHAIN_OPERATION_TIMEOUT, DEFAULT_COMPUTE_UNITS}; + +/// Configuration for blockchain operations +pub struct BlockchainConfig { + pub rpc_url: String, + pub compute_pool_id: u64, + pub private_key_provider: String, + pub private_key_node: String, + pub auto_accept_transactions: bool, + pub funding_retry_count: u32, +} + +/// Handles all blockchain-related operations for the worker +pub struct BlockchainService { + config: BlockchainConfig, + provider_wallet: Option, + node_wallet: Option, + contracts: Option>, +} + +impl BlockchainService { + pub fn new(config: BlockchainConfig) -> Result { + // Validate RPC URL + Url::parse(&config.rpc_url).context("Invalid RPC URL format")?; + + Ok(Self { + config, + provider_wallet: None, + node_wallet: None, + contracts: None, + }) + } + + /// Get the node wallet (used for authentication) + pub fn node_wallet(&self) -> Option<&Wallet> { + self.node_wallet.as_ref() + } + + /// Get the provider wallet + pub fn provider_wallet(&self) -> Option<&Wallet> { + self.provider_wallet.as_ref() + } + + /// Initialize blockchain components and ensure the node is properly registered + pub async fn initialize(&mut self) -> Result<()> { + let (provider_wallet, node_wallet, contracts) = self.create_wallets_and_contracts().await?; + + // Store the wallets + self.provider_wallet = Some(provider_wallet.clone()); + self.node_wallet = Some(node_wallet.clone()); + self.contracts = Some(contracts.clone()); + + self.wait_for_active_pool(&contracts).await?; + self.ensure_provider_registered(&provider_wallet, &contracts) + .await?; + self.ensure_compute_node_registered(&provider_wallet, &node_wallet, &contracts) + .await?; + + Ok(()) + } + + async fn create_wallets_and_contracts( + &self, + ) -> Result<(Wallet, Wallet, Contracts)> { + let rpc_url = Url::parse(&self.config.rpc_url)?; + + let provider_wallet = Wallet::new(&self.config.private_key_provider, rpc_url.clone()) + .map_err(|e| anyhow::anyhow!("Failed to create provider wallet: {}", e))?; + + let node_wallet = Wallet::new(&self.config.private_key_node, rpc_url.clone()) + .map_err(|e| anyhow::anyhow!("Failed to create node wallet: {}", e))?; + + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .context("Failed to build contracts")?; + + Ok((provider_wallet, node_wallet, contracts)) + } + + async fn wait_for_active_pool( + &self, + contracts: &Contracts, + ) -> Result { + loop { + match contracts + .compute_pool + .get_pool_info(U256::from(self.config.compute_pool_id)) + .await + { + Ok(pool) if pool.status == PoolStatus::ACTIVE => { + log::info!("Pool {} is active", self.config.compute_pool_id); + return Ok(pool); + } + Ok(pool) => { + log::info!( + "Pool {} is not active yet (status: {:?}), waiting...", + self.config.compute_pool_id, + pool.status + ); + tokio::time::sleep(crate::constants::POOL_STATUS_CHECK_INTERVAL).await; + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to get pool info: {}", e)); + } + } + } + } + + async fn ensure_provider_registered( + &self, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let provider_ops = ProviderOperations::new( + provider_wallet.clone(), + contracts.clone(), + self.config.auto_accept_transactions, + ); + + let provider_exists = provider_ops + .check_provider_exists() + .await + .map_err(|e| anyhow::anyhow!("Failed to check if provider exists: {}", e))?; + + let is_whitelisted = provider_ops + .check_provider_whitelisted() + .await + .map_err(|e| anyhow::anyhow!("Failed to check provider whitelist status: {}", e))?; + + if !provider_exists || !is_whitelisted { + self.register_provider(&provider_ops, contracts).await?; + } else { + log::info!("Provider is already registered and whitelisted"); + } + + self.ensure_adequate_stake(&provider_ops, provider_wallet, contracts) + .await?; + + Ok(()) + } + + async fn register_provider( + &self, + provider_ops: &ProviderOperations, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts + .stake_manager + .as_ref() + .context("Stake manager not initialized")?; + + let compute_units = U256::from(DEFAULT_COMPUTE_UNITS); + let required_stake = stake_manager + .calculate_stake(compute_units, U256::from(0)) + .await + .map_err(|e| anyhow::anyhow!("Failed to calculate required stake: {}", e))?; + + log::info!( + "Required stake for registration: {}", + format_ether(required_stake) + ); + + tokio::time::timeout( + BLOCKCHAIN_OPERATION_TIMEOUT, + provider_ops.retry_register_provider( + required_stake, + self.config.funding_retry_count, + None, + ), + ) + .await + .context("Provider registration timed out")? + .map_err(|e| anyhow::anyhow!("Failed to register provider: {}", e))?; + + log::info!("Provider registered successfully"); + Ok(()) + } + + async fn ensure_adequate_stake( + &self, + provider_ops: &ProviderOperations, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts + .stake_manager + .as_ref() + .context("Stake manager not initialized")?; + + let provider_address = provider_wallet.wallet.default_signer().address(); + + let provider_total_compute = contracts + .compute_registry + .get_provider_total_compute(provider_address) + .await + .map_err(|e| anyhow::anyhow!("Failed to get provider total compute: {}", e))?; + + let provider_stake = stake_manager + .get_stake(provider_address) + .await + .unwrap_or_default(); + + let compute_units = U256::from(DEFAULT_COMPUTE_UNITS); + let required_stake = stake_manager + .calculate_stake(compute_units, provider_total_compute) + .await + .map_err(|e| anyhow::anyhow!("Failed to calculate required stake: {}", e))?; + + if required_stake > provider_stake { + log::info!( + "Increasing provider stake. Required: {} tokens, Current: {} tokens", + format_ether(required_stake), + format_ether(provider_stake) + ); + + tokio::time::timeout( + BLOCKCHAIN_OPERATION_TIMEOUT, + provider_ops.increase_stake(required_stake - provider_stake), + ) + .await + .context("Stake increase timed out")? + .map_err(|e| anyhow::anyhow!("Failed to increase stake: {}", e))?; + + log::info!("Successfully increased stake"); + } + + Ok(()) + } + + async fn ensure_compute_node_registered( + &self, + provider_wallet: &Wallet, + node_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let compute_node_ops = + ComputeNodeOperations::new(provider_wallet, node_wallet, contracts.clone()); + + let compute_node_exists = compute_node_ops + .check_compute_node_exists() + .await + .map_err(|e| anyhow::anyhow!("Failed to check if compute node exists: {}", e))?; + + if compute_node_exists { + log::info!("Compute node is already registered"); + return Ok(()); + } + + let compute_units = U256::from(DEFAULT_COMPUTE_UNITS); + + tokio::time::timeout( + BLOCKCHAIN_OPERATION_TIMEOUT, + compute_node_ops.add_compute_node(compute_units), + ) + .await + .context("Compute node registration timed out")? + .map_err(|e| anyhow::anyhow!("Failed to register compute node: {}", e))?; + + log::info!("Compute node registered successfully"); + Ok(()) + } + + /// Join compute pool using an invite + pub async fn join_compute_pool_with_invite(&self, invite: &p2p::InviteRequest) -> Result<()> { + use shared::web3::contracts::core::builder::ContractBuilder; + use shared::web3::contracts::helpers::utils::retry_call; + + let invite_bytes = hex::decode(&invite.invite).context("Failed to decode invite")?; + + if invite_bytes.len() < 65 { + anyhow::bail!("Invite data is too short"); + } + + let bytes_array: [u8; 65] = invite_bytes[..65] + .try_into() + .map_err(|_| anyhow::anyhow!("Failed to convert invite to array"))?; + + // Get wallets + let provider_wallet = self + .provider_wallet + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Provider wallet not initialized"))?; + let node_wallet = self + .node_wallet + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Node wallet not initialized"))?; + + // Create contracts + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .context("Failed to build contracts")?; + + let pool_id = alloy::primitives::U256::from(invite.pool_id); + let provider_address = provider_wallet.wallet.default_signer().address(); + let node_address = vec![node_wallet.wallet.default_signer().address()]; + let signatures = vec![alloy::primitives::FixedBytes::from(&bytes_array)]; + + let call = contracts + .compute_pool + .build_join_compute_pool_call( + pool_id, + provider_address, + node_address, + vec![invite.nonce], + vec![invite.expiration], + signatures, + ) + .map_err(|e| anyhow::anyhow!("Failed to build join compute pool call: {}", e))?; + + let result = retry_call(call, 3, provider_wallet.provider.clone(), None) + .await + .context("Failed to join compute pool")?; + + log::info!("Successfully joined compute pool with tx: {}", result); + Ok(()) + } + + /// Get all validator addresses from the PrimeNetwork contract + pub async fn get_validator_addresses(&self) -> Arc> { + let contracts = match self.contracts.as_ref() { + Some(contracts) => contracts, + None => { + log::error!("Contracts not initialized"); + return Arc::new(std::collections::HashSet::new()); + } + }; + + match contracts.prime_network.get_validator_role().await { + Ok(validators) => { + log::info!( + "Fetched {} validator addresses from chain", + validators.len() + ); + let validator_set: std::collections::HashSet
= validators + .into_iter() + .inspect(|&addr| { + log::debug!("Validator address: {:?}", addr); + }) + .collect(); + log::info!("Validator addresses: {:?}", validator_set); + Arc::new(validator_set) + } + Err(e) => { + log::error!("Failed to get validator addresses: {}", e); + Arc::new(std::collections::HashSet::new()) + } + } + } + + /// Get the pool owner address for the current compute pool + pub async fn get_pool_owner_address(&self) -> Option
{ + let contracts = match self.contracts.as_ref() { + Some(contracts) => contracts, + None => { + log::error!("Contracts not initialized"); + return None; + } + }; + + log::info!( + "Fetching pool owner for pool ID: {}", + self.config.compute_pool_id + ); + + match contracts + .compute_pool + .get_pool_info(U256::from(self.config.compute_pool_id)) + .await + { + Ok(pool_info) => { + log::info!( + "Pool {} owner address: {:?}", + self.config.compute_pool_id, + pool_info.creator + ); + log::info!( + "Pool {} owner address (checksummed): {}", + self.config.compute_pool_id, + pool_info.creator + ); + Some(pool_info.creator) + } + Err(e) => { + log::error!( + "Failed to get pool info for pool {}: {}", + self.config.compute_pool_id, + e + ); + None + } + } + } + + /// Get the compute manager address for the current compute pool + pub async fn get_compute_manager_address(&self) -> Option
{ + let contracts = match self.contracts.as_ref() { + Some(contracts) => contracts, + None => { + log::error!("Contracts not initialized"); + return None; + } + }; + + match contracts + .compute_pool + .get_pool_info(U256::from(self.config.compute_pool_id)) + .await + { + Ok(pool_info) => { + log::info!( + "Pool {} compute manager address: {:?}", + self.config.compute_pool_id, + pool_info.compute_manager_key + ); + log::info!( + "Pool {} compute manager address (checksummed): {}", + self.config.compute_pool_id, + pool_info.compute_manager_key + ); + Some(pool_info.compute_manager_key) + } + Err(e) => { + log::error!( + "Failed to get pool info for pool {}: {}", + self.config.compute_pool_id, + e + ); + None + } + } + } +} diff --git a/crates/python-sdk/src/worker/client.rs b/crates/python-sdk/src/worker/client.rs new file mode 100644 index 00000000..ec69e448 --- /dev/null +++ b/crates/python-sdk/src/worker/client.rs @@ -0,0 +1,569 @@ +use crate::constants::{DEFAULT_FUNDING_RETRY_COUNT, MESSAGE_QUEUE_TIMEOUT, P2P_SHUTDOWN_TIMEOUT}; +use crate::error::{PrimeProtocolError, Result}; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; +use crate::worker::blockchain::{BlockchainConfig, BlockchainService}; +use crate::worker::p2p_handler::{Message, MessageType, Service as P2PService}; +use alloy::primitives::Address; +use p2p::{Keypair, PeerId}; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use url::Url; + +/// Core worker client that handles P2P networking and blockchain operations +pub struct WorkerClientCore { + config: ClientConfig, + blockchain_service: Option, + p2p_state: P2PState, + auth_manager: Option>, + cancellation_token: CancellationToken, + // Message processing + user_message_tx: Option>, + user_message_rx: Option>>>, + message_processor_handle: Option>, + // Validator and pool owner info + validator_addresses: Option>>, + pool_owner_address: Option
, + compute_manager_address: Option
, +} + +/// Configuration for the worker client +struct ClientConfig { + rpc_url: String, + compute_pool_id: u64, + private_key_provider: Option, + private_key_node: Option, + auto_accept_transactions: bool, + funding_retry_count: u32, + p2p_port: u16, +} + +/// P2P networking state +struct P2PState { + keypair: Keypair, + peer_id: Option, + outbound_tx: Option>>>, + message_queue_rx: Option>>>, + handle: Option>>, + authenticated_peers: + Option>>>, +} + +impl WorkerClientCore { + /// Create a new worker client + #[allow(clippy::too_many_arguments)] + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + auto_accept_transactions: Option, + funding_retry_count: Option, + cancellation_token: CancellationToken, + p2p_port: u16, + ) -> Result { + // Validate inputs + if rpc_url.is_empty() { + return Err(PrimeProtocolError::InvalidConfig( + "RPC URL cannot be empty".to_string(), + )); + } + + Url::parse(&rpc_url) + .map_err(|_| PrimeProtocolError::InvalidConfig("Invalid RPC URL format".to_string()))?; + + let config = ClientConfig { + rpc_url, + compute_pool_id, + private_key_provider, + private_key_node, + auto_accept_transactions: auto_accept_transactions.unwrap_or(true), + funding_retry_count: funding_retry_count.unwrap_or(DEFAULT_FUNDING_RETRY_COUNT), + p2p_port, + }; + + let p2p_state = P2PState { + keypair: Keypair::generate_ed25519(), + peer_id: None, + outbound_tx: None, + message_queue_rx: None, + handle: None, + authenticated_peers: None, + }; + + // Create user message channel + let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); + + Ok(Self { + config, + blockchain_service: None, + p2p_state, + auth_manager: None, + cancellation_token, + user_message_tx: Some(user_message_tx), + user_message_rx: Some(Arc::new(Mutex::new(user_message_rx))), + message_processor_handle: None, + validator_addresses: None, + pool_owner_address: None, + compute_manager_address: None, + }) + } + + /// Start the worker client asynchronously + pub async fn start_async(&mut self) -> Result<()> { + log::info!("Starting WorkerClient"); + + self.initialize_blockchain().await?; + self.initialize_validator_and_pool_info().await?; + self.initialize_auth_manager()?; + self.start_p2p_service().await?; + self.start_message_processor().await?; + + log::info!("WorkerClient started successfully"); + Ok(()) + } + + /// Stop the worker client asynchronously + pub async fn stop_async(&mut self) -> Result<()> { + log::info!("Stopping worker client..."); + self.cancellation_token.cancel(); + + // Stop message processor + if let Some(handle) = self.message_processor_handle.take() { + handle.abort(); + } + + // Give P2P service time to shutdown gracefully + tokio::time::sleep(P2P_SHUTDOWN_TIMEOUT).await; + + // Stop P2P service + if let Some(handle) = self.p2p_state.handle.take() { + let _ = tokio::time::timeout(P2P_SHUTDOWN_TIMEOUT, handle).await; + } + + log::info!("Worker client stopped"); + Ok(()) + } + + /// Get the peer ID of this node + pub fn get_peer_id(&self) -> Option { + self.p2p_state.peer_id + } + + /// Get the next message from the P2P network + pub async fn get_next_message(&self) -> Option { + let rx = self.user_message_rx.as_ref()?; + + match tokio::time::timeout(MESSAGE_QUEUE_TIMEOUT, rx.lock().await.recv()).await { + Ok(Some(message)) => { + // Check if it's an invite and process it automatically + if let MessageType::General { ref data } = message.message_type { + if let Ok(invite) = serde_json::from_slice::(data) { + log::info!("Received invite from peer: {}", message.peer_id); + + // Check if invite has expired + if let Ok(true) = operations::invite::worker::is_invite_expired(&invite) { + log::warn!("Received expired invite from peer: {}", message.peer_id); + return Some(message); // Return it so user can see the expired invite + } + + // Verify pool ID matches + if invite.pool_id != self.config.compute_pool_id as u32 { + log::warn!( + "Received invite for wrong pool: expected {}, got {}", + self.config.compute_pool_id, + invite.pool_id + ); + return Some(message); // Return it so user can see the wrong pool invite + } + + // Process the invite automatically + if let Some(blockchain_service) = &self.blockchain_service { + match blockchain_service + .join_compute_pool_with_invite(&invite) + .await + { + Ok(()) => { + log::info!( + "Successfully processed invite and joined compute pool" + ); + // Don't return the invite message since we handled it + return None; + } + Err(e) => { + log::error!("Failed to process invite: {}", e); + // Return the message so user knows about the failed invite + } + } + } + } + } + Some(message) + } + Ok(None) => None, + Err(_) => None, // Timeout + } + } + + /// Send a message to a peer + pub async fn send_message(&self, message: Message) -> Result<()> { + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Authentication manager not initialized".to_string()) + })?; + + let tx = self.p2p_state.outbound_tx.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) + })?; + + crate::p2p_handler::send_message_with_auth(message, auth_manager, tx).await + } + + // Private helper methods + + async fn initialize_blockchain(&mut self) -> Result<()> { + let blockchain_config = BlockchainConfig { + rpc_url: self.config.rpc_url.clone(), + compute_pool_id: self.config.compute_pool_id, + private_key_provider: self.get_private_key_provider()?, + private_key_node: self.get_private_key_node()?, + auto_accept_transactions: self.config.auto_accept_transactions, + funding_retry_count: self.config.funding_retry_count, + }; + + // Create blockchain service - wallets are created internally + let mut blockchain_service = BlockchainService::new(blockchain_config).map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to create blockchain service: {}", + e + )) + })?; + + blockchain_service.initialize().await.map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to initialize blockchain: {}", e)) + })?; + + self.blockchain_service = Some(blockchain_service); + Ok(()) + } + + async fn initialize_validator_and_pool_info(&mut self) -> Result<()> { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + + let validator_addresses = blockchain_service.get_validator_addresses().await; + self.validator_addresses = Some(validator_addresses); + + let pool_owner_address = blockchain_service.get_pool_owner_address().await; + self.pool_owner_address = pool_owner_address; + + let compute_manager_address = blockchain_service.get_compute_manager_address().await; + self.compute_manager_address = compute_manager_address; + + log::info!( + "Validator addresses: {:?}", + self.validator_addresses.as_ref().map(|s| s.len()) + ); + log::info!("Pool owner address: {:?}", self.pool_owner_address); + log::info!( + "Compute manager address: {:?}", + self.compute_manager_address + ); + + Ok(()) + } + + fn initialize_auth_manager(&mut self) -> Result<()> { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + + let node_wallet = blockchain_service.node_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Node wallet not initialized".to_string()) + })?; + + self.auth_manager = Some(Arc::new(AuthenticationManager::new(Arc::new( + node_wallet.clone(), + )))); + Ok(()) + } + + async fn start_p2p_service(&mut self) -> Result<()> { + // Get wallet address from auth manager + let wallet_address = self.auth_manager.as_ref().map(|auth| auth.wallet_address()); + + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = P2PService::new( + self.p2p_state.keypair.clone(), + self.config.p2p_port, + self.cancellation_token.clone(), + wallet_address, + ) + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to create P2P service: {}", e)) + })?; + + self.p2p_state.peer_id = Some(p2p_service.node.peer_id()); + self.p2p_state.outbound_tx = Some(Arc::new(Mutex::new(outbound_tx))); + self.p2p_state.message_queue_rx = Some(Arc::new(Mutex::new(message_queue_rx))); + self.p2p_state.authenticated_peers = Some(authenticated_peers); + + log::info!( + "P2P service initialized with peer ID: {:?}", + self.p2p_state.peer_id + ); + + self.p2p_state.handle = Some(tokio::task::spawn(p2p_service.run())); + Ok(()) + } + + async fn start_message_processor(&mut self) -> Result<()> { + // Build the message processor configuration + let config = self.build_message_processor_config()?; + + // Create and spawn the message processor + let message_processor = MessageProcessor::from_config(config); + self.message_processor_handle = Some(message_processor.spawn()); + + log::info!("Message processor started"); + Ok(()) + } + + /// Build configuration for the message processor + /// This method is public to allow reuse in other crates + pub fn build_message_processor_config(&self) -> Result { + let message_queue_rx = self + .p2p_state + .message_queue_rx + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) + })? + .clone(); + + let user_message_tx = self + .user_message_tx + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "User message channel not initialized".to_string(), + ) + })? + .clone(); + + let auth_manager = self + .auth_manager + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "Authentication manager not initialized".to_string(), + ) + })? + .clone(); + + let outbound_tx = self + .p2p_state + .outbound_tx + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) + })? + .clone(); + + let authenticated_peers = self + .p2p_state + .authenticated_peers + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "Authenticated peers map not initialized".to_string(), + ) + })? + .clone(); + + let validator_addresses = self + .validator_addresses + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Validator addresses not initialized".to_string()) + })? + .clone(); + + let pool_owner_address = self.pool_owner_address.ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Pool owner address not initialized".to_string()) + })?; + + let compute_manager_address = self.compute_manager_address; + + Ok(MessageProcessorConfig { + auth_manager, + message_queue_rx, + user_message_tx, + outbound_tx, + authenticated_peers, + cancellation_token: self.cancellation_token.clone(), + validator_addresses, + pool_owner_address: Some(pool_owner_address), + compute_manager_address, + }) + } + + /// Get the provider's Ethereum address + pub fn get_provider_address(&self) -> Result { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + let provider_wallet = blockchain_service.provider_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Provider wallet not initialized".to_string()) + })?; + Ok(provider_wallet + .wallet + .default_signer() + .address() + .to_string()) + } + + // todo: move to blockchain service + pub fn get_node_address(&self) -> Result { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + let node_wallet = blockchain_service.node_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Node wallet not initialized".to_string()) + })?; + Ok(node_wallet.wallet.default_signer().address().to_string()) + } + + /// Get the compute pool ID + pub fn get_compute_pool_id(&self) -> u64 { + self.config.compute_pool_id + } + + /// Get the listening multiaddresses from the P2P service + pub async fn get_listening_addresses(&self) -> Vec { + // For now, return a simple localhost address with the configured port + // In the future, this could query the actual P2P service for its listen addresses + vec![format!("/ip4/0.0.0.0/tcp/{}", self.config.p2p_port)] + } + + /// Upload node information to discovery services + pub async fn upload_to_discovery( + &self, + node_info: &crate::worker::discovery::SimpleNode, + discovery_urls: &[String], + ) -> Result<()> { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + + let node_wallet = blockchain_service.node_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Provider wallet not initialized".to_string()) + })?; + + let node_json = node_info.to_json_value(); + + shared::discovery::upload_node_to_discovery(discovery_urls, &node_json, node_wallet) + .await + .map_err(|e| { + PrimeProtocolError::RuntimeError(format!("Failed to upload to discovery: {}", e)) + }) + } + + fn get_private_key_provider(&self) -> Result { + self.config + .private_key_provider + .clone() + .or_else(|| std::env::var("PRIVATE_KEY_PROVIDER").ok()) + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "PRIVATE_KEY_PROVIDER must be set either as parameter or environment variable" + .to_string(), + ) + }) + } + + fn get_private_key_node(&self) -> Result { + self.config + .private_key_node + .clone() + .or_else(|| std::env::var("PRIVATE_KEY_NODE").ok()) + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "PRIVATE_KEY_NODE must be set either as parameter or environment variable" + .to_string(), + ) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + fn create_test_config() -> (String, String) { + // Standard Anvil blockchain keys for local testing + let node_key = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + (node_key.to_string(), provider_key.to_string()) + } + + #[test] + fn test_client_creation() { + let (node_key, provider_key) = create_test_config(); + let cancellation_token = CancellationToken::new(); + + let result = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key), + Some(node_key), + None, + None, + cancellation_token, + 8000, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_invalid_rpc_url() { + let (node_key, provider_key) = create_test_config(); + let cancellation_token = CancellationToken::new(); + + let result = WorkerClientCore::new( + 0, + "invalid-url".to_string(), + Some(provider_key), + Some(node_key), + None, + None, + cancellation_token, + 8000, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_empty_rpc_url() { + let (node_key, provider_key) = create_test_config(); + let cancellation_token = CancellationToken::new(); + + let result = WorkerClientCore::new( + 0, + "".to_string(), + Some(provider_key), + Some(node_key), + None, + None, + cancellation_token, + 8000, + ); + + assert!(result.is_err()); + } +} diff --git a/crates/python-sdk/src/worker/discovery.rs b/crates/python-sdk/src/worker/discovery.rs new file mode 100644 index 00000000..9be96eb3 --- /dev/null +++ b/crates/python-sdk/src/worker/discovery.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +/// Simple node information for discovery upload +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleNode { + /// Unique identifier for the node + pub id: String, + /// The external IP address of the node + pub ip_address: String, + /// The port for the node (usually 0 for dynamic port allocation) + pub port: u16, + /// The Ethereum address of the node + pub provider_address: String, + /// The compute pool ID this node belongs to + pub compute_pool_id: u32, + /// The P2P peer ID + pub worker_p2p_id: String, + /// The P2P multiaddresses for connecting to this node + pub worker_p2p_addresses: Vec, +} + +impl SimpleNode { + pub fn new( + ip_address: String, + port: u16, + provider_wallet_address: String, + node_wallet_address: String, + compute_pool_id: u32, + peer_id: String, + multi_addresses: Vec, + ) -> Self { + Self { + id: node_wallet_address, + ip_address, + port, + provider_address: provider_wallet_address, + compute_pool_id, + worker_p2p_id: peer_id, + worker_p2p_addresses: multi_addresses, + } + } + + /// Convert to serde_json::Value for upload + pub fn to_json_value(&self) -> serde_json::Value { + serde_json::json!({ + "id": self.id, + "ip_address": self.ip_address, + "port": self.port, + "provider_address": self.provider_address, + "compute_pool_id": self.compute_pool_id, + "worker_p2p_id": self.worker_p2p_id, + "worker_p2p_addresses": self.worker_p2p_addresses, + }) + } +} diff --git a/crates/python-sdk/src/worker/mod.rs b/crates/python-sdk/src/worker/mod.rs new file mode 100644 index 00000000..bd720c9a --- /dev/null +++ b/crates/python-sdk/src/worker/mod.rs @@ -0,0 +1,238 @@ +use pyo3::prelude::*; + +mod blockchain; +mod client; +mod discovery; + +use crate::constants::DEFAULT_P2P_PORT; +use crate::p2p_handler; +use crate::p2p_handler::Message; +pub(crate) use client::WorkerClientCore; +use tokio_util::sync::CancellationToken; + +/// Prime Protocol Worker Client - for compute nodes that execute tasks +#[pyclass] +pub(crate) struct WorkerClient { + inner: WorkerClientCore, + runtime: Option, + cancellation_token: CancellationToken, +} + +#[pymethods] +impl WorkerClient { + #[new] + #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None, p2p_port=DEFAULT_P2P_PORT))] + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + p2p_port: u16, + ) -> PyResult { + let cancellation_token = CancellationToken::new(); + let inner = WorkerClientCore::new( + compute_pool_id, + rpc_url, + private_key_provider, + private_key_node, + None, + None, + cancellation_token.clone(), + p2p_port, + ) + .map_err(to_py_err)?; + + Ok(Self { + inner, + runtime: None, + cancellation_token, + }) + } + + /// Start the worker client + pub fn start(&mut self, py: Python) -> PyResult<()> { + if self.runtime.is_some() { + return Err(to_py_runtime_err("Client already started")); + } + + let rt = create_runtime()?; + let result = py.allow_threads(|| rt.block_on(self.inner.start_async())); + + self.runtime = Some(rt); + result.map_err(to_py_err) + } + + /// Get the next message from the P2P network + pub fn get_next_message(&self, py: Python) -> PyResult> { + let rt = self.ensure_runtime()?; + + Ok(py.allow_threads(|| { + rt.block_on(self.inner.get_next_message()) + .map(message_to_pyobject) + })) + } + + /// Send a message to a peer + pub fn send_message( + &self, + peer_id: String, + data: Vec, + multiaddrs: Vec, + py: Python, + ) -> PyResult<()> { + let rt = self.ensure_runtime()?; + + let message = Message { + message_type: p2p_handler::MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: None, // Will be filled from our wallet automatically + is_sender_validator: false, + is_sender_pool_owner: false, + response_tx: None, + }; + + py.allow_threads(|| rt.block_on(self.inner.send_message(message))) + .map_err(to_py_err) + } + + /// Get this node's peer ID + pub fn get_own_peer_id(&self) -> PyResult> { + self.ensure_runtime()?; + Ok(self.inner.get_peer_id().map(|id| id.to_string())) + } + + /// Upload node information to discovery services + /// + /// Args: + /// ip_address: External IP address of the node + /// discovery_urls: List of discovery service URLs (defaults to ["http://localhost:8089"]) + pub fn upload_to_discovery( + &self, + ip_address: String, + discovery_urls: Option>, + py: Python, + ) -> PyResult<()> { + let rt = self.ensure_runtime()?; + + // Get the peer ID and multiaddresses from the P2P service + let peer_id = self + .inner + .get_peer_id() + .ok_or_else(|| to_py_err("P2P service not started. Cannot get peer ID"))? + .to_string(); + + let multi_addresses = + py.allow_threads(|| rt.block_on(self.inner.get_listening_addresses())); + // Create simple node info (port 0 indicates dynamic port allocation) + let node_info = discovery::SimpleNode::new( + ip_address, + 0, // Port 0 for dynamic allocation + self.inner.get_provider_address().map_err(to_py_err)?, + self.inner.get_node_address().map_err(to_py_err)?, + self.inner.get_compute_pool_id() as u32, + peer_id, + multi_addresses, + ); + + // Use default discovery URLs if none provided + let urls = discovery_urls.unwrap_or_else(|| vec!["http://localhost:8089".to_string()]); + + // Upload to discovery + py.allow_threads(|| rt.block_on(self.inner.upload_to_discovery(&node_info, &urls))) + .map_err(to_py_err) + } + + /// Stop the worker client and clean up resources + pub fn stop(&mut self, py: Python) -> PyResult<()> { + // Cancel all background tasks + self.cancellation_token.cancel(); + + if let Some(rt) = self.runtime.as_ref() { + let inner = &mut self.inner; + py.allow_threads(|| rt.block_on(inner.stop_async())) + .map_err(to_py_err)?; + } + + // Clean up the runtime + if let Some(rt) = self.runtime.take() { + rt.shutdown_background(); + } + + Ok(()) + } +} + +// Helper methods +impl WorkerClient { + fn ensure_runtime(&self) -> PyResult<&tokio::runtime::Runtime> { + self.runtime + .as_ref() + .ok_or_else(|| to_py_runtime_err("Client not started. Call start() first.")) + } +} + +// Utility functions +fn create_runtime() -> PyResult { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| to_py_runtime_err(&format!("Failed to create runtime: {}", e))) +} + +fn to_py_err(err: impl std::fmt::Display) -> PyErr { + PyErr::new::(err.to_string()) +} + +fn to_py_runtime_err(msg: &str) -> PyErr { + PyErr::new::(msg.to_string()) +} + +fn message_to_pyobject(message: Message) -> PyObject { + let message_data = match message.message_type { + p2p_handler::MessageType::General { data } => { + serde_json::json!({ + "type": "general", + "data": data, + }) + } + p2p_handler::MessageType::AuthenticationInitiation { challenge } => { + serde_json::json!({ + "type": "auth_initiation", + "challenge": challenge, + }) + } + p2p_handler::MessageType::AuthenticationResponse { + challenge, + signature, + } => { + serde_json::json!({ + "type": "auth_response", + "challenge": challenge, + "signature": signature, + }) + } + p2p_handler::MessageType::AuthenticationSolution { signature } => { + serde_json::json!({ + "type": "auth_solution", + "signature": signature, + }) + } + p2p_handler::MessageType::AuthenticationComplete => { + serde_json::json!({ + "type": "auth_complete", + }) + } + }; + + let json_value = serde_json::json!({ + "message": message_data, + "peer_id": message.peer_id, + "multiaddrs": message.multiaddrs, + "sender_address": message.sender_address, + "is_sender_validator": message.is_sender_validator, + "is_sender_pool_owner": message.is_sender_pool_owner, + }); + + Python::with_gil(|py| crate::utils::json_parser::json_to_pyobject(py, &json_value)) +} diff --git a/crates/python-sdk/tests/integration/test_worker.rs b/crates/python-sdk/tests/integration/test_worker.rs new file mode 100644 index 00000000..ed4e7a4e --- /dev/null +++ b/crates/python-sdk/tests/integration/test_worker.rs @@ -0,0 +1,163 @@ +#[cfg(test)] +mod worker_integration_tests { + use prime_protocol_py::worker::{Message, WorkerClientCore}; + use test_log::test; + use tokio_util::sync::CancellationToken; + + // Note: These tests require a running local blockchain with deployed contracts + // Run with: cargo test --test integration -- --ignored + + #[test(tokio::test)] + #[ignore = "requires local blockchain setup"] + async fn test_worker_full_lifecycle() { + // Standard Anvil test keys + let node_key = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + let cancellation_token = CancellationToken::new(); + + let mut worker = WorkerClientCore::new( + 0, // compute_pool_id + "http://localhost:8545".to_string(), + Some(provider_key.to_string()), + Some(node_key.to_string()), + Some(true), // auto_accept_transactions + Some(5), // funding_retry_count + cancellation_token.clone(), + 8000, // p2p_port + ) + .expect("Failed to create worker"); + + // Start the worker + worker + .start_async() + .await + .expect("Failed to start worker"); + + // Verify peer ID was assigned + let peer_id = worker.get_peer_id(); + assert!(peer_id.is_some()); + + // Test message sending to self + let test_message = Message { + data: b"Hello, self!".to_vec(), + peer_id: peer_id.unwrap().to_string(), + multiaddrs: vec![], + }; + + worker + .send_message(test_message) + .await + .expect("Failed to send message"); + + // Give some time for the message to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Try to receive the message + let received = worker.get_next_message().await; + assert!(received.is_some()); + + let msg = received.unwrap(); + assert_eq!(msg.data, b"Hello, self!"); + + // Clean shutdown + worker + .stop_async() + .await + .expect("Failed to stop worker"); + } + + #[test(tokio::test)] + #[ignore = "requires local blockchain setup"] + async fn test_multiple_workers_communication() { + let node_key1 = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key1 = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + + let node_key2 = "0x47e179ec197488593b187f80a00eb0da91f1b9d0b13f8733639f19c30a34926a"; + let provider_key2 = "0x8b3a350cf5c34c9194ca85829a2df0ec3153be0318b5e2d3348e872092edffba"; + + let cancel_token1 = CancellationToken::new(); + let cancel_token2 = CancellationToken::new(); + + // Create two workers + let mut worker1 = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key1.to_string()), + Some(node_key1.to_string()), + Some(true), + Some(5), + cancel_token1.clone(), + 8001, + ) + .expect("Failed to create worker1"); + + let mut worker2 = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key2.to_string()), + Some(node_key2.to_string()), + Some(true), + Some(5), + cancel_token2.clone(), + 8002, + ) + .expect("Failed to create worker2"); + + // Start both workers + worker1.start_async().await.expect("Failed to start worker1"); + worker2.start_async().await.expect("Failed to start worker2"); + + let peer_id1 = worker1.get_peer_id().unwrap(); + let peer_id2 = worker2.get_peer_id().unwrap(); + + + // Worker1 sends message to Worker2 + let message = Message { + data: b"Hello from Worker1!".to_vec(), + peer_id: peer_id2.to_string(), + multiaddrs: vec!["/ip4/127.0.0.1/tcp/8002".to_string()], + }; + + worker1 + .send_message(message) + .await + .expect("Failed to send message from worker1"); + + // Give time for message delivery + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + // Worker2 should receive the message + let received = worker2.get_next_message().await; + assert!(received.is_some()); + + let msg = received.unwrap(); + assert_eq!(msg.data, b"Hello from Worker1!"); + assert_eq!(msg.peer_id, peer_id1.to_string()); + + // Clean shutdown + worker1.stop_async().await.expect("Failed to stop worker1"); + worker2.stop_async().await.expect("Failed to stop worker2"); + } + + #[test(tokio::test)] + async fn test_worker_without_blockchain() { + // Test that worker fails gracefully when blockchain is not available + let cancellation_token = CancellationToken::new(); + + let mut worker = WorkerClientCore::new( + 0, + "http://localhost:9999".to_string(), // Non-existent RPC + Some("0x1234".to_string()), + Some("0x5678".to_string()), + Some(true), + Some(1), // Low retry count for faster test + cancellation_token, + 8003, + ) + .expect("Failed to create worker"); + + // Starting should fail due to blockchain connection issues + let result = worker.start_async().await; + assert!(result.is_err()); + } +} \ No newline at end of file diff --git a/crates/python-sdk/tests/test_client.py b/crates/python-sdk/tests/test_client.py new file mode 100644 index 00000000..57b02400 --- /dev/null +++ b/crates/python-sdk/tests/test_client.py @@ -0,0 +1,29 @@ +"""Basic tests for the Prime Protocol Python client.""" + +import pytest +from primeprotocol import PrimeProtocolClient + + +def test_client_creation(): + """Test that client can be created with valid RPC URL.""" + client = PrimeProtocolClient("http://localhost:8545") + assert client is not None + + +def test_client_creation_with_empty_url(): + """Test that client creation fails with empty RPC URL.""" + with pytest.raises(ValueError): + PrimeProtocolClient("") + + +def test_client_creation_with_invalid_url(): + """Test that client creation fails with invalid RPC URL.""" + with pytest.raises(ValueError): + PrimeProtocolClient("not-a-valid-url") + + +def test_has_compute_pool_exists_method(): + """Test that the client has the compute_pool_exists method.""" + client = PrimeProtocolClient("http://example.com:8545") + assert hasattr(client, 'compute_pool_exists') + assert callable(getattr(client, 'compute_pool_exists')) \ No newline at end of file diff --git a/crates/python-sdk/tests/test_validator.py b/crates/python-sdk/tests/test_validator.py new file mode 100644 index 00000000..e15edeab --- /dev/null +++ b/crates/python-sdk/tests/test_validator.py @@ -0,0 +1,139 @@ +"""Tests for the validator client.""" + +import pytest +import os +from unittest.mock import patch, Mock + + +def test_validator_client_creation(): + """Test creating a validator client requires proper parameters.""" + # Mock the primeprotocol module since it might not be built + with patch('primeprotocol.ValidatorClient') as MockValidator: + mock_instance = Mock() + MockValidator.return_value = mock_instance + + # Test creation with required parameters + rpc_url = "http://localhost:8545" + private_key = "0x1234567890abcdef" + discovery_urls = ["http://localhost:8089"] + + from primeprotocol import ValidatorClient + validator = ValidatorClient(rpc_url, private_key, discovery_urls) + + # Verify the constructor was called with correct parameters + MockValidator.assert_called_once_with(rpc_url, private_key, discovery_urls) + + +def test_list_non_validated_nodes(): + """Test listing non-validated nodes.""" + with patch('primeprotocol.ValidatorClient') as MockValidator: + # Create mock node data + mock_node1 = Mock() + mock_node1.id = "node1" + mock_node1.provider_address = "0xabc123" + mock_node1.ip_address = "192.168.1.1" + mock_node1.port = 8080 + mock_node1.compute_pool_id = 1 + mock_node1.is_validated = False + mock_node1.is_active = True + mock_node1.is_provider_whitelisted = True + mock_node1.is_blacklisted = False + mock_node1.worker_p2p_id = "p2p_id_1" + mock_node1.created_at = "2024-01-01T00:00:00Z" + mock_node1.last_updated = "2024-01-02T00:00:00Z" + + mock_instance = Mock() + mock_instance.list_non_validated_nodes.return_value = [mock_node1] + MockValidator.return_value = mock_instance + + from primeprotocol import ValidatorClient + validator = ValidatorClient("http://localhost:8545", "0x123", ["http://localhost:8089"]) + + # Get non-validated nodes + nodes = validator.list_non_validated_nodes() + + # Verify results + assert len(nodes) == 1 + assert nodes[0].id == "node1" + assert nodes[0].is_validated == False + assert nodes[0].is_active == True + + +def test_list_all_nodes_dict(): + """Test listing all nodes as dictionaries.""" + with patch('primeprotocol.ValidatorClient') as MockValidator: + # Create mock response + mock_nodes = [ + { + 'id': 'node1', + 'provider_address': '0xabc123', + 'is_validated': False, + 'is_active': True, + 'node': { + 'compute_specs': { + 'gpu': { + 'count': 4, + 'model': 'NVIDIA A100', + 'memory_mb': 40000 + }, + 'cpu': { + 'cores': 32 + }, + 'ram_mb': 128000, + 'storage_gb': 1000 + } + } + }, + { + 'id': 'node2', + 'provider_address': '0xdef456', + 'is_validated': True, + 'is_active': True, + 'node': { + 'compute_specs': None + } + } + ] + + mock_instance = Mock() + mock_instance.list_all_nodes_dict.return_value = mock_nodes + MockValidator.return_value = mock_instance + + from primeprotocol import ValidatorClient + validator = ValidatorClient("http://localhost:8545", "0x123", ["http://localhost:8089"]) + + # Get all nodes + nodes = validator.list_all_nodes_dict() + + # Verify results + assert len(nodes) == 2 + assert nodes[0]['is_validated'] == False + assert nodes[1]['is_validated'] == True + + # Check compute specs + assert nodes[0]['node']['compute_specs']['gpu']['count'] == 4 + assert nodes[0]['node']['compute_specs']['gpu']['model'] == 'NVIDIA A100' + + +def test_get_non_validated_count(): + """Test getting count of non-validated nodes.""" + with patch('primeprotocol.ValidatorClient') as MockValidator: + mock_instance = Mock() + # Mock list_non_validated_nodes to return 3 nodes + mock_nodes = [Mock() for _ in range(3)] + mock_instance.list_non_validated_nodes.return_value = mock_nodes + mock_instance.get_non_validated_count.return_value = 3 + MockValidator.return_value = mock_instance + + from primeprotocol import ValidatorClient + validator = ValidatorClient("http://localhost:8545", "0x123", ["http://localhost:8089"]) + + # Get count + count = validator.get_non_validated_count() + + # Verify result + assert count == 3 + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/crates/python-sdk/uv.lock b/crates/python-sdk/uv.lock new file mode 100644 index 00000000..639a70ba --- /dev/null +++ b/crates/python-sdk/uv.lock @@ -0,0 +1,7 @@ +version = 1 +requires-python = ">=3.8" + +[[package]] +name = "primeprotocol" +version = "0.1.0" +source = { editable = "." } diff --git a/crates/shared/Cargo.toml b/crates/shared/Cargo.toml index 4d3a8760..a1b5f189 100644 --- a/crates/shared/Cargo.toml +++ b/crates/shared/Cargo.toml @@ -44,3 +44,4 @@ subtle = "2.6.1" utoipa = { version = "5.3.0", features = ["actix_extras", "chrono", "uuid"] } futures = { workspace = true } tokio-util = { workspace = true } +reqwest = { workspace = true } diff --git a/crates/shared/src/discovery/README.md b/crates/shared/src/discovery/README.md new file mode 100644 index 00000000..25fe8f59 --- /dev/null +++ b/crates/shared/src/discovery/README.md @@ -0,0 +1,60 @@ +# Shared Discovery Utilities + +This module provides shared utilities for interacting with the Prime Protocol discovery service. + +## Overview + +The discovery service is a temporary solution while we migrate to Kadmilla DHT. It allows nodes to register their information (IP, port, compute specs) and enables validators and orchestrators to discover nodes. + +## Functions + +### `fetch_nodes_from_discovery_url` +Fetches nodes from a single discovery URL with proper authentication. + +### `fetch_nodes_from_discovery_urls` +Fetches nodes from multiple discovery URLs with automatic deduplication. + +### `fetch_pool_nodes_from_discovery` +Convenience function to fetch nodes for a specific compute pool. + +### `fetch_validator_nodes_from_discovery` +Convenience function to fetch all validator-accessible nodes. + +## Usage + +### From Rust (Validator/Orchestrator) + +```rust +use shared::discovery::fetch_validator_nodes_from_discovery; + +let nodes = fetch_validator_nodes_from_discovery( + &discovery_urls, + &wallet +).await?; +``` + +### From Python (Worker) + +```python +from prime_protocol import WorkerClient + +worker = WorkerClient(compute_pool_id=1, rpc_url="http://localhost:8545") + +# Configure discovery +worker.set_discovery_urls(["http://localhost:8089"]) +worker.set_node_config( + ip="192.168.1.100", + port=8080 +) + +# Upload to discovery +worker.start() +worker.upload_to_discovery() +``` + +## Migration Plan + +This is a temporary solution. Once Kadmilla DHT is fully integrated: +1. Discovery uploads will be replaced with DHT announcements +2. Discovery queries will be replaced with DHT lookups +3. The discovery service endpoints can be deprecated \ No newline at end of file diff --git a/crates/shared/src/discovery/mod.rs b/crates/shared/src/discovery/mod.rs new file mode 100644 index 00000000..af281e6c --- /dev/null +++ b/crates/shared/src/discovery/mod.rs @@ -0,0 +1,221 @@ +use crate::models::api::ApiResponse; +use crate::models::node::DiscoveryNode; +use crate::security::request_signer::sign_request_with_nonce; +use crate::web3::wallet::Wallet; +use anyhow::{Context, Result}; +use log::{debug, error}; +use std::time::Duration; + +/// Fetch nodes from a single discovery URL +pub async fn fetch_nodes_from_discovery_url( + discovery_url: &str, + route: &str, + wallet: &Wallet, +) -> Result> { + let address = wallet.wallet.default_signer().address().to_string(); + + let signature = sign_request_with_nonce(route, wallet, None) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "x-address", + reqwest::header::HeaderValue::from_str(&address) + .context("Failed to create address header")?, + ); + headers.insert( + "x-signature", + reqwest::header::HeaderValue::from_str(&signature.signature) + .context("Failed to create signature header")?, + ); + + debug!("Fetching nodes from: {discovery_url}{route}"); + let response = reqwest::Client::new() + .get(format!("{discovery_url}{route}")) + .query(&[("nonce", signature.nonce)]) + .headers(headers) + .timeout(Duration::from_secs(10)) + .send() + .await + .context("Failed to fetch nodes")?; + + let response_text = response + .text() + .await + .context("Failed to get response text")?; + + let parsed_response: ApiResponse> = + serde_json::from_str(&response_text).context("Failed to parse response")?; + + if !parsed_response.success { + error!("Failed to fetch nodes from {discovery_url}: {parsed_response:?}"); + return Ok(vec![]); + } + + Ok(parsed_response.data) +} + +/// Fetch nodes from multiple discovery URLs with deduplication +pub async fn fetch_nodes_from_discovery_urls( + discovery_urls: &[String], + route: &str, + wallet: &Wallet, +) -> Result> { + let mut all_nodes = Vec::new(); + let mut any_success = false; + + for discovery_url in discovery_urls { + match fetch_nodes_from_discovery_url(discovery_url, route, wallet).await { + Ok(nodes) => { + debug!( + "Successfully fetched {} nodes from {}", + nodes.len(), + discovery_url + ); + all_nodes.extend(nodes); + any_success = true; + } + Err(e) => { + error!("Failed to fetch nodes from {discovery_url}: {e:#}"); + } + } + } + + if !any_success { + error!("Failed to fetch nodes from all discovery services"); + return Ok(vec![]); + } + + // Remove duplicates based on node ID + let mut unique_nodes = Vec::new(); + let mut seen_ids = std::collections::HashSet::new(); + for node in all_nodes { + if seen_ids.insert(node.node.id.clone()) { + unique_nodes.push(node); + } + } + + debug!( + "Total unique nodes after deduplication: {}", + unique_nodes.len() + ); + Ok(unique_nodes) +} + +/// Fetch nodes for a specific pool from discovery +pub async fn fetch_pool_nodes_from_discovery( + discovery_urls: &[String], + compute_pool_id: u32, + wallet: &Wallet, +) -> Result> { + let route = format!("/api/pool/{compute_pool_id}"); + fetch_nodes_from_discovery_urls(discovery_urls, &route, wallet).await +} + +/// Fetch all validator-accessible nodes from discovery +pub async fn fetch_validator_nodes_from_discovery( + discovery_urls: &[String], + wallet: &Wallet, +) -> Result> { + let route = "/api/validator"; + fetch_nodes_from_discovery_urls(discovery_urls, route, wallet).await +} + +/// Upload node information to discovery services +/// +/// This function attempts to upload node information to all provided discovery URLs. +/// It returns Ok(()) if at least one upload succeeds. +pub async fn upload_node_to_discovery( + discovery_urls: &[String], + node_data: &serde_json::Value, + wallet: &Wallet, +) -> Result<()> { + let endpoint = "/api/nodes"; + let mut last_error: Option = None; + + for discovery_url in discovery_urls { + match upload_to_single_discovery(discovery_url, endpoint, node_data, wallet).await { + Ok(_) => { + debug!("Successfully uploaded node info to {discovery_url}"); + return Ok(()); + } + Err(e) => { + error!("Failed to upload to {discovery_url}: {e}"); + last_error = Some(e.to_string()); + } + } + } + + // If we reach here, all discovery services failed + if let Some(error) = last_error { + Err(anyhow::anyhow!( + "Failed to upload to all discovery services. Last error: {}", + error + )) + } else { + Err(anyhow::anyhow!( + "Failed to upload to all discovery services" + )) + } +} + +async fn upload_to_single_discovery( + base_url: &str, + endpoint: &str, + node_data: &serde_json::Value, + wallet: &Wallet, +) -> Result<()> { + let signed_request = sign_request_with_nonce(endpoint, wallet, Some(node_data)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "x-address", + wallet + .wallet + .default_signer() + .address() + .to_string() + .parse() + .context("Failed to parse address header")?, + ); + headers.insert( + "x-signature", + signed_request + .signature + .parse() + .context("Failed to parse signature header")?, + ); + + let request_url = format!("{base_url}{endpoint}"); + let response = reqwest::Client::new() + .put(&request_url) + .headers(headers) + .json( + &signed_request + .data + .expect("Signed request data should always be present for discovery upload"), + ) + .timeout(Duration::from_secs(10)) + .send() + .await + .context("Failed to send request")?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "No error message".to_string()); + return Err(anyhow::anyhow!( + "Error: Received response with status code {} from {}: {}", + status, + base_url, + error_text + )); + } + + Ok(()) +} diff --git a/crates/shared/src/lib.rs b/crates/shared/src/lib.rs index 5ce256ed..daa51c09 100644 --- a/crates/shared/src/lib.rs +++ b/crates/shared/src/lib.rs @@ -1,3 +1,4 @@ +pub mod discovery; pub mod models; pub mod p2p; pub mod security; diff --git a/crates/shared/src/models/metric.rs b/crates/shared/src/models/metric.rs index 47b27f24..b85c4926 100644 --- a/crates/shared/src/models/metric.rs +++ b/crates/shared/src/models/metric.rs @@ -58,7 +58,7 @@ mod tests { let invalid_values = vec![(f64::INFINITY, "infinite value"), (f64::NAN, "NaN value")]; for (value, case) in invalid_values { let entry = MetricEntry::new(key.clone(), value); - assert!(entry.is_err(), "Should fail for {}", case); + assert!(entry.is_err(), "Should fail for {case}"); } } diff --git a/crates/shared/src/p2p/service.rs b/crates/shared/src/p2p/service.rs index bf776009..d2208b6e 100644 --- a/crates/shared/src/p2p/service.rs +++ b/crates/shared/src/p2p/service.rs @@ -365,7 +365,7 @@ async fn handle_validation_authentication_response( let ongoing_auth_requests = context.ongoing_auth_requests.read().await; let Some(ongoing_challenge) = ongoing_auth_requests.get(&from) else { bail!( - "no ongoing hardware challenge for peer {from}, cannot handle ValidatorAuthenticationInitiationResponse" + "no ongoing challenge for peer {from}, cannot handle ValidatorAuthenticationInitiationResponse" ); }; diff --git a/crates/shared/src/security/auth_signature_middleware.rs b/crates/shared/src/security/auth_signature_middleware.rs index 1c4c1e10..8ba7767e 100644 --- a/crates/shared/src/security/auth_signature_middleware.rs +++ b/crates/shared/src/security/auth_signature_middleware.rs @@ -634,10 +634,10 @@ mod tests { .await; log::info!("Address: {}", wallet.wallet.default_signer().address()); - log::info!("Signature: {}", signature); - log::info!("Nonce: {}", nonce); + log::info!("Signature: {signature}"); + log::info!("Nonce: {nonce}"); let req = test::TestRequest::get() - .uri(&format!("/test?nonce={}", nonce)) + .uri(&format!("/test?nonce={nonce}")) .insert_header(( "x-address", wallet.wallet.default_signer().address().to_string(), @@ -801,8 +801,7 @@ mod tests { // Create multiple addresses let addresses: Vec
= (0..5) .map(|i| { - Address::from_str(&format!("0x{}000000000000000000000000000000000000000", i)) - .unwrap() + Address::from_str(&format!("0x{i}000000000000000000000000000000000000000")).unwrap() }) .collect(); diff --git a/crates/shared/src/security/request_signer.rs b/crates/shared/src/security/request_signer.rs index ff3e9964..c5ea3605 100644 --- a/crates/shared/src/security/request_signer.rs +++ b/crates/shared/src/security/request_signer.rs @@ -143,7 +143,7 @@ mod tests { let signature = sign_request(endpoint, &wallet, Some(&empty_data)) .await .unwrap(); - println!("Signature: {}", signature); + println!("Signature: {signature}"); assert!(signature.starts_with("0x")); assert_eq!(signature.len(), 132); } diff --git a/crates/shared/src/utils/google_cloud.rs b/crates/shared/src/utils/google_cloud.rs index 128259eb..72fae856 100644 --- a/crates/shared/src/utils/google_cloud.rs +++ b/crates/shared/src/utils/google_cloud.rs @@ -194,20 +194,14 @@ mod tests { #[tokio::test] async fn test_generate_mapping_file() { // Check if required environment variables are set - let bucket_name = match std::env::var("S3_BUCKET_NAME") { - Ok(name) => name, - Err(_) => { - println!("Skipping test: BUCKET_NAME not set"); - return; - } + let Ok(bucket_name) = std::env::var("S3_BUCKET_NAME") else { + println!("Skipping test: BUCKET_NAME not set"); + return; }; - let credentials_base64 = match std::env::var("S3_CREDENTIALS") { - Ok(credentials) => credentials, - Err(_) => { - println!("Skipping test: S3_CREDENTIALS not set"); - return; - } + let Ok(credentials_base64) = std::env::var("S3_CREDENTIALS") else { + println!("Skipping test: S3_CREDENTIALS not set"); + return; }; let storage = GcsStorageProvider::new(&bucket_name, &credentials_base64) @@ -219,15 +213,15 @@ mod tests { .generate_mapping_file(&random_sha256, "run_1/file.parquet") .await .unwrap(); - println!("mapping_content: {}", mapping_content); - println!("bucket_name: {}", bucket_name); + println!("mapping_content: {mapping_content}"); + println!("bucket_name: {bucket_name}"); let original_file_name = storage .resolve_mapping_for_sha(&random_sha256) .await .unwrap(); - println!("original_file_name: {}", original_file_name); + println!("original_file_name: {original_file_name}"); assert_eq!(original_file_name, "run_1/file.parquet"); } } diff --git a/crates/shared/src/utils/mod.rs b/crates/shared/src/utils/mod.rs index d4e3f1c9..290f1ae5 100644 --- a/crates/shared/src/utils/mod.rs +++ b/crates/shared/src/utils/mod.rs @@ -119,7 +119,7 @@ mod tests { provider.add_mapping_file("sha256", "file.txt").await; provider.add_file("file.txt", "content").await; let map_file_link = provider.resolve_mapping_for_sha("sha256").await.unwrap(); - println!("map_file_link: {}", map_file_link); + println!("map_file_link: {map_file_link}"); assert_eq!(map_file_link, "file.txt"); assert_eq!( diff --git a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs index ff0a20ce..32637b62 100644 --- a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs +++ b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs @@ -29,6 +29,7 @@ impl ComputePool

{ .function("getComputePool", &[pool_id.into()])? .call() .await?; + let pool_info_tuple: &[DynSolValue] = pool_info_response.first().unwrap().as_tuple().unwrap(); @@ -388,8 +389,6 @@ impl ComputePool { pool_id: u32, node: Address, ) -> Result, Box> { - println!("Ejecting node"); - let arg_pool_id: U256 = U256::from(pool_id); let result = self diff --git a/crates/validator/src/store/redis.rs b/crates/validator/src/store/redis.rs index 508815c2..c0a0c36b 100644 --- a/crates/validator/src/store/redis.rs +++ b/crates/validator/src/store/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/validator/src/validators/hardware.rs b/crates/validator/src/validators/hardware.rs index 877861da..da5307e3 100644 --- a/crates/validator/src/validators/hardware.rs +++ b/crates/validator/src/validators/hardware.rs @@ -161,7 +161,7 @@ mod tests { let result = validator.validate_nodes(nodes).await; let elapsed = start_time.elapsed(); assert!(elapsed < std::time::Duration::from_secs(11)); - println!("Validation took: {:?}", elapsed); + println!("Validation took: {elapsed:?}"); assert!(result.is_ok()); } diff --git a/crates/validator/src/validators/synthetic_data/chain_operations.rs b/crates/validator/src/validators/synthetic_data/chain_operations.rs index 004c7e45..a0687d18 100644 --- a/crates/validator/src/validators/synthetic_data/chain_operations.rs +++ b/crates/validator/src/validators/synthetic_data/chain_operations.rs @@ -3,7 +3,7 @@ use super::*; impl SyntheticDataValidator { #[cfg(test)] pub fn soft_invalidate_work(&self, work_key: &str) -> Result<(), Error> { - info!("Soft invalidating work: {}", work_key); + info!("Soft invalidating work: {work_key}"); if self.disable_chain_invalidation { info!("Chain invalidation is disabled, skipping work soft invalidation"); @@ -54,7 +54,7 @@ impl SyntheticDataValidator { #[cfg(test)] pub fn invalidate_work(&self, work_key: &str) -> Result<(), Error> { - info!("Invalidating work: {}", work_key); + info!("Invalidating work: {work_key}"); if let Some(metrics) = &self.metrics { metrics.record_work_key_invalidation(); @@ -98,20 +98,27 @@ impl SyntheticDataValidator { } } } - + #[cfg(test)] + #[allow(clippy::unused_async)] pub async fn invalidate_according_to_invalidation_type( &self, work_key: &str, invalidation_type: InvalidationType, ) -> Result<(), Error> { match invalidation_type { - #[cfg(test)] InvalidationType::Soft => self.soft_invalidate_work(work_key), - #[cfg(not(test))] - InvalidationType::Soft => self.soft_invalidate_work(work_key).await, - #[cfg(test)] InvalidationType::Hard => self.invalidate_work(work_key), - #[cfg(not(test))] + } + } + + #[cfg(not(test))] + pub async fn invalidate_according_to_invalidation_type( + &self, + work_key: &str, + invalidation_type: InvalidationType, + ) -> Result<(), Error> { + match invalidation_type { + InvalidationType::Soft => self.soft_invalidate_work(work_key).await, InvalidationType::Hard => self.invalidate_work(work_key).await, } } diff --git a/crates/validator/src/validators/synthetic_data/mod.rs b/crates/validator/src/validators/synthetic_data/mod.rs index ce472c8b..bf8ce6e2 100644 --- a/crates/validator/src/validators/synthetic_data/mod.rs +++ b/crates/validator/src/validators/synthetic_data/mod.rs @@ -237,7 +237,7 @@ impl SyntheticDataValidator { let score: Option = con .zscore("incomplete_groups", group_key) .await - .map_err(|e| Error::msg(format!("Failed to check incomplete tracking: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to check incomplete tracking: {e}")))?; Ok(score.is_some()) } @@ -270,13 +270,10 @@ impl SyntheticDataValidator { let _: () = con .zadd("incomplete_groups", group_key, new_deadline) .await - .map_err(|e| { - Error::msg(format!("Failed to update incomplete group deadline: {}", e)) - })?; + .map_err(|e| Error::msg(format!("Failed to update incomplete group deadline: {e}")))?; debug!( - "Updated deadline for incomplete group {} to {} ({} minutes from now)", - group_key, new_deadline, minutes_from_now + "Updated deadline for incomplete group {group_key} to {new_deadline} ({minutes_from_now} minutes from now)" ); Ok(()) @@ -420,7 +417,7 @@ impl SyntheticDataValidator { let data: Option = con .get(key) .await - .map_err(|e| Error::msg(format!("Failed to get work validation status: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to get work validation status: {e}")))?; match data { Some(data) => { @@ -435,8 +432,7 @@ impl SyntheticDataValidator { reason: None, })), Err(e) => Err(Error::msg(format!( - "Failed to parse work validation data: {}", - e + "Failed to parse work validation data: {e}" ))), } } @@ -1576,8 +1572,7 @@ impl SyntheticDataValidator { .await { error!( - "Failed to update work validation status for {}: {}", - work_key, e + "Failed to update work validation status for {work_key}: {e}" ); } } diff --git a/crates/validator/src/validators/synthetic_data/tests/mod.rs b/crates/validator/src/validators/synthetic_data/tests/mod.rs index a589076f..48aaee85 100644 --- a/crates/validator/src/validators/synthetic_data/tests/mod.rs +++ b/crates/validator/src/validators/synthetic_data/tests/mod.rs @@ -34,7 +34,7 @@ fn setup_test_env() -> Result<(RedisStore, Contracts), Error> { "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", url, ) - .map_err(|e| Error::msg(format!("Failed to create demo wallet: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to create demo wallet: {e}")))?; let contracts = ContractBuilder::new(demo_wallet.provider()) .with_compute_registry() @@ -45,7 +45,7 @@ fn setup_test_env() -> Result<(RedisStore, Contracts), Error> { .with_stake_manager() .with_synthetic_data_validator(Some(Address::ZERO)) .build() - .map_err(|e| Error::msg(format!("Failed to build contracts: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to build contracts: {e}")))?; Ok((store, contracts)) } @@ -197,8 +197,8 @@ async fn test_status_update() -> Result<(), Error> { ) .await .map_err(|e| { - error!("Failed to update work validation status: {}", e); - Error::msg(format!("Failed to update work validation status: {}", e)) + error!("Failed to update work validation status: {e}"); + Error::msg(format!("Failed to update work validation status: {e}")) })?; tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; @@ -206,8 +206,8 @@ async fn test_status_update() -> Result<(), Error> { .get_work_validation_status_from_redis("0x0000000000000000000000000000000000000000") .await .map_err(|e| { - error!("Failed to get work validation status: {}", e); - Error::msg(format!("Failed to get work validation status: {}", e)) + error!("Failed to get work validation status: {e}"); + Error::msg(format!("Failed to get work validation status: {e}")) })?; assert_eq!(status, Some(ValidationResult::Accept)); Ok(()) @@ -344,20 +344,20 @@ async fn test_group_e2e_accept() -> Result<(), Error> { let mock_storage = MockStorageProvider::new(); mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-1-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-1-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-1-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-1-0-0.parquet"), ) .await; server .mock( "POST", - format!("/validategroup/dataset/samplingn-{}-1-0.parquet", GROUP_ID).as_str(), + format!("/validategroup/dataset/samplingn-{GROUP_ID}-1-0.parquet").as_str(), ) .match_body(mockito::Matcher::Json(serde_json::json!({ "file_shas": [FILE_SHA], @@ -371,7 +371,7 @@ async fn test_group_e2e_accept() -> Result<(), Error> { server .mock( "GET", - format!("/statusgroup/dataset/samplingn-{}-1-0.parquet", GROUP_ID).as_str(), + format!("/statusgroup/dataset/samplingn-{GROUP_ID}-1-0.parquet").as_str(), ) .with_status(200) .with_body(r#"{"status": "accept", "input_flops": 1, "output_flops": 1000}"#) @@ -463,7 +463,7 @@ async fn test_group_e2e_accept() -> Result<(), Error> { metrics_2.contains("validator_work_keys_to_process{pool_id=\"0\",validator_id=\"0\"} 0") ); assert!(metrics_2.contains("toploc_config_name=\"Qwen/Qwen0.6\"")); - assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{}\",pool_id=\"0\",result=\"match\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"match\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -490,32 +490,32 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { let mock_storage = MockStorageProvider::new(); mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-1.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-1.parquet"), "file2", ) .await; mock_storage .add_mapping_file( HONEST_FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-0.parquet"), ) .await; mock_storage .add_mapping_file( EXCESSIVE_FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-1.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-1.parquet"), ) .await; server .mock( "POST", - format!("/validategroup/dataset/samplingn-{}-2-0.parquet", GROUP_ID).as_str(), + format!("/validategroup/dataset/samplingn-{GROUP_ID}-2-0.parquet").as_str(), ) .match_body(mockito::Matcher::Json(serde_json::json!({ "file_shas": [HONEST_FILE_SHA, EXCESSIVE_FILE_SHA], @@ -529,7 +529,7 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { server .mock( "GET", - format!("/statusgroup/dataset/samplingn-{}-2-0.parquet", GROUP_ID).as_str(), + format!("/statusgroup/dataset/samplingn-{GROUP_ID}-2-0.parquet").as_str(), ) .with_status(200) .with_body(r#"{"status": "accept", "input_flops": 1, "output_flops": 2000}"#) @@ -636,12 +636,12 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { assert_eq!(plan_3.group_trigger_tasks.len(), 0); assert_eq!(plan_3.group_status_check_tasks.len(), 0); let metrics_2 = export_metrics().unwrap(); - assert!(metrics_2.contains(&format!("validator_group_validations_total{{group_id=\"{}\",pool_id=\"0\",result=\"accept\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_validations_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"accept\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); assert!( metrics_2.contains("validator_work_keys_to_process{pool_id=\"0\",validator_id=\"0\"} 0") ); assert!(metrics_2.contains("toploc_config_name=\"Qwen/Qwen0.6\"")); - assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{}\",pool_id=\"0\",result=\"mismatch\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"mismatch\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -734,26 +734,26 @@ async fn test_incomplete_group_recovery() -> Result<(), Error> { mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-1.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-1.parquet"), "file2", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), ) .await; mock_storage .add_mapping_file( FILE_SHA_2, - &format!("TestModel/dataset/test-{}-2-0-1.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-1.parquet"), ) .await; @@ -800,7 +800,7 @@ async fn test_incomplete_group_recovery() -> Result<(), Error> { assert!(group.is_none(), "Group should be incomplete"); // Check that the incomplete group is being tracked - let group_key = format!("group:{}:2:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:2:0"); let is_tracked = validator .is_group_being_tracked_as_incomplete(&group_key) .await?; @@ -847,14 +847,14 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), ) .await; @@ -902,7 +902,7 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> // Manually expire the incomplete group tracking by removing it and simulating expiry // In a real test, you would wait for the actual expiry, but for testing we simulate it - let group_key = format!("group:{}:2:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:2:0"); validator.track_incomplete_group(&group_key).await?; // Process groups past grace period (this would normally find groups past deadline) @@ -936,7 +936,7 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> assert_eq!(key_status, Some(ValidationResult::IncompleteGroup)); let metrics = export_metrics().unwrap(); - assert!(metrics.contains(&format!("validator_work_keys_soft_invalidated_total{{group_key=\"group:{}:2:0\",pool_id=\"0\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics.contains(&format!("validator_work_keys_soft_invalidated_total{{group_key=\"group:{GROUP_ID}:2:0\",pool_id=\"0\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -952,14 +952,14 @@ async fn test_incomplete_group_status_tracking() -> Result<(), Error> { mock_storage .add_file( - &format!("TestModel/dataset/test-{}-3-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-3-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-3-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-3-0-0.parquet"), ) .await; @@ -1006,7 +1006,7 @@ async fn test_incomplete_group_status_tracking() -> Result<(), Error> { // Manually process groups past grace period to simulate what would happen // after the grace period expires (we simulate this since we can't wait in tests) - let group_key = format!("group:{}:3:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:3:0"); // Manually add the group to tracking and then process it validator.track_incomplete_group(&group_key).await?; diff --git a/crates/validator/src/validators/synthetic_data/toploc.rs b/crates/validator/src/validators/synthetic_data/toploc.rs index 33d9f57f..f5641533 100644 --- a/crates/validator/src/validators/synthetic_data/toploc.rs +++ b/crates/validator/src/validators/synthetic_data/toploc.rs @@ -689,8 +689,7 @@ mod tests { Some(expected_idx) => { assert!( matched, - "Expected file {} to match config {}", - test_file, expected_idx + "Expected file {test_file} to match config {expected_idx}" ); assert_eq!( matched_idx, @@ -701,7 +700,7 @@ mod tests { expected_idx ); } - None => assert!(!matched, "File {} should not match any config", test_file), + None => assert!(!matched, "File {test_file} should not match any config"), } } } diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index eb041cad..51981fa5 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] shared = { workspace = true } p2p = { workspace = true } +operations = { workspace = true} actix-web = { workspace = true } alloy = { workspace = true } diff --git a/crates/worker/src/checks/hardware/interconnect.rs b/crates/worker/src/checks/hardware/interconnect.rs index 21725686..d87d1819 100644 --- a/crates/worker/src/checks/hardware/interconnect.rs +++ b/crates/worker/src/checks/hardware/interconnect.rs @@ -78,7 +78,7 @@ mod tests { #[tokio::test] async fn test_check_speeds() { let result = InterconnectCheck::check_speeds().await; - println!("Test Result: {:?}", result); + println!("Test Result: {result:?}"); // Verify the result is Ok and contains expected tuple structure assert!(result.is_ok()); diff --git a/crates/worker/src/checks/hardware/storage.rs b/crates/worker/src/checks/hardware/storage.rs index 9509e731..8360993b 100644 --- a/crates/worker/src/checks/hardware/storage.rs +++ b/crates/worker/src/checks/hardware/storage.rs @@ -216,7 +216,7 @@ fn test_or_create_app_directory(path: &str) -> bool { } #[cfg(not(target_os = "linux"))] -pub fn find_largest_storage() -> Option { +pub(crate) fn find_largest_storage() -> Option { None } @@ -233,7 +233,7 @@ pub(crate) fn get_available_space(path: &str) -> Option { } #[cfg(not(target_os = "linux"))] -pub fn get_available_space(_path: &str) -> Option { +pub(crate) fn get_available_space(_path: &str) -> Option { None } diff --git a/crates/worker/src/checks/hardware/storage.rs:236:1 b/crates/worker/src/checks/hardware/storage.rs:236:1 new file mode 100644 index 00000000..e69de29b diff --git a/crates/worker/src/checks/stun.rs b/crates/worker/src/checks/stun.rs index 5830b49e..734f2795 100644 --- a/crates/worker/src/checks/stun.rs +++ b/crates/worker/src/checks/stun.rs @@ -139,7 +139,7 @@ mod tests { async fn test_get_public_ip() { let stun_check = StunCheck::new(Duration::from_secs(5), 0); let public_ip = stun_check.get_public_ip().await.unwrap(); - println!("Public IP: {}", public_ip); + println!("Public IP: {public_ip}"); assert!(!public_ip.is_empty()); } } diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs index 1e9e5825..e0214d83 100644 --- a/crates/worker/src/cli/command.rs +++ b/crates/worker/src/cli/command.rs @@ -6,9 +6,8 @@ use crate::console::Console; use crate::docker::taskbridge::TaskBridge; use crate::docker::DockerService; use crate::metrics::store::MetricsStore; -use crate::operations::compute_node::ComputeNodeOperations; use crate::operations::heartbeat::service::HeartbeatService; -use crate::operations::provider::ProviderOperations; +use crate::operations::node_monitor::NodeMonitor; use crate::services::discovery::DiscoveryService; use crate::services::discovery_updater::DiscoveryUpdater; use crate::state::system_state::SystemState; @@ -20,6 +19,8 @@ use alloy::signers::local::PrivateKeySigner; use alloy::signers::Signer; use clap::{Parser, Subcommand}; use log::{error, info}; +use operations::operations::compute_node::ComputeNodeOperations; +use operations::operations::provider::ProviderOperations; use shared::models::node::ComputeRequirements; use shared::models::node::Node; use shared::web3::contracts::core::builder::ContractBuilder; @@ -294,12 +295,10 @@ pub async fn execute_command( let provider_ops_cancellation = cancellation_token.clone(); - let compute_node_state = state.clone(); let compute_node_ops = ComputeNodeOperations::new( &provider_wallet_instance, &node_wallet_instance, contracts.clone(), - compute_node_state, ); let discovery_urls = vec![discovery_url @@ -606,7 +605,7 @@ pub async fn execute_command( .retry_register_provider( required_stake, *funding_retry_count, - cancellation_token.clone(), + Some(cancellation_token.clone()), ) .await { @@ -709,7 +708,7 @@ pub async fn execute_command( let heartbeat = match heartbeat_service.clone() { Ok(service) => service, Err(e) => { - error!("❌ Heartbeat service is not available: {e}"); + error!("❌ Heartbeat service is not available: {e:?}"); std::process::exit(1); } }; @@ -821,8 +820,13 @@ pub async fn execute_command( provider_ops.start_monitoring(provider_ops_cancellation); let pool_id = state.get_compute_pool_id(); - if let Err(err) = compute_node_ops.start_monitoring(cancellation_token.clone(), pool_id) - { + let node_monitor = NodeMonitor::new( + provider_wallet_instance.clone(), + node_wallet_instance.clone(), + contracts.clone(), + state.clone(), + ); + if let Err(err) = node_monitor.start_monitoring(cancellation_token.clone(), pool_id) { error!("❌ Failed to start node monitoring: {err}"); std::process::exit(1); } @@ -1031,7 +1035,6 @@ pub async fn execute_command( /* Initialize dependencies - services, contracts, operations */ - let contracts = ContractBuilder::new(provider_wallet_instance.provider()) .with_compute_registry() .with_ai_token() diff --git a/crates/worker/src/docker/taskbridge/bridge.rs b/crates/worker/src/docker/taskbridge/bridge.rs index 4765ef06..594bc62d 100644 --- a/crates/worker/src/docker/taskbridge/bridge.rs +++ b/crates/worker/src/docker/taskbridge/bridge.rs @@ -565,7 +565,7 @@ mod tests { "test_label2": 20.0, }); let sample_metric = serde_json::to_string(&data)?; - debug!("Sending {:?}", sample_metric); + debug!("Sending {sample_metric:?}"); let msg = format!("{}{}", sample_metric, "\n"); stream.write_all(msg.as_bytes()).await?; stream.flush().await?; @@ -616,7 +616,7 @@ mod tests { "output/input_flops": 2500.0, }); let sample_metric = serde_json::to_string(&json)?; - debug!("Sending {:?}", sample_metric); + debug!("Sending {sample_metric:?}"); let msg = format!("{}{}", sample_metric, "\n"); stream.write_all(msg.as_bytes()).await?; stream.flush().await?; @@ -626,8 +626,7 @@ mod tests { let all_metrics = metrics_store.get_all_metrics().await; assert!( all_metrics.is_empty(), - "Expected metrics to be empty but found: {:?}", - all_metrics + "Expected metrics to be empty but found: {all_metrics:?}" ); bridge_handle.abort(); diff --git a/crates/worker/src/operations/heartbeat/service.rs b/crates/worker/src/operations/heartbeat/service.rs index 1b002cae..289e86af 100644 --- a/crates/worker/src/operations/heartbeat/service.rs +++ b/crates/worker/src/operations/heartbeat/service.rs @@ -24,7 +24,6 @@ pub(crate) struct HeartbeatService { docker_service: Arc, metrics_store: Arc, } - #[derive(Debug, Clone, thiserror::Error)] pub(crate) enum HeartbeatError { #[error("HTTP request failed")] @@ -32,6 +31,7 @@ pub(crate) enum HeartbeatError { #[error("Service initialization failed")] InitFailed, } + impl HeartbeatService { #[allow(clippy::too_many_arguments)] pub(crate) fn new( diff --git a/crates/worker/src/operations/mod.rs b/crates/worker/src/operations/mod.rs index 193b64ae..d684160a 100644 --- a/crates/worker/src/operations/mod.rs +++ b/crates/worker/src/operations/mod.rs @@ -1,3 +1,2 @@ -pub(crate) mod compute_node; pub(crate) mod heartbeat; -pub(crate) mod provider; +pub(crate) mod node_monitor; diff --git a/crates/worker/src/operations/compute_node.rs b/crates/worker/src/operations/node_monitor.rs similarity index 59% rename from crates/worker/src/operations/compute_node.rs rename to crates/worker/src/operations/node_monitor.rs index 00f147a7..af33d450 100644 --- a/crates/worker/src/operations/compute_node.rs +++ b/crates/worker/src/operations/node_monitor.rs @@ -1,5 +1,5 @@ -use crate::{console::Console, state::system_state::SystemState}; -use alloy::{primitives::utils::keccak256 as keccak, primitives::U256, signers::Signer}; +use crate::state::system_state::SystemState; +use alloy::primitives::U256; use anyhow::Result; use shared::web3::wallet::Wallet; use shared::web3::{contracts::core::builder::Contracts, wallet::WalletProvider}; @@ -7,17 +7,17 @@ use std::sync::Arc; use tokio::time::{sleep, Duration}; use tokio_util::sync::CancellationToken; -pub(crate) struct ComputeNodeOperations<'c> { - provider_wallet: &'c Wallet, - node_wallet: &'c Wallet, +pub(crate) struct NodeMonitor { + provider_wallet: Wallet, + node_wallet: Wallet, contracts: Contracts, system_state: Arc, } -impl<'c> ComputeNodeOperations<'c> { +impl NodeMonitor { pub(crate) fn new( - provider_wallet: &'c Wallet, - node_wallet: &'c Wallet, + provider_wallet: Wallet, + node_wallet: Wallet, contracts: Contracts, system_state: Arc, ) -> Self { @@ -43,11 +43,12 @@ impl<'c> ComputeNodeOperations<'c> { let mut last_claimable = None; let mut last_locked = None; let mut first_check = true; + tokio::spawn(async move { loop { tokio::select! { _ = cancellation_token.cancelled() => { - Console::info("Monitor", "Shutting down node status monitor..."); + log::info!("Shutting down node status monitor..."); break; } _ = async { @@ -55,16 +56,15 @@ impl<'c> ComputeNodeOperations<'c> { Ok((active, validated)) => { if first_check || active != last_active { if !first_check { - Console::info("🔄 Chain Sync - Pool membership changed", &format!("From {last_active} to {active}" - )); + log::info!("🔄 Chain Sync - Pool membership changed: From {last_active} to {active}"); } else { - Console::info("🔄 Chain Sync - Node pool membership", &format!("{active}")); + log::info!("🔄 Chain Sync - Node pool membership: {active}"); } last_active = active; } let is_running = system_state.is_running().await; if !active && is_running { - Console::warning("Node is not longer in pool, shutting down heartbeat..."); + log::warn!("Node is not longer in pool, shutting down heartbeat..."); if let Err(e) = system_state.set_running(false, None).await { log::error!("Failed to set running to false: {e:?}"); } @@ -72,10 +72,9 @@ impl<'c> ComputeNodeOperations<'c> { if first_check || validated != last_validated { if !first_check { - Console::info("🔄 Chain Sync - Validation changed", &format!("From {last_validated} to {validated}" - )); + log::info!("🔄 Chain Sync - Validation changed: From {last_validated} to {validated}"); } else { - Console::info("🔄 Chain Sync - Node validation", &format!("{validated}")); + log::info!("🔄 Chain Sync - Node validation: {validated}"); } last_validated = validated; } @@ -91,7 +90,7 @@ impl<'c> ComputeNodeOperations<'c> { last_locked = Some(locked); let claimable_formatted = claimable.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0); let locked_formatted = locked.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0); - Console::info("Rewards", &format!("{claimable_formatted} claimable, {locked_formatted} locked")); + log::info!("Rewards: {claimable_formatted} claimable, {locked_formatted} locked"); } } Err(e) => { @@ -113,55 +112,4 @@ impl<'c> ComputeNodeOperations<'c> { }); Ok(()) } - - pub(crate) async fn check_compute_node_exists( - &self, - ) -> Result> { - let compute_node = self - .contracts - .compute_registry - .get_node( - self.provider_wallet.wallet.default_signer().address(), - self.node_wallet.wallet.default_signer().address(), - ) - .await; - - match compute_node { - Ok(_) => Ok(true), - Err(_) => Ok(false), - } - } - - // Returns true if the compute node was added, false if it already exists - pub(crate) async fn add_compute_node( - &self, - compute_units: U256, - ) -> Result> { - Console::title("🔄 Adding compute node"); - - if self.check_compute_node_exists().await? { - return Ok(false); - } - - Console::progress("Adding compute node"); - let provider_address = self.provider_wallet.wallet.default_signer().address(); - let node_address = self.node_wallet.wallet.default_signer().address(); - let digest = keccak([provider_address.as_slice(), node_address.as_slice()].concat()); - - let signature = self - .node_wallet - .signer - .sign_message(digest.as_slice()) - .await? - .as_bytes(); - - // Create the signature bytes - let add_node_tx = self - .contracts - .prime_network - .add_compute_node(node_address, compute_units, signature.to_vec()) - .await?; - Console::success(&format!("Add node tx: {add_node_tx:?}")); - Ok(true) - } } diff --git a/crates/worker/src/p2p/mod.rs b/crates/worker/src/p2p/mod.rs index 94fe10a3..8d31d27a 100644 --- a/crates/worker/src/p2p/mod.rs +++ b/crates/worker/src/p2p/mod.rs @@ -1,7 +1,7 @@ use anyhow::Context as _; use anyhow::Result; use futures::stream::FuturesUnordered; -use p2p::InviteRequestUrl; +use operations::invite::{common::get_endpoint_from_url, worker::is_invite_expired}; use p2p::Node; use p2p::NodeBuilder; use p2p::PeerId; @@ -421,6 +421,11 @@ async fn handle_invite_request( ); } + // Check if invite has expired + if is_invite_expired(&req)? { + anyhow::bail!("invite has expired"); + } + let invite_bytes = hex::decode(&req.invite).context("failed to decode invite hex")?; if invite_bytes.len() < 65 { @@ -481,12 +486,7 @@ async fn handle_invite_request( } } - let heartbeat_endpoint = match req.url { - InviteRequestUrl::MasterIpPort(ip, port) => { - format!("http://{ip}:{port}/heartbeat") - } - InviteRequestUrl::MasterUrl(url) => format!("{url}/heartbeat"), - }; + let heartbeat_endpoint = get_endpoint_from_url(&req.url, "heartbeat"); context .heartbeat_service diff --git a/crates/worker/src/services/discovery.rs b/crates/worker/src/services/discovery.rs index 2088215c..e4440a0a 100644 --- a/crates/worker/src/services/discovery.rs +++ b/crates/worker/src/services/discovery.rs @@ -1,16 +1,14 @@ use anyhow::Result; use shared::models::node::Node; -use shared::security::request_signer::sign_request_with_nonce; use shared::web3::wallet::Wallet; pub(crate) struct DiscoveryService { wallet: Wallet, base_urls: Vec, - endpoint: String, } impl DiscoveryService { - pub(crate) fn new(wallet: Wallet, base_urls: Vec, endpoint: Option) -> Self { + pub(crate) fn new(wallet: Wallet, base_urls: Vec, _endpoint: Option) -> Self { let urls = if base_urls.is_empty() { vec!["http://localhost:8089".to_string()] } else { @@ -19,86 +17,13 @@ impl DiscoveryService { Self { wallet, base_urls: urls, - endpoint: endpoint.unwrap_or_else(|| "/api/nodes".to_string()), } } - async fn upload_to_single_discovery(&self, node_config: &Node, base_url: &str) -> Result<()> { - let request_data = serde_json::to_value(node_config)?; - - let signed_request = - sign_request_with_nonce(&self.endpoint, &self.wallet, Some(&request_data)) - .await - .map_err(|e| anyhow::anyhow!("{}", e))?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "x-address", - self.wallet - .wallet - .default_signer() - .address() - .to_string() - .parse() - .unwrap(), - ); - headers.insert("x-signature", signed_request.signature.parse().unwrap()); - let request_url = format!("{}{}", base_url, &self.endpoint); - let client = reqwest::Client::new(); - let response = client - .put(&request_url) - .headers(headers) - .json( - &signed_request - .data - .expect("Signed request data should always be present for discovery upload"), - ) - .send() - .await?; - - if !response.status().is_success() { - let status = response.status(); - let error_text = response - .text() - .await - .unwrap_or_else(|_| "No error message".to_string()); - return Err(anyhow::anyhow!( - "Error: Received response with status code {} from {}: {}", - status, - base_url, - error_text - )); - } - - Ok(()) - } - pub(crate) async fn upload_discovery_info(&self, node_config: &Node) -> Result<()> { - let mut last_error: Option = None; - - for base_url in &self.base_urls { - match self.upload_to_single_discovery(node_config, base_url).await { - Ok(_) => { - // Successfully uploaded to one discovery service, return immediately - return Ok(()); - } - Err(e) => { - last_error = Some(e.to_string()); - } - } - } + let node_data = serde_json::to_value(node_config)?; - // If we reach here, all discovery services failed - if let Some(error) = last_error { - Err(anyhow::anyhow!( - "Failed to upload to all discovery services. Last error: {}", - error - )) - } else { - Err(anyhow::anyhow!( - "Failed to upload to all discovery services" - )) - } + shared::discovery::upload_node_to_discovery(&self.base_urls, &node_data, &self.wallet).await } } @@ -107,7 +32,6 @@ impl Clone for DiscoveryService { Self { wallet: self.wallet.clone(), base_urls: self.base_urls.clone(), - endpoint: self.endpoint.clone(), } } }