From d44584c3856515bea6647b89a8d1fe58ff541562 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Fri, 11 Jul 2025 13:50:53 +0200 Subject: [PATCH 01/23] basic version of python sdk with provider and node registration --- Cargo.lock | 157 +++++++++- Cargo.toml | 11 +- Makefile | 9 + crates/dev-utils/examples/compute_pool.rs | 9 +- crates/dev-utils/examples/create_domain.rs | 2 +- crates/dev-utils/examples/eject_node.rs | 6 +- crates/dev-utils/examples/get_node_info.rs | 5 +- crates/dev-utils/examples/invalidate_work.rs | 2 +- crates/dev-utils/examples/mint_ai_token.rs | 4 +- .../examples/set_min_stake_amount.rs | 4 +- .../dev-utils/examples/start_compute_pool.rs | 2 +- crates/dev-utils/examples/submit_work.rs | 2 +- .../examples/test_concurrent_calls.rs | 14 +- crates/discovery/src/api/routes/node.rs | 16 +- crates/discovery/src/store/redis.rs | 4 +- .../orchestrator/src/api/routes/heartbeat.rs | 2 +- crates/orchestrator/src/api/routes/task.rs | 4 +- .../src/plugins/node_groups/tests.rs | 24 +- crates/orchestrator/src/scheduler/mod.rs | 4 +- crates/orchestrator/src/status_update/mod.rs | 35 ++- crates/orchestrator/src/store/core/redis.rs | 4 +- .../src/store/domains/heartbeat_store.rs | 2 +- .../src/store/domains/metrics_store.rs | 2 +- crates/prime-core/Cargo.toml | 31 ++ crates/prime-core/src/lib.rs | 1 + .../prime-core/src/operations/compute_node.rs | 92 ++++++ crates/prime-core/src/operations/mod.rs | 2 + .../src/operations/provider.rs | 158 +++++----- crates/prime-protocol-py/.gitignore | 24 ++ crates/prime-protocol-py/.python-version | 1 + crates/prime-protocol-py/Cargo.toml | 32 ++ crates/prime-protocol-py/Makefile | 46 +++ crates/prime-protocol-py/README.md | 46 +++ .../prime-protocol-py/examples/basic_usage.py | 24 ++ crates/prime-protocol-py/pyproject.toml | 37 +++ crates/prime-protocol-py/requirements-dev.txt | 3 + crates/prime-protocol-py/setup.sh | 16 + crates/prime-protocol-py/src/client.rs | 294 ++++++++++++++++++ crates/prime-protocol-py/src/error.rs | 21 ++ crates/prime-protocol-py/src/lib.rs | 62 ++++ crates/prime-protocol-py/tests/test_client.py | 29 ++ crates/prime-protocol-py/uv.lock | 7 + crates/shared/src/models/metric.rs | 2 +- .../src/security/auth_signature_middleware.rs | 9 +- crates/shared/src/security/request_signer.rs | 2 +- crates/shared/src/utils/google_cloud.rs | 24 +- crates/shared/src/utils/mod.rs | 2 +- .../implementations/compute_pool_contract.rs | 4 + crates/validator/src/store/redis.rs | 4 +- crates/validator/src/validators/hardware.rs | 2 +- .../synthetic_data/chain_operations.rs | 23 +- .../src/validators/synthetic_data/mod.rs | 17 +- .../validators/synthetic_data/tests/mod.rs | 62 ++-- .../src/validators/synthetic_data/toploc.rs | 5 +- crates/worker/Cargo.toml | 1 + .../src/checks/hardware/interconnect.rs | 2 +- crates/worker/src/checks/hardware/storage.rs | 4 +- crates/worker/src/checks/stun.rs | 2 +- crates/worker/src/cli/command.rs | 21 +- crates/worker/src/docker/taskbridge/bridge.rs | 7 +- .../src/operations/heartbeat/service.rs | 2 +- crates/worker/src/operations/mod.rs | 3 +- .../{compute_node.rs => node_monitor.rs} | 84 +---- 63 files changed, 1198 insertions(+), 334 deletions(-) create mode 100644 crates/prime-core/Cargo.toml create mode 100644 crates/prime-core/src/lib.rs create mode 100644 crates/prime-core/src/operations/compute_node.rs create mode 100644 crates/prime-core/src/operations/mod.rs rename crates/{worker => prime-core}/src/operations/provider.rs (67%) create mode 100644 crates/prime-protocol-py/.gitignore create mode 100644 crates/prime-protocol-py/.python-version create mode 100644 crates/prime-protocol-py/Cargo.toml create mode 100644 crates/prime-protocol-py/Makefile create mode 100644 crates/prime-protocol-py/README.md create mode 100644 crates/prime-protocol-py/examples/basic_usage.py create mode 100644 crates/prime-protocol-py/pyproject.toml create mode 100644 crates/prime-protocol-py/requirements-dev.txt create mode 100755 crates/prime-protocol-py/setup.sh create mode 100644 crates/prime-protocol-py/src/client.rs create mode 100644 crates/prime-protocol-py/src/error.rs create mode 100644 crates/prime-protocol-py/src/lib.rs create mode 100644 crates/prime-protocol-py/tests/test_client.py create mode 100644 crates/prime-protocol-py/uv.lock rename crates/worker/src/operations/{compute_node.rs => node_monitor.rs} (59%) diff --git a/Cargo.lock b/Cargo.lock index 67fc79bd..47b10ada 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4215,6 +4215,12 @@ dependencies = [ "serde", ] +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + [[package]] name = "inout" version = "0.1.4" @@ -5489,6 +5495,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -5884,7 +5899,7 @@ dependencies = [ "bitflags 1.3.2", "cfg-if", "libc", - "memoffset", + "memoffset 0.7.1", "pin-utils", ] @@ -6681,6 +6696,47 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "prime-core" +version = "0.1.0" +dependencies = [ + "actix-web", + "alloy", + "alloy-provider", + "anyhow", + "env_logger", + "futures-util", + "hex", + "log", + "rand 0.8.5", + "redis", + "serde", + "serde_json", + "shared", + "subtle", + "tokio", + "tokio-util", + "url", + "uuid", +] + +[[package]] +name = "prime-protocol-py" +version = "0.1.0" +dependencies = [ + "alloy", + "alloy-provider", + "log", + "prime-core", + "pyo3", + "pyo3-log", + "shared", + "thiserror 1.0.69", + "tokio", + "tokio-test", + "url", +] + [[package]] name = "primeorder" version = "0.13.6" @@ -6865,6 +6921,79 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "pyo3" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +dependencies = [ + "indoc", + "libc", + "memoffset 0.9.1", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + +[[package]] +name = "pyo3-macros" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.101", +] + [[package]] name = "quanta" version = "0.10.1" @@ -8641,6 +8770,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + [[package]] name = "tempfile" version = "3.14.0" @@ -8840,6 +8975,19 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-tungstenite" version = "0.24.0" @@ -9173,6 +9321,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "universal-hash" version = "0.5.1" @@ -10190,6 +10344,7 @@ dependencies = [ "log", "nvml-wrapper", "p2p", + "prime-core", "rand 0.8.5", "rand 0.9.1", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 1bc9e2ac..15655fbc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ members = [ "crates/orchestrator", "crates/p2p", "crates/dev-utils", + "crates/prime-protocol-py", + "crates/prime-core", ] resolver = "2" @@ -14,6 +16,7 @@ resolver = "2" shared = { path = "crates/shared" } p2p = { path = "crates/p2p" } +prime-core = { path = "crates/prime-core" } actix-web = "4.9.0" clap = { version = "4.5.27", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] } @@ -42,7 +45,6 @@ mockito = "1.7.0" iroh = "0.34.1" rand_v8 = { package = "rand", version = "0.8.5", features = ["std"] } rand_core_v6 = { package = "rand_core", version = "0.6.4", features = ["std"] } -ipld-core = "0.4" rust-ipfs = "0.14" cid = "0.11" tracing = "0.1.41" @@ -59,3 +61,10 @@ manual_let_else = "warn" [workspace.lints.rust] unreachable_pub = "warn" + +[workspace.metadata.rust-analyzer] +# Help rust-analyzer with proc-macros +procMacro.enable = true +procMacro.attributes.enable = true +# Use a separate target directory for rust-analyzer +targetDir = true diff --git a/Makefile b/Makefile index decd07f6..dfc0d0af 100644 --- a/Makefile +++ b/Makefile @@ -268,3 +268,12 @@ deregister-worker: set -a; source ${ENV_FILE}; set +a; \ cargo run --bin worker -- deregister --compute-pool-id $${WORKER_COMPUTE_POOL_ID} --private-key-provider $${PRIVATE_KEY_PROVIDER} --private-key-node $${PRIVATE_KEY_NODE} --rpc-url $${RPC_URL} +# Python Package +.PHONY: python-install +python-install: + @cd crates/prime-protocol-py && make install + +.PHONY: python-test +python-test: + @cd crates/prime-protocol-py && make test + diff --git a/crates/dev-utils/examples/compute_pool.rs b/crates/dev-utils/examples/compute_pool.rs index 2569980c..51658d59 100644 --- a/crates/dev-utils/examples/compute_pool.rs +++ b/crates/dev-utils/examples/compute_pool.rs @@ -68,17 +68,14 @@ async fn main() -> Result<()> { compute_limit, ) .await; - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let rewards_distributor_address = contracts .compute_pool .get_reward_distributor_address(U256::from(0)) .await .unwrap(); - println!( - "Rewards distributor address: {:?}", - rewards_distributor_address - ); + println!("Rewards distributor address: {rewards_distributor_address:?}"); let rewards_distributor = RewardsDistributor::new( rewards_distributor_address, wallet.provider(), @@ -86,7 +83,7 @@ async fn main() -> Result<()> { ); let rate = U256::from(10000000000000000u64); let tx = rewards_distributor.set_reward_rate(rate).await; - println!("Setting reward rate: {:?}", tx); + println!("Setting reward rate: {tx:?}"); let reward_rate = rewards_distributor.get_reward_rate().await.unwrap(); println!( diff --git a/crates/dev-utils/examples/create_domain.rs b/crates/dev-utils/examples/create_domain.rs index 4365c764..d1da5ea2 100644 --- a/crates/dev-utils/examples/create_domain.rs +++ b/crates/dev-utils/examples/create_domain.rs @@ -59,6 +59,6 @@ async fn main() -> Result<()> { .await; println!("Creating domain: {}", args.domain_name); println!("Validation logic: {}", args.validation_logic); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/eject_node.rs b/crates/dev-utils/examples/eject_node.rs index e2ed03a3..142aa1cd 100644 --- a/crates/dev-utils/examples/eject_node.rs +++ b/crates/dev-utils/examples/eject_node.rs @@ -52,20 +52,20 @@ async fn main() -> Result<()> { .compute_registry .get_node(provider_address, node_address) .await; - println!("Node info: {:?}", node_info); + println!("Node info: {node_info:?}"); let tx = contracts .compute_pool .eject_node(args.pool_id, node_address) .await; println!("Ejected node {} from pool {}", args.node, args.pool_id); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let node_info = contracts .compute_registry .get_node(provider_address, node_address) .await; - println!("Post ejection node info: {:?}", node_info); + println!("Post ejection node info: {node_info:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/get_node_info.rs b/crates/dev-utils/examples/get_node_info.rs index fec5f526..79c7c120 100644 --- a/crates/dev-utils/examples/get_node_info.rs +++ b/crates/dev-utils/examples/get_node_info.rs @@ -55,9 +55,6 @@ async fn main() -> Result<()> { .await .unwrap(); - println!( - "Node Active: {}, Validated: {}, In Pool: {}", - active, validated, is_node_in_pool - ); + println!("Node Active: {active}, Validated: {validated}, In Pool: {is_node_in_pool}"); Ok(()) } diff --git a/crates/dev-utils/examples/invalidate_work.rs b/crates/dev-utils/examples/invalidate_work.rs index 78154b07..c93c8cee 100644 --- a/crates/dev-utils/examples/invalidate_work.rs +++ b/crates/dev-utils/examples/invalidate_work.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { "Invalidated work in pool {} with penalty {}", args.pool_id, args.penalty ); - println!("Transaction hash: {:?}", tx); + println!("Transaction hash: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/mint_ai_token.rs b/crates/dev-utils/examples/mint_ai_token.rs index 5e572b40..bc43b78d 100644 --- a/crates/dev-utils/examples/mint_ai_token.rs +++ b/crates/dev-utils/examples/mint_ai_token.rs @@ -45,9 +45,9 @@ async fn main() -> Result<()> { let amount = U256::from(args.amount) * Unit::ETHER.wei(); let tx = contracts.ai_token.mint(address, amount).await; println!("Minting to address: {}", args.address); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let balance = contracts.ai_token.balance_of(address).await; - println!("Balance: {:?}", balance); + println!("Balance: {balance:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/set_min_stake_amount.rs b/crates/dev-utils/examples/set_min_stake_amount.rs index 82644e61..2858f5c7 100644 --- a/crates/dev-utils/examples/set_min_stake_amount.rs +++ b/crates/dev-utils/examples/set_min_stake_amount.rs @@ -36,13 +36,13 @@ async fn main() -> Result<()> { .unwrap(); let min_stake_amount = U256::from(args.min_stake_amount) * Unit::ETHER.wei(); - println!("Min stake amount: {}", min_stake_amount); + println!("Min stake amount: {min_stake_amount}"); let tx = contracts .prime_network .set_stake_minimum(min_stake_amount) .await; - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/start_compute_pool.rs b/crates/dev-utils/examples/start_compute_pool.rs index b11e2b2c..a94c0b6f 100644 --- a/crates/dev-utils/examples/start_compute_pool.rs +++ b/crates/dev-utils/examples/start_compute_pool.rs @@ -41,6 +41,6 @@ async fn main() -> Result<()> { .start_compute_pool(U256::from(args.pool_id)) .await; println!("Started compute pool with id: {}", args.pool_id); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/submit_work.rs b/crates/dev-utils/examples/submit_work.rs index aa3b489c..0fcf20d0 100644 --- a/crates/dev-utils/examples/submit_work.rs +++ b/crates/dev-utils/examples/submit_work.rs @@ -64,7 +64,7 @@ async fn main() -> Result<()> { "Submitted work for node {} in pool {}", args.node, args.pool_id ); - println!("Transaction hash: {:?}", tx); + println!("Transaction hash: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/test_concurrent_calls.rs b/crates/dev-utils/examples/test_concurrent_calls.rs index 47f7bbea..1bef230a 100644 --- a/crates/dev-utils/examples/test_concurrent_calls.rs +++ b/crates/dev-utils/examples/test_concurrent_calls.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { let wallet = Arc::new(Wallet::new(&args.key, Url::parse(&args.rpc_url)?).unwrap()); let price = wallet.provider.get_gas_price().await?; - println!("Gas price: {:?}", price); + println!("Gas price: {price:?}"); let current_nonce = wallet .provider @@ -50,8 +50,8 @@ async fn main() -> Result<()> { .block_id(BlockId::Number(BlockNumberOrTag::Pending)) .await?; - println!("Pending nonce: {:?}", pending_nonce); - println!("Current nonce: {:?}", current_nonce); + println!("Pending nonce: {pending_nonce:?}"); + println!("Current nonce: {current_nonce:?}"); // Unfortunately have to build all contracts atm let contracts = Arc::new( @@ -67,7 +67,7 @@ async fn main() -> Result<()> { let address = Address::from_str(&args.address).unwrap(); let amount = U256::from(args.amount) * Unit::ETHER.wei(); let random = (rand::random::() % 10) + 1; - println!("Random: {:?}", random); + println!("Random: {random:?}"); let contracts_one = contracts.clone(); let wallet_one = wallet.clone(); @@ -80,7 +80,7 @@ async fn main() -> Result<()> { let tx = retry_call(mint_call, 5, wallet_one.provider(), None) .await .unwrap(); - println!("Transaction hash I: {:?}", tx); + println!("Transaction hash I: {tx:?}"); }); let contracts_two = contracts.clone(); @@ -93,11 +93,11 @@ async fn main() -> Result<()> { let tx = retry_call(mint_call_two, 5, wallet_two.provider(), None) .await .unwrap(); - println!("Transaction hash II: {:?}", tx); + println!("Transaction hash II: {tx:?}"); }); let balance = contracts.ai_token.balance_of(address).await.unwrap(); - println!("Balance: {:?}", balance); + println!("Balance: {balance:?}"); tokio::time::sleep(tokio::time::Duration::from_secs(40)).await; Ok(()) } diff --git a/crates/discovery/src/api/routes/node.rs b/crates/discovery/src/api/routes/node.rs index b2cf780f..aa6ca45a 100644 --- a/crates/discovery/src/api/routes/node.rs +++ b/crates/discovery/src/api/routes/node.rs @@ -465,12 +465,10 @@ mod tests { assert_eq!(body.data, "Node registered successfully"); let nodes = app_state.node_store.get_nodes().await; - let nodes = match nodes { - Ok(nodes) => nodes, - Err(_) => { - panic!("Error getting nodes"); - } + let Ok(nodes) = nodes else { + panic!("Error getting nodes"); }; + assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].id, node.id); assert_eq!(nodes[0].last_updated, None); @@ -611,12 +609,10 @@ mod tests { assert_eq!(body.data, "Node registered successfully"); let nodes = app_state.node_store.get_nodes().await; - let nodes = match nodes { - Ok(nodes) => nodes, - Err(_) => { - panic!("Error getting nodes"); - } + let Ok(nodes) = nodes else { + panic!("Error getting nodes"); }; + assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].id, node.id); } diff --git a/crates/discovery/src/store/redis.rs b/crates/discovery/src/store/redis.rs index 508815c2..c0a0c36b 100644 --- a/crates/discovery/src/store/redis.rs +++ b/crates/discovery/src/store/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/orchestrator/src/api/routes/heartbeat.rs b/crates/orchestrator/src/api/routes/heartbeat.rs index a8110e61..4d6261f9 100644 --- a/crates/orchestrator/src/api/routes/heartbeat.rs +++ b/crates/orchestrator/src/api/routes/heartbeat.rs @@ -404,7 +404,7 @@ mod tests { let task = match task.try_into() { Ok(task) => task, - Err(e) => panic!("Failed to convert TaskRequest to Task: {}", e), + Err(e) => panic!("Failed to convert TaskRequest to Task: {e}"), }; let _ = app_state.store_context.task_store.add_task(task).await; diff --git a/crates/orchestrator/src/api/routes/task.rs b/crates/orchestrator/src/api/routes/task.rs index 7cff4b6d..fa167dc7 100644 --- a/crates/orchestrator/src/api/routes/task.rs +++ b/crates/orchestrator/src/api/routes/task.rs @@ -315,8 +315,8 @@ mod tests { // Add tasks in sequence with delays for i in 1..=3 { let task: Task = TaskRequest { - image: format!("test{}", i), - name: format!("test{}", i), + image: format!("test{i}"), + name: format!("test{i}"), ..Default::default() } .try_into() diff --git a/crates/orchestrator/src/plugins/node_groups/tests.rs b/crates/orchestrator/src/plugins/node_groups/tests.rs index a7d73b36..5fc22430 100644 --- a/crates/orchestrator/src/plugins/node_groups/tests.rs +++ b/crates/orchestrator/src/plugins/node_groups/tests.rs @@ -276,9 +276,7 @@ async fn test_group_formation_with_multiple_configs() { let _ = plugin.try_form_new_groups().await; let mut conn = plugin.store.client.get_connection().unwrap(); - let groups: Vec = conn - .keys(format!("{}*", GROUP_KEY_PREFIX).as_str()) - .unwrap(); + let groups: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*").as_str()).unwrap(); assert_eq!(groups.len(), 2); // Verify group was created @@ -1102,7 +1100,7 @@ async fn test_node_cannot_be_in_multiple_groups() { ); // Get all group keys - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); let group_copy = group_keys.clone(); // There should be exactly one group @@ -1167,7 +1165,7 @@ async fn test_node_cannot_be_in_multiple_groups() { let _ = plugin.try_form_new_groups().await; // Get updated group keys - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); // There should now be exactly two groups assert_eq!( @@ -1544,7 +1542,7 @@ async fn test_task_observer() { let _ = store_context.task_store.add_task(task2.clone()).await; let _ = plugin.try_form_new_groups().await; let all_tasks = store_context.task_store.get_all_tasks().await.unwrap(); - println!("All tasks: {:?}", all_tasks); + println!("All tasks: {all_tasks:?}"); assert_eq!(all_tasks.len(), 2); assert!(all_tasks[0].id != all_tasks[1].id); let topologies = get_task_topologies(&task).unwrap(); @@ -1588,7 +1586,7 @@ async fn test_task_observer() { .unwrap(); assert!(group_3.is_some()); let all_tasks = store_context.task_store.get_all_tasks().await.unwrap(); - println!("All tasks: {:?}", all_tasks); + println!("All tasks: {all_tasks:?}"); assert_eq!(all_tasks.len(), 2); // Manually assign the first task to the group to test immediate dissolution let group_3_before = plugin @@ -1615,7 +1613,7 @@ async fn test_task_observer() { .get_node_group(&node_3.address.to_string()) .await .unwrap(); - println!("Group 3 after task deletion: {:?}", group_3); + println!("Group 3 after task deletion: {group_3:?}"); // With new behavior, group should be dissolved immediately when its assigned task is deleted assert!(group_3.is_none()); @@ -1833,7 +1831,7 @@ async fn test_group_formation_priority() { let nodes: Vec<_> = (1..=4) .map(|i| { create_test_node( - &format!("0x{}234567890123456789012345678901234567890", i), + &format!("0x{i}234567890123456789012345678901234567890"), NodeStatus::Healthy, None, ) @@ -1863,7 +1861,7 @@ async fn test_group_formation_priority() { // Verify: Should form one 3-node group + one 1-node group // NOT four 1-node groups let mut conn = plugin.store.client.get_connection().unwrap(); - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); assert_eq!(group_keys.len(), 2, "Should form exactly 2 groups"); // Check group compositions @@ -1944,7 +1942,7 @@ async fn test_multiple_groups_same_configuration() { let nodes: Vec<_> = (1..=6) .map(|i| { create_test_node( - &format!("0x{}234567890123456789012345678901234567890", i), + &format!("0x{i}234567890123456789012345678901234567890"), NodeStatus::Healthy, None, ) @@ -1958,7 +1956,7 @@ async fn test_multiple_groups_same_configuration() { // Verify: Should create 3 groups of 2 nodes each let mut conn = plugin.store.client.get_connection().unwrap(); - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); assert_eq!(group_keys.len(), 3, "Should form exactly 3 groups"); // Verify all groups have exactly 2 nodes and same configuration @@ -2663,7 +2661,7 @@ async fn test_no_merge_when_policy_disabled() { // Create 3 nodes let nodes: Vec<_> = (1..=3) - .map(|i| create_test_node(&format!("0x{:040x}", i), NodeStatus::Healthy, None)) + .map(|i| create_test_node(&format!("0x{i:040x}"), NodeStatus::Healthy, None)) .collect(); for node in &nodes { diff --git a/crates/orchestrator/src/scheduler/mod.rs b/crates/orchestrator/src/scheduler/mod.rs index 711f313f..d5ffa506 100644 --- a/crates/orchestrator/src/scheduler/mod.rs +++ b/crates/orchestrator/src/scheduler/mod.rs @@ -144,12 +144,12 @@ mod tests { ); assert_eq!( env_vars.get("NODE_VAR").unwrap(), - &format!("node-{}", node_address) + &format!("node-{node_address}") ); // Check cmd replacement let cmd = returned_task.cmd.unwrap(); assert_eq!(cmd[0], format!("--task={}", task.id)); - assert_eq!(cmd[1], format!("--node={}", node_address)); + assert_eq!(cmd[1], format!("--node={node_address}")); } } diff --git a/crates/orchestrator/src/status_update/mod.rs b/crates/orchestrator/src/status_update/mod.rs index b2738488..67140cbc 100644 --- a/crates/orchestrator/src/status_update/mod.rs +++ b/crates/orchestrator/src/status_update/mod.rs @@ -372,6 +372,7 @@ async fn process_node( } #[cfg(test)] +#[allow(clippy::unused_async)] async fn is_node_in_pool(_: Contracts, _: u32, _: &OrchestratorNode) -> bool { true } @@ -433,7 +434,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let heartbeat = HeartbeatRequest { address: node.address.to_string(), @@ -451,7 +452,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } let _ = updater.process_nodes().await; @@ -510,7 +511,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -563,7 +564,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -623,7 +624,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } if let Err(e) = app_state .store_context @@ -631,7 +632,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); } let mode = ServerMode::Full; @@ -687,7 +688,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); }; let heartbeat = HeartbeatRequest { @@ -702,7 +703,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } if let Err(e) = app_state .store_context @@ -710,7 +711,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; @@ -772,7 +773,7 @@ mod tests { .set_unhealthy_counter(&node1.address, 1) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); }; if let Err(e) = app_state .store_context @@ -780,7 +781,7 @@ mod tests { .add_node(node1.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let node2 = OrchestratorNode { @@ -797,7 +798,7 @@ mod tests { .add_node(node2.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; @@ -873,7 +874,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } if let Err(e) = app_state .store_context @@ -881,7 +882,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); } let mode = ServerMode::Full; @@ -926,7 +927,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } sleep(Duration::from_secs(5)).await; @@ -960,7 +961,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -1029,7 +1030,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( diff --git a/crates/orchestrator/src/store/core/redis.rs b/crates/orchestrator/src/store/core/redis.rs index 79f57ce8..3b524b33 100644 --- a/crates/orchestrator/src/store/core/redis.rs +++ b/crates/orchestrator/src/store/core/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/orchestrator/src/store/domains/heartbeat_store.rs b/crates/orchestrator/src/store/domains/heartbeat_store.rs index b2f8138a..8bb43374 100644 --- a/crates/orchestrator/src/store/domains/heartbeat_store.rs +++ b/crates/orchestrator/src/store/domains/heartbeat_store.rs @@ -80,7 +80,7 @@ impl HeartbeatStore { .get_multiplexed_async_connection() .await .map_err(|_| anyhow!("Failed to get connection"))?; - let key = format!("{}:{}", ORCHESTRATOR_UNHEALTHY_COUNTER_KEY, address); + let key = format!("{ORCHESTRATOR_UNHEALTHY_COUNTER_KEY}:{address}"); con.set(key, counter.to_string()) .await .map_err(|_| anyhow!("Failed to set value")) diff --git a/crates/orchestrator/src/store/domains/metrics_store.rs b/crates/orchestrator/src/store/domains/metrics_store.rs index 1a0d79ac..5520860a 100644 --- a/crates/orchestrator/src/store/domains/metrics_store.rs +++ b/crates/orchestrator/src/store/domains/metrics_store.rs @@ -145,7 +145,7 @@ impl MetricsStore { task_id: &str, ) -> Result> { let mut con = self.redis.client.get_multiplexed_async_connection().await?; - let pattern = format!("{}:*", ORCHESTRATOR_NODE_METRICS_STORE); + let pattern = format!("{ORCHESTRATOR_NODE_METRICS_STORE}:*"); // Scan all node keys let mut iter: redis::AsyncIter = con.scan_match(&pattern).await?; diff --git a/crates/prime-core/Cargo.toml b/crates/prime-core/Cargo.toml new file mode 100644 index 00000000..bfcef45e --- /dev/null +++ b/crates/prime-core/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "prime-core" +version = "0.1.0" +edition = "2021" + +[lints] +workspace = true + +[lib] +name = "prime_core" +path = "src/lib.rs" + +[dependencies] +shared = { workspace = true } +alloy = { workspace = true } +alloy-provider = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +url = { workspace = true } +actix-web = { workspace = true } +anyhow = { workspace = true } +futures-util = { workspace = true } +hex = { workspace = true } +uuid = { workspace = true } +log = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +redis = { workspace = true, features = ["aio", "tokio-comp"] } +rand_v8 = { workspace = true } +env_logger = { workspace = true } +subtle = "2.6.1" diff --git a/crates/prime-core/src/lib.rs b/crates/prime-core/src/lib.rs new file mode 100644 index 00000000..1bf04f8a --- /dev/null +++ b/crates/prime-core/src/lib.rs @@ -0,0 +1 @@ +pub mod operations; diff --git a/crates/prime-core/src/operations/compute_node.rs b/crates/prime-core/src/operations/compute_node.rs new file mode 100644 index 00000000..c294291a --- /dev/null +++ b/crates/prime-core/src/operations/compute_node.rs @@ -0,0 +1,92 @@ +use alloy::{primitives::utils::keccak256 as keccak, primitives::U256, signers::Signer}; +use anyhow::Result; +use shared::web3::wallet::Wallet; +use shared::web3::{contracts::core::builder::Contracts, wallet::WalletProvider}; + +pub struct ComputeNodeOperations<'c> { + provider_wallet: &'c Wallet, + node_wallet: &'c Wallet, + contracts: Contracts, +} + +impl<'c> ComputeNodeOperations<'c> { + pub fn new( + provider_wallet: &'c Wallet, + node_wallet: &'c Wallet, + contracts: Contracts, + ) -> Self { + Self { + provider_wallet, + node_wallet, + contracts, + } + } + + pub async fn check_compute_node_exists(&self) -> Result> { + let compute_node = self + .contracts + .compute_registry + .get_node( + self.provider_wallet.wallet.default_signer().address(), + self.node_wallet.wallet.default_signer().address(), + ) + .await; + + match compute_node { + Ok(_) => Ok(true), + Err(_) => Ok(false), + } + } + + // Returns true if the compute node was added, false if it already exists + pub async fn add_compute_node( + &self, + compute_units: U256, + ) -> Result> { + log::info!("🔄 Adding compute node"); + + if self.check_compute_node_exists().await? { + return Ok(false); + } + + log::info!("Adding compute node"); + let provider_address = self.provider_wallet.wallet.default_signer().address(); + let node_address = self.node_wallet.wallet.default_signer().address(); + let digest = keccak([provider_address.as_slice(), node_address.as_slice()].concat()); + + let signature = self + .node_wallet + .signer + .sign_message(digest.as_slice()) + .await? + .as_bytes(); + + // Create the signature bytes + let add_node_tx = self + .contracts + .prime_network + .add_compute_node(node_address, compute_units, signature.to_vec()) + .await?; + log::info!("Add node tx: {add_node_tx:?}"); + Ok(true) + } + + pub async fn remove_compute_node(&self) -> Result> { + log::info!("🔄 Removing compute node"); + + if !self.check_compute_node_exists().await? { + return Ok(false); + } + + log::info!("Removing compute node"); + let provider_address = self.provider_wallet.wallet.default_signer().address(); + let node_address = self.node_wallet.wallet.default_signer().address(); + let remove_node_tx = self + .contracts + .prime_network + .remove_compute_node(provider_address, node_address) + .await?; + log::info!("Remove node tx: {remove_node_tx:?}"); + Ok(true) + } +} diff --git a/crates/prime-core/src/operations/mod.rs b/crates/prime-core/src/operations/mod.rs new file mode 100644 index 00000000..089315f5 --- /dev/null +++ b/crates/prime-core/src/operations/mod.rs @@ -0,0 +1,2 @@ +pub mod compute_node; +pub mod provider; diff --git a/crates/worker/src/operations/provider.rs b/crates/prime-core/src/operations/provider.rs similarity index 67% rename from crates/worker/src/operations/provider.rs rename to crates/prime-core/src/operations/provider.rs index fb8aba5f..c07f6189 100644 --- a/crates/worker/src/operations/provider.rs +++ b/crates/prime-core/src/operations/provider.rs @@ -1,4 +1,3 @@ -use crate::console::Console; use alloy::primitives::utils::format_ether; use alloy::primitives::{Address, U256}; use log::error; @@ -9,18 +8,14 @@ use std::{fmt, io}; use tokio::time::{sleep, Duration}; use tokio_util::sync::CancellationToken; -pub(crate) struct ProviderOperations { +pub struct ProviderOperations { wallet: Wallet, contracts: Contracts, auto_accept: bool, } impl ProviderOperations { - pub(crate) fn new( - wallet: Wallet, - contracts: Contracts, - auto_accept: bool, - ) -> Self { + pub fn new(wallet: Wallet, contracts: Contracts, auto_accept: bool) -> Self { Self { wallet, contracts, @@ -44,7 +39,7 @@ impl ProviderOperations { } } - pub(crate) fn start_monitoring(&self, cancellation_token: CancellationToken) { + pub fn start_monitoring(&self, cancellation_token: CancellationToken) { let provider_address = self.wallet.wallet.default_signer().address(); let contracts = self.contracts.clone(); @@ -58,12 +53,12 @@ impl ProviderOperations { loop { tokio::select! { _ = cancellation_token.cancelled() => { - Console::info("Monitor", "Shutting down provider status monitor..."); + log::info!("Shutting down provider status monitor..."); break; } _ = async { let Some(stake_manager) = contracts.stake_manager.as_ref() else { - Console::user_error("Cannot start monitoring - stake manager not initialized"); + log::error!("Cannot start monitoring - stake manager not initialized"); return; }; @@ -71,21 +66,21 @@ impl ProviderOperations { match stake_manager.get_stake(provider_address).await { Ok(stake) => { if first_check || stake != last_stake { - Console::info("🔄 Chain Sync - Provider stake", &format_ether(stake)); + log::info!("🔄 Chain Sync - Provider stake: {}", format_ether(stake)); if !first_check { if stake < last_stake { - Console::warning(&format!("Stake decreased - possible slashing detected: From {} to {}", + log::warn!("Stake decreased - possible slashing detected: From {} to {}", format_ether(last_stake), format_ether(stake) - )); + ); if stake == U256::ZERO { - Console::warning("Stake is 0 - you might have to restart the node to increase your stake (if you still have balance left)"); + log::warn!("Stake is 0 - you might have to restart the node to increase your stake (if you still have balance left)"); } } else { - Console::info("🔄 Chain Sync - Stake changed", &format!("From {} to {}", + log::info!("🔄 Chain Sync - Stake increased: From {} to {}", format_ether(last_stake), format_ether(stake) - )); + ); } } last_stake = stake; @@ -102,13 +97,7 @@ impl ProviderOperations { match contracts.ai_token.balance_of(provider_address).await { Ok(balance) => { if first_check || balance != last_balance { - Console::info("🔄 Chain Sync - Balance", &format_ether(balance)); - if !first_check { - Console::info("🔄 Chain Sync - Balance changed", &format!("From {} to {}", - format_ether(last_balance), - format_ether(balance) - )); - } + log::info!("🔄 Chain Sync - Balance: {}", format_ether(balance)); last_balance = balance; } Some(balance) @@ -123,12 +112,12 @@ impl ProviderOperations { match contracts.compute_registry.get_provider(provider_address).await { Ok(provider) => { if first_check || provider.is_whitelisted != last_whitelist_status { - Console::info("🔄 Chain Sync - Whitelist status", &format!("{}", provider.is_whitelisted)); + log::info!("🔄 Chain Sync - Whitelist status: {}", provider.is_whitelisted); if !first_check { - Console::info("🔄 Chain Sync - Whitelist status changed", &format!("From {} to {}", + log::info!("🔄 Chain Sync - Whitelist status changed: {} -> {}", last_whitelist_status, provider.is_whitelisted - )); + ); } last_whitelist_status = provider.is_whitelisted; } @@ -146,7 +135,7 @@ impl ProviderOperations { }); } - pub(crate) async fn check_provider_exists(&self) -> Result { + pub async fn check_provider_exists(&self) -> Result { let address = self.wallet.wallet.default_signer().address(); let provider = self @@ -159,7 +148,7 @@ impl ProviderOperations { Ok(provider.provider_address != Address::default()) } - pub(crate) async fn check_provider_whitelisted(&self) -> Result { + pub async fn check_provider_whitelisted(&self) -> Result { let address = self.wallet.wallet.default_signer().address(); let provider = self @@ -171,29 +160,32 @@ impl ProviderOperations { Ok(provider.is_whitelisted) } - - pub(crate) async fn retry_register_provider( + pub async fn retry_register_provider( &self, stake: U256, max_attempts: u32, - cancellation_token: CancellationToken, + cancellation_token: Option, ) -> Result<(), ProviderError> { - Console::title("Registering Provider"); + log::info!("Registering Provider"); let mut attempts = 0; while attempts < max_attempts || max_attempts == 0 { - Console::progress("Registering provider..."); + log::info!("Registering provider..."); match self.register_provider(stake).await { Ok(_) => { return Ok(()); } Err(e) => match e { ProviderError::NotWhitelisted | ProviderError::InsufficientBalance => { - Console::info("Info", "Retrying in 10 seconds..."); - tokio::select! { - _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} - _ = cancellation_token.cancelled() => { - return Err(e); + log::info!("Retrying in 10 seconds..."); + if let Some(ref token) = cancellation_token { + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} + _ = token.cancelled() => { + return Err(e); + } } + } else { + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; } attempts += 1; continue; @@ -206,7 +198,7 @@ impl ProviderOperations { Err(ProviderError::Other) } - pub(crate) async fn register_provider(&self, stake: U256) -> Result<(), ProviderError> { + pub async fn register_provider(&self, stake: U256) -> Result<(), ProviderError> { let address = self.wallet.wallet.default_signer().address(); let balance: U256 = self .contracts @@ -224,42 +216,39 @@ impl ProviderOperations { let provider_exists = self.check_provider_exists().await?; if !provider_exists { - Console::info("Balance", &format_ether(balance)); - Console::info( - "ETH Balance", + log::info!("Balance: {}", &format_ether(balance)); + log::info!( + "ETH Balance: {}", &format!("{} ETH", format_ether(U256::from(eth_balance))), ); if balance < stake { - Console::user_error(&format!( - "Insufficient balance for stake: {}", - format_ether(stake) - )); + log::error!("Insufficient balance for stake: {}", format_ether(stake)); return Err(ProviderError::InsufficientBalance); } if !self.prompt_user_confirmation(&format!( "Do you want to approve staking {}?", format_ether(stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving for Stake transaction"); + log::info!("Approving for Stake transaction"); self.contracts .ai_token .approve(stake) .await .map_err(|_| ProviderError::Other)?; - Console::progress("Registering Provider"); + log::info!("Registering Provider"); let Ok(register_tx) = self.contracts.prime_network.register_provider(stake).await else { return Err(ProviderError::Other); }; - Console::info("Registration tx", &format!("{register_tx:?}")); + log::info!("Registration tx: {}", &format!("{register_tx:?}")); } // Get provider details again - cleanup later - Console::progress("Getting provider details"); + log::info!("Getting provider details"); let _ = self .contracts .compute_registry @@ -270,32 +259,29 @@ impl ProviderOperations { let provider_exists = self.check_provider_exists().await?; if !provider_exists { - Console::info("Balance", &format_ether(balance)); - Console::info( - "ETH Balance", + log::info!("Balance: {}", &format_ether(balance)); + log::info!( + "ETH Balance: {}", &format!("{} ETH", format_ether(U256::from(eth_balance))), ); if balance < stake { - Console::user_error(&format!( - "Insufficient balance for stake: {}", - format_ether(stake) - )); + log::error!("Insufficient balance for stake: {}", format_ether(stake)); return Err(ProviderError::InsufficientBalance); } if !self.prompt_user_confirmation(&format!( "Do you want to approve staking {}?", format_ether(stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving Stake transaction"); + log::info!("Approving Stake transaction"); self.contracts.ai_token.approve(stake).await.map_err(|e| { error!("Failed to approve stake: {e}"); ProviderError::Other })?; - Console::progress("Registering Provider"); + log::info!("Registering Provider"); let register_tx = match self.contracts.prime_network.register_provider(stake).await { Ok(tx) => tx, Err(e) => { @@ -303,7 +289,7 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info("Registration tx", &format!("{register_tx:?}")); + log::info!("Registration tx: {register_tx:?}"); } let provider = self @@ -315,23 +301,23 @@ impl ProviderOperations { let provider_exists = provider.provider_address != Address::default(); if !provider_exists { - Console::user_error( - "Provider could not be registered. Please ensure your balance is high enough.", + log::error!( + "Provider could not be registered. Please ensure your balance is high enough." ); return Err(ProviderError::Other); } - Console::success("Provider registered"); + log::info!("Provider registered"); if !provider.is_whitelisted { - Console::user_error("Provider is not whitelisted yet."); + log::error!("Provider is not whitelisted yet."); return Err(ProviderError::NotWhitelisted); } Ok(()) } - pub(crate) async fn increase_stake(&self, additional_stake: U256) -> Result<(), ProviderError> { - Console::title("💰 Increasing Provider Stake"); + pub async fn increase_stake(&self, additional_stake: U256) -> Result<(), ProviderError> { + log::info!("💰 Increasing Provider Stake"); let address = self.wallet.wallet.default_signer().address(); let balance: U256 = self @@ -341,11 +327,14 @@ impl ProviderOperations { .await .map_err(|_| ProviderError::Other)?; - Console::info("Current Balance", &format_ether(balance)); - Console::info("Additional stake amount", &format_ether(additional_stake)); + log::info!("Current Balance: {}", &format_ether(balance)); + log::info!( + "Additional stake amount: {}", + &format_ether(additional_stake) + ); if balance < additional_stake { - Console::user_error("Insufficient balance for stake increase"); + log::error!("Insufficient balance for stake increase"); return Err(ProviderError::Other); } @@ -353,20 +342,20 @@ impl ProviderOperations { "Do you want to approve staking {} additional funds?", format_ether(additional_stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving additional stake"); + log::info!("Approving additional stake"); let approve_tx = self .contracts .ai_token .approve(additional_stake) .await .map_err(|_| ProviderError::Other)?; - Console::info("Transaction approved", &format!("{approve_tx:?}")); + log::info!("Transaction approved: {}", &format!("{approve_tx:?}")); - Console::progress("Increasing stake"); + log::info!("Increasing stake"); let stake_tx = match self.contracts.prime_network.stake(additional_stake).await { Ok(tx) => tx, Err(e) => { @@ -374,17 +363,15 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info( - "Stake increase transaction completed: ", - &format!("{stake_tx:?}"), + log::info!( + "Stake increase transaction completed: {}", + &format!("{stake_tx:?}") ); - Console::success("Provider stake increased successfully"); Ok(()) } - pub(crate) async fn reclaim_stake(&self, amount: U256) -> Result<(), ProviderError> { - Console::progress("Reclaiming stake"); + pub async fn reclaim_stake(&self, amount: U256) -> Result<(), ProviderError> { let reclaim_tx = match self.contracts.prime_network.reclaim_stake(amount).await { Ok(tx) => tx, Err(e) => { @@ -392,17 +379,16 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info( - "Stake reclaim transaction completed: ", - &format!("{reclaim_tx:?}"), + log::info!( + "Stake reclaim transaction completed: {}", + &format!("{reclaim_tx:?}") ); - Console::success("Provider stake reclaimed successfully"); Ok(()) } } #[derive(Debug)] -pub(crate) enum ProviderError { +pub enum ProviderError { NotWhitelisted, UserCancelled, Other, diff --git a/crates/prime-protocol-py/.gitignore b/crates/prime-protocol-py/.gitignore new file mode 100644 index 00000000..454f9f33 --- /dev/null +++ b/crates/prime-protocol-py/.gitignore @@ -0,0 +1,24 @@ +# Python +__pycache__/ +*.py[cod] +*.so +*.pyd +*.egg-info/ +dist/ + +# Virtual environments +.venv/ + +# Testing +.pytest_cache/ + +# IDE +.vscode/ +.idea/ + +# Rust/Maturin +target/ +Cargo.lock + +# OS +.DS_Store \ No newline at end of file diff --git a/crates/prime-protocol-py/.python-version b/crates/prime-protocol-py/.python-version new file mode 100644 index 00000000..4b7e4839 --- /dev/null +++ b/crates/prime-protocol-py/.python-version @@ -0,0 +1 @@ +3.11 \ No newline at end of file diff --git a/crates/prime-protocol-py/Cargo.toml b/crates/prime-protocol-py/Cargo.toml new file mode 100644 index 00000000..9441afe1 --- /dev/null +++ b/crates/prime-protocol-py/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "prime-protocol-py" +version = "0.1.0" +authors = ["Prime Protocol"] +edition = "2021" +rust-version = "1.70" + +[lib] +name = "primeprotocol" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.25.1", features = ["extension-module"] } +thiserror = "1.0" +shared = { workspace = true } +prime-core = { workspace = true } +alloy = { workspace = true } +alloy-provider = { workspace = true } +tokio = { version = "1.35", features = ["rt"] } +url = "2.5" +log = { workspace = true } +pyo3-log = "0.12.4" + +[dev-dependencies] +tokio-test = "0.4" + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +strip = true + diff --git a/crates/prime-protocol-py/Makefile b/crates/prime-protocol-py/Makefile new file mode 100644 index 00000000..fe1858d0 --- /dev/null +++ b/crates/prime-protocol-py/Makefile @@ -0,0 +1,46 @@ +.PHONY: install +install: + @command -v uv > /dev/null || (echo "Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" && exit 1) + @./setup.sh # Uses uv for fast package management + +.PHONY: build +build: + @source .venv/bin/activate && maturin develop + @source .venv/bin/activate && uv pip install --force-reinstall -e . + +.PHONY: dev +dev: + @source .venv/bin/activate && maturin develop --watch + +.PHONY: build-release +build-release: + @source .venv/bin/activate && maturin build --release --strip + +.PHONY: test +test: + @source .venv/bin/activate && pytest tests/ -v + +.PHONY: example +example: + @source .venv/bin/activate && python examples/basic_usage.py + +.PHONY: clean +clean: + @rm -rf target/ dist/ *.egg-info .pytest_cache __pycache__ .venv/ + +.PHONY: clear-cache +clear-cache: + @uv cache clean + @echo "uv cache cleared" + +.PHONY: help +help: + @echo "Available commands:" + @echo " make install - Setup environment and install dependencies" + @echo " make build - Build development version" + @echo " make dev - Build with hot reload (watches for changes)" + @echo " make build-release - Build release wheel" + @echo " make test - Run tests" + @echo " make example - Run example script" + @echo " make clean - Clean build artifacts" + @echo " make clear-cache - Clear uv cache" \ No newline at end of file diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md new file mode 100644 index 00000000..439218c3 --- /dev/null +++ b/crates/prime-protocol-py/README.md @@ -0,0 +1,46 @@ +# Prime Protocol Python Client + +Python bindings for checking if compute pools exist. + +## Build + +```bash +# Install uv (one-time) +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Setup and build +cd crates/prime-protocol-py +make install +``` + +## Usage + +```python +from primeprotocol import PrimeProtocolClient + +client = PrimeProtocolClient("http://localhost:8545") +exists = client.compute_pool_exists(0) +``` + +## Development + +```bash +make build # Build development version +make test # Run tests +make example # Run example +make clean # Clean artifacts +make help # Show all commands +``` + +## Installing in other projects + +```bash +# Build the wheel +make build-release + +# Install with uv (recommended) +uv pip install target/wheels/primeprotocol-*.whl + +# Or install directly from source +uv pip install /path/to/prime-protocol-py/ +``` \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py new file mode 100644 index 00000000..639eccf7 --- /dev/null +++ b/crates/prime-protocol-py/examples/basic_usage.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Python client.""" + +import logging +import os +from primeprotocol import PrimeProtocolClient + +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def main(): + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + pool_id = os.getenv("POOL_ID", 0) + private_key_provider = os.getenv("PRIVATE_KEY_PROVIDER", None) + private_key_node = os.getenv("PRIVATE_KEY_NODE", None) + + logging.info(f"Connecting to: {rpc_url}") + client = PrimeProtocolClient(pool_id, rpc_url, private_key_provider, private_key_node) + client.start() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/prime-protocol-py/pyproject.toml b/crates/prime-protocol-py/pyproject.toml new file mode 100644 index 00000000..9834d8b4 --- /dev/null +++ b/crates/prime-protocol-py/pyproject.toml @@ -0,0 +1,37 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "primeprotocol" +description = "Simple Python bindings for Prime Protocol client" +readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} +keywords = ["prime", "protocol"] +authors = [ + {name = "Prime Protocol", email = "jannik@primeintellect.ai"} +] +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dynamic = ["version"] + +[project.urls] +"Homepage" = "https://github.com/primeprotocol/protocol" +"Bug Tracker" = "https://github.com/primeprotocol/protocol/issues" + +[tool.maturin] +features = ["pyo3/extension-module"] +module-name = "primeprotocol" \ No newline at end of file diff --git a/crates/prime-protocol-py/requirements-dev.txt b/crates/prime-protocol-py/requirements-dev.txt new file mode 100644 index 00000000..f2af3c5d --- /dev/null +++ b/crates/prime-protocol-py/requirements-dev.txt @@ -0,0 +1,3 @@ +# Development dependencies +maturin>=1.0,<2.0 +pytest>=7.0 \ No newline at end of file diff --git a/crates/prime-protocol-py/setup.sh b/crates/prime-protocol-py/setup.sh new file mode 100755 index 00000000..7609b236 --- /dev/null +++ b/crates/prime-protocol-py/setup.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e + +# Check if uv is installed +if ! command -v uv &> /dev/null; then + echo "Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" + exit 1 +fi + +# Setup environment +uv venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +maturin develop + +echo "Setup complete." \ No newline at end of file diff --git a/crates/prime-protocol-py/src/client.rs b/crates/prime-protocol-py/src/client.rs new file mode 100644 index 00000000..b4139b7b --- /dev/null +++ b/crates/prime-protocol-py/src/client.rs @@ -0,0 +1,294 @@ +use crate::error::{PrimeProtocolError, Result}; +use alloy::primitives::utils::format_ether; +use alloy::primitives::U256; +use prime_core::operations::compute_node::ComputeNodeOperations; +use prime_core::operations::provider::ProviderOperations; +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::contracts::structs::compute_pool::PoolStatus; +use shared::web3::wallet::{Wallet, WalletProvider}; +use url::Url; + +pub struct PrimeProtocolClientCore { + rpc_url: String, + compute_pool_id: u64, + private_key_provider: Option, + private_key_node: Option, + auto_accept_transactions: bool, + funding_retry_count: u32, +} + +impl PrimeProtocolClientCore { + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + auto_accept_transactions: Option, + funding_retry_count: Option, + ) -> Result { + if rpc_url.is_empty() { + return Err(PrimeProtocolError::InvalidConfig( + "RPC URL cannot be empty".to_string(), + )); + } + + Url::parse(&rpc_url) + .map_err(|_| PrimeProtocolError::InvalidConfig("Invalid RPC URL format".to_string()))?; + + Ok(Self { + rpc_url, + compute_pool_id, + private_key_provider, + private_key_node, + auto_accept_transactions: auto_accept_transactions.unwrap_or(true), + funding_retry_count: funding_retry_count.unwrap_or(10), + }) + } + + pub async fn start_async(&self) -> Result<()> { + let (provider_wallet, node_wallet, contracts) = + self.initialize_blockchain_components().await?; + let pool_info = self.wait_for_active_pool(&contracts).await?; + + log::info!("Pool info: {:?}", pool_info); + + self.ensure_provider_registered(&provider_wallet, &contracts) + .await?; + self.ensure_compute_node_registered(&provider_wallet, &node_wallet, &contracts) + .await?; + + // TODO: Optional - run hardware check? + // TODO: p2p reachable? + + Ok(()) + } + + async fn initialize_blockchain_components( + &self, + ) -> Result<(Wallet, Wallet, Contracts)> { + let private_key_provider = self.get_private_key_provider()?; + let private_key_node = self.get_private_key_node()?; + let rpc_url = Url::parse(&self.rpc_url).unwrap(); + + let provider_wallet = Wallet::new(&private_key_provider, rpc_url.clone()).map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to create provider wallet: {}", e)) + })?; + + let node_wallet = Wallet::new(&private_key_node, rpc_url.clone()).map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to create node wallet: {}", e)) + })?; + + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .map_err(|e| PrimeProtocolError::BlockchainError(e.to_string()))?; + + Ok((provider_wallet, node_wallet, contracts)) + } + + async fn wait_for_active_pool( + &self, + contracts: &Contracts, + ) -> Result { + loop { + match contracts + .compute_pool + .get_pool_info(U256::from(self.compute_pool_id)) + .await + { + Ok(pool) if pool.status == PoolStatus::ACTIVE => return Ok(pool), + Ok(_) => { + log::info!("Pool not active yet, waiting..."); + tokio::time::sleep(tokio::time::Duration::from_secs(15)).await; + } + Err(e) => { + return Err(PrimeProtocolError::BlockchainError(format!( + "Failed to get pool info: {}", + e + ))); + } + } + } + } + async fn ensure_provider_registered( + &self, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let provider_ops = ProviderOperations::new( + provider_wallet.clone(), + contracts.clone(), + self.auto_accept_transactions, + ); + + // Check if provider exists + let provider_exists = provider_ops.check_provider_exists().await.map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check if provider exists: {}", + e + )) + })?; + + let Some(stake_manager) = contracts.stake_manager.as_ref() else { + return Err(PrimeProtocolError::BlockchainError( + "Stake manager not initialized".to_string(), + )); + }; + + // Check if provider is whitelisted + let is_whitelisted = provider_ops + .check_provider_whitelisted() + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check provider whitelist status: {}", + e + )) + })?; + + // todo: revisit this + if provider_exists && is_whitelisted { + log::info!("Provider is registered and whitelisted"); + } else { + // For now, we'll use a default compute_units value - this should be configurable + let compute_units = U256::from(1); + + let required_stake = stake_manager + .calculate_stake(compute_units, U256::from(0)) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to calculate required stake: {}", + e + )) + })?; + + log::info!("Required stake: {}", format_ether(required_stake)); + + provider_ops + .retry_register_provider(required_stake, self.funding_retry_count, None) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to register provider: {}", + e + )) + })?; + + log::info!("Provider registered successfully"); + } + + // Get provider's current total compute and stake + let provider_total_compute = contracts + .compute_registry + .get_provider_total_compute(provider_wallet.wallet.default_signer().address()) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to get provider total compute: {}", + e + )) + })?; + + let provider_stake = stake_manager + .get_stake(provider_wallet.wallet.default_signer().address()) + .await + .unwrap_or_default(); + + // For now, we'll use a default compute_units value - this should be configurable + let compute_units = U256::from(1); + + let required_stake = stake_manager + .calculate_stake(compute_units, provider_total_compute) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to calculate required stake: {}", + e + )) + })?; + + if required_stake > provider_stake { + log::info!( + "Provider stake is less than required stake. Required: {} tokens, Current: {} tokens", + format_ether(required_stake), + format_ether(provider_stake) + ); + + provider_ops + .increase_stake(required_stake - provider_stake) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) + })?; + + log::info!("Successfully increased stake"); + } + + Ok(()) + } + + async fn ensure_compute_node_registered( + &self, + provider_wallet: &Wallet, + node_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let compute_node_ops = + ComputeNodeOperations::new(provider_wallet, node_wallet, contracts.clone()); + + // Check if compute node exists + let compute_node_exists = + compute_node_ops + .check_compute_node_exists() + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check if compute node exists: {}", + e + )) + })?; + + if compute_node_exists { + log::info!("Compute node is already registered"); + return Ok(()); + } + + // If compute node doesn't exist, register it + // For now, we'll use default compute specs - this should be configurable + compute_node_ops + .add_compute_node(U256::from(1)) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to register compute node: {}", + e + )) + })?; + + log::info!("Compute node registered successfully"); + Ok(()) + } + + fn get_private_key_provider(&self) -> Result { + match &self.private_key_provider { + Some(key) => Ok(key.clone()), + None => std::env::var("PRIVATE_KEY_PROVIDER").map_err(|_| { + PrimeProtocolError::InvalidConfig("PRIVATE_KEY_PROVIDER must be set".to_string()) + }), + } + } + + fn get_private_key_node(&self) -> Result { + match &self.private_key_node { + Some(key) => Ok(key.clone()), + None => std::env::var("PRIVATE_KEY_NODE").map_err(|_| { + PrimeProtocolError::InvalidConfig("PRIVATE_KEY_NODE must be set".to_string()) + }), + } + } +} diff --git a/crates/prime-protocol-py/src/error.rs b/crates/prime-protocol-py/src/error.rs new file mode 100644 index 00000000..cf561595 --- /dev/null +++ b/crates/prime-protocol-py/src/error.rs @@ -0,0 +1,21 @@ +use thiserror::Error; + +/// Result type alias for Prime Protocol operations +pub type Result = std::result::Result; + +/// Errors that can occur in the Prime Protocol client +#[derive(Debug, Error)] +pub enum PrimeProtocolError { + /// Invalid configuration provided + #[error("Invalid configuration: {0}")] + InvalidConfig(String), + + /// Blockchain interaction error + #[error("Blockchain error: {0}")] + BlockchainError(String), + + /// General runtime error + #[error("Runtime error: {0}")] + #[allow(dead_code)] + RuntimeError(String), +} diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/prime-protocol-py/src/lib.rs new file mode 100644 index 00000000..faa72b0c --- /dev/null +++ b/crates/prime-protocol-py/src/lib.rs @@ -0,0 +1,62 @@ +use pyo3::prelude::*; + +mod client; +mod error; + +use client::PrimeProtocolClientCore; + +// todo: We need a manager + validator side to send messages + +/// Prime Protocol Python client +#[pyclass] +pub struct PrimeProtocolClient { + inner: PrimeProtocolClientCore, +} + +#[pymethods] +impl PrimeProtocolClient { + #[new] + #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None))] + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + ) -> PyResult { + // todo: revisit default arguments here that are currently none + let inner = PrimeProtocolClientCore::new( + compute_pool_id, + rpc_url, + private_key_provider, + private_key_node, + None, + None, + ) + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(Self { inner }) + } + + pub fn start(&self) -> PyResult<()> { + // Create a new runtime for this call + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Run the async function + let result = rt.block_on(self.inner.start_async()); + + // Clean shutdown + rt.shutdown_background(); + + result.map_err(|e| PyErr::new::(e.to_string())) + } +} + +#[pymodule] +fn primeprotocol(m: &Bound<'_, PyModule>) -> PyResult<()> { + pyo3_log::init(); + m.add_class::()?; + Ok(()) +} diff --git a/crates/prime-protocol-py/tests/test_client.py b/crates/prime-protocol-py/tests/test_client.py new file mode 100644 index 00000000..57b02400 --- /dev/null +++ b/crates/prime-protocol-py/tests/test_client.py @@ -0,0 +1,29 @@ +"""Basic tests for the Prime Protocol Python client.""" + +import pytest +from primeprotocol import PrimeProtocolClient + + +def test_client_creation(): + """Test that client can be created with valid RPC URL.""" + client = PrimeProtocolClient("http://localhost:8545") + assert client is not None + + +def test_client_creation_with_empty_url(): + """Test that client creation fails with empty RPC URL.""" + with pytest.raises(ValueError): + PrimeProtocolClient("") + + +def test_client_creation_with_invalid_url(): + """Test that client creation fails with invalid RPC URL.""" + with pytest.raises(ValueError): + PrimeProtocolClient("not-a-valid-url") + + +def test_has_compute_pool_exists_method(): + """Test that the client has the compute_pool_exists method.""" + client = PrimeProtocolClient("http://example.com:8545") + assert hasattr(client, 'compute_pool_exists') + assert callable(getattr(client, 'compute_pool_exists')) \ No newline at end of file diff --git a/crates/prime-protocol-py/uv.lock b/crates/prime-protocol-py/uv.lock new file mode 100644 index 00000000..639a70ba --- /dev/null +++ b/crates/prime-protocol-py/uv.lock @@ -0,0 +1,7 @@ +version = 1 +requires-python = ">=3.8" + +[[package]] +name = "primeprotocol" +version = "0.1.0" +source = { editable = "." } diff --git a/crates/shared/src/models/metric.rs b/crates/shared/src/models/metric.rs index 47b27f24..b85c4926 100644 --- a/crates/shared/src/models/metric.rs +++ b/crates/shared/src/models/metric.rs @@ -58,7 +58,7 @@ mod tests { let invalid_values = vec![(f64::INFINITY, "infinite value"), (f64::NAN, "NaN value")]; for (value, case) in invalid_values { let entry = MetricEntry::new(key.clone(), value); - assert!(entry.is_err(), "Should fail for {}", case); + assert!(entry.is_err(), "Should fail for {case}"); } } diff --git a/crates/shared/src/security/auth_signature_middleware.rs b/crates/shared/src/security/auth_signature_middleware.rs index 1c4c1e10..8ba7767e 100644 --- a/crates/shared/src/security/auth_signature_middleware.rs +++ b/crates/shared/src/security/auth_signature_middleware.rs @@ -634,10 +634,10 @@ mod tests { .await; log::info!("Address: {}", wallet.wallet.default_signer().address()); - log::info!("Signature: {}", signature); - log::info!("Nonce: {}", nonce); + log::info!("Signature: {signature}"); + log::info!("Nonce: {nonce}"); let req = test::TestRequest::get() - .uri(&format!("/test?nonce={}", nonce)) + .uri(&format!("/test?nonce={nonce}")) .insert_header(( "x-address", wallet.wallet.default_signer().address().to_string(), @@ -801,8 +801,7 @@ mod tests { // Create multiple addresses let addresses: Vec
= (0..5) .map(|i| { - Address::from_str(&format!("0x{}000000000000000000000000000000000000000", i)) - .unwrap() + Address::from_str(&format!("0x{i}000000000000000000000000000000000000000")).unwrap() }) .collect(); diff --git a/crates/shared/src/security/request_signer.rs b/crates/shared/src/security/request_signer.rs index ff3e9964..c5ea3605 100644 --- a/crates/shared/src/security/request_signer.rs +++ b/crates/shared/src/security/request_signer.rs @@ -143,7 +143,7 @@ mod tests { let signature = sign_request(endpoint, &wallet, Some(&empty_data)) .await .unwrap(); - println!("Signature: {}", signature); + println!("Signature: {signature}"); assert!(signature.starts_with("0x")); assert_eq!(signature.len(), 132); } diff --git a/crates/shared/src/utils/google_cloud.rs b/crates/shared/src/utils/google_cloud.rs index 128259eb..72fae856 100644 --- a/crates/shared/src/utils/google_cloud.rs +++ b/crates/shared/src/utils/google_cloud.rs @@ -194,20 +194,14 @@ mod tests { #[tokio::test] async fn test_generate_mapping_file() { // Check if required environment variables are set - let bucket_name = match std::env::var("S3_BUCKET_NAME") { - Ok(name) => name, - Err(_) => { - println!("Skipping test: BUCKET_NAME not set"); - return; - } + let Ok(bucket_name) = std::env::var("S3_BUCKET_NAME") else { + println!("Skipping test: BUCKET_NAME not set"); + return; }; - let credentials_base64 = match std::env::var("S3_CREDENTIALS") { - Ok(credentials) => credentials, - Err(_) => { - println!("Skipping test: S3_CREDENTIALS not set"); - return; - } + let Ok(credentials_base64) = std::env::var("S3_CREDENTIALS") else { + println!("Skipping test: S3_CREDENTIALS not set"); + return; }; let storage = GcsStorageProvider::new(&bucket_name, &credentials_base64) @@ -219,15 +213,15 @@ mod tests { .generate_mapping_file(&random_sha256, "run_1/file.parquet") .await .unwrap(); - println!("mapping_content: {}", mapping_content); - println!("bucket_name: {}", bucket_name); + println!("mapping_content: {mapping_content}"); + println!("bucket_name: {bucket_name}"); let original_file_name = storage .resolve_mapping_for_sha(&random_sha256) .await .unwrap(); - println!("original_file_name: {}", original_file_name); + println!("original_file_name: {original_file_name}"); assert_eq!(original_file_name, "run_1/file.parquet"); } } diff --git a/crates/shared/src/utils/mod.rs b/crates/shared/src/utils/mod.rs index d4e3f1c9..290f1ae5 100644 --- a/crates/shared/src/utils/mod.rs +++ b/crates/shared/src/utils/mod.rs @@ -119,7 +119,7 @@ mod tests { provider.add_mapping_file("sha256", "file.txt").await; provider.add_file("file.txt", "content").await; let map_file_link = provider.resolve_mapping_for_sha("sha256").await.unwrap(); - println!("map_file_link: {}", map_file_link); + println!("map_file_link: {map_file_link}"); assert_eq!(map_file_link, "file.txt"); assert_eq!( diff --git a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs index ff0a20ce..b52f96e2 100644 --- a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs +++ b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs @@ -29,6 +29,7 @@ impl ComputePool

{ .function("getComputePool", &[pool_id.into()])? .call() .await?; + let pool_info_tuple: &[DynSolValue] = pool_info_response.first().unwrap().as_tuple().unwrap(); @@ -60,6 +61,9 @@ impl ComputePool

{ _ => panic!("Unknown status value: {status}"), }; + println!("Mapped status: {mapped_status:?}"); + println!("Returning pool info"); + let pool_info = PoolInfo { pool_id, domain_id, diff --git a/crates/validator/src/store/redis.rs b/crates/validator/src/store/redis.rs index 508815c2..c0a0c36b 100644 --- a/crates/validator/src/store/redis.rs +++ b/crates/validator/src/store/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/validator/src/validators/hardware.rs b/crates/validator/src/validators/hardware.rs index 877861da..da5307e3 100644 --- a/crates/validator/src/validators/hardware.rs +++ b/crates/validator/src/validators/hardware.rs @@ -161,7 +161,7 @@ mod tests { let result = validator.validate_nodes(nodes).await; let elapsed = start_time.elapsed(); assert!(elapsed < std::time::Duration::from_secs(11)); - println!("Validation took: {:?}", elapsed); + println!("Validation took: {elapsed:?}"); assert!(result.is_ok()); } diff --git a/crates/validator/src/validators/synthetic_data/chain_operations.rs b/crates/validator/src/validators/synthetic_data/chain_operations.rs index 004c7e45..a0687d18 100644 --- a/crates/validator/src/validators/synthetic_data/chain_operations.rs +++ b/crates/validator/src/validators/synthetic_data/chain_operations.rs @@ -3,7 +3,7 @@ use super::*; impl SyntheticDataValidator { #[cfg(test)] pub fn soft_invalidate_work(&self, work_key: &str) -> Result<(), Error> { - info!("Soft invalidating work: {}", work_key); + info!("Soft invalidating work: {work_key}"); if self.disable_chain_invalidation { info!("Chain invalidation is disabled, skipping work soft invalidation"); @@ -54,7 +54,7 @@ impl SyntheticDataValidator { #[cfg(test)] pub fn invalidate_work(&self, work_key: &str) -> Result<(), Error> { - info!("Invalidating work: {}", work_key); + info!("Invalidating work: {work_key}"); if let Some(metrics) = &self.metrics { metrics.record_work_key_invalidation(); @@ -98,20 +98,27 @@ impl SyntheticDataValidator { } } } - + #[cfg(test)] + #[allow(clippy::unused_async)] pub async fn invalidate_according_to_invalidation_type( &self, work_key: &str, invalidation_type: InvalidationType, ) -> Result<(), Error> { match invalidation_type { - #[cfg(test)] InvalidationType::Soft => self.soft_invalidate_work(work_key), - #[cfg(not(test))] - InvalidationType::Soft => self.soft_invalidate_work(work_key).await, - #[cfg(test)] InvalidationType::Hard => self.invalidate_work(work_key), - #[cfg(not(test))] + } + } + + #[cfg(not(test))] + pub async fn invalidate_according_to_invalidation_type( + &self, + work_key: &str, + invalidation_type: InvalidationType, + ) -> Result<(), Error> { + match invalidation_type { + InvalidationType::Soft => self.soft_invalidate_work(work_key).await, InvalidationType::Hard => self.invalidate_work(work_key).await, } } diff --git a/crates/validator/src/validators/synthetic_data/mod.rs b/crates/validator/src/validators/synthetic_data/mod.rs index ce472c8b..bf8ce6e2 100644 --- a/crates/validator/src/validators/synthetic_data/mod.rs +++ b/crates/validator/src/validators/synthetic_data/mod.rs @@ -237,7 +237,7 @@ impl SyntheticDataValidator { let score: Option = con .zscore("incomplete_groups", group_key) .await - .map_err(|e| Error::msg(format!("Failed to check incomplete tracking: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to check incomplete tracking: {e}")))?; Ok(score.is_some()) } @@ -270,13 +270,10 @@ impl SyntheticDataValidator { let _: () = con .zadd("incomplete_groups", group_key, new_deadline) .await - .map_err(|e| { - Error::msg(format!("Failed to update incomplete group deadline: {}", e)) - })?; + .map_err(|e| Error::msg(format!("Failed to update incomplete group deadline: {e}")))?; debug!( - "Updated deadline for incomplete group {} to {} ({} minutes from now)", - group_key, new_deadline, minutes_from_now + "Updated deadline for incomplete group {group_key} to {new_deadline} ({minutes_from_now} minutes from now)" ); Ok(()) @@ -420,7 +417,7 @@ impl SyntheticDataValidator { let data: Option = con .get(key) .await - .map_err(|e| Error::msg(format!("Failed to get work validation status: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to get work validation status: {e}")))?; match data { Some(data) => { @@ -435,8 +432,7 @@ impl SyntheticDataValidator { reason: None, })), Err(e) => Err(Error::msg(format!( - "Failed to parse work validation data: {}", - e + "Failed to parse work validation data: {e}" ))), } } @@ -1576,8 +1572,7 @@ impl SyntheticDataValidator { .await { error!( - "Failed to update work validation status for {}: {}", - work_key, e + "Failed to update work validation status for {work_key}: {e}" ); } } diff --git a/crates/validator/src/validators/synthetic_data/tests/mod.rs b/crates/validator/src/validators/synthetic_data/tests/mod.rs index a589076f..48aaee85 100644 --- a/crates/validator/src/validators/synthetic_data/tests/mod.rs +++ b/crates/validator/src/validators/synthetic_data/tests/mod.rs @@ -34,7 +34,7 @@ fn setup_test_env() -> Result<(RedisStore, Contracts), Error> { "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", url, ) - .map_err(|e| Error::msg(format!("Failed to create demo wallet: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to create demo wallet: {e}")))?; let contracts = ContractBuilder::new(demo_wallet.provider()) .with_compute_registry() @@ -45,7 +45,7 @@ fn setup_test_env() -> Result<(RedisStore, Contracts), Error> { .with_stake_manager() .with_synthetic_data_validator(Some(Address::ZERO)) .build() - .map_err(|e| Error::msg(format!("Failed to build contracts: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to build contracts: {e}")))?; Ok((store, contracts)) } @@ -197,8 +197,8 @@ async fn test_status_update() -> Result<(), Error> { ) .await .map_err(|e| { - error!("Failed to update work validation status: {}", e); - Error::msg(format!("Failed to update work validation status: {}", e)) + error!("Failed to update work validation status: {e}"); + Error::msg(format!("Failed to update work validation status: {e}")) })?; tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; @@ -206,8 +206,8 @@ async fn test_status_update() -> Result<(), Error> { .get_work_validation_status_from_redis("0x0000000000000000000000000000000000000000") .await .map_err(|e| { - error!("Failed to get work validation status: {}", e); - Error::msg(format!("Failed to get work validation status: {}", e)) + error!("Failed to get work validation status: {e}"); + Error::msg(format!("Failed to get work validation status: {e}")) })?; assert_eq!(status, Some(ValidationResult::Accept)); Ok(()) @@ -344,20 +344,20 @@ async fn test_group_e2e_accept() -> Result<(), Error> { let mock_storage = MockStorageProvider::new(); mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-1-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-1-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-1-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-1-0-0.parquet"), ) .await; server .mock( "POST", - format!("/validategroup/dataset/samplingn-{}-1-0.parquet", GROUP_ID).as_str(), + format!("/validategroup/dataset/samplingn-{GROUP_ID}-1-0.parquet").as_str(), ) .match_body(mockito::Matcher::Json(serde_json::json!({ "file_shas": [FILE_SHA], @@ -371,7 +371,7 @@ async fn test_group_e2e_accept() -> Result<(), Error> { server .mock( "GET", - format!("/statusgroup/dataset/samplingn-{}-1-0.parquet", GROUP_ID).as_str(), + format!("/statusgroup/dataset/samplingn-{GROUP_ID}-1-0.parquet").as_str(), ) .with_status(200) .with_body(r#"{"status": "accept", "input_flops": 1, "output_flops": 1000}"#) @@ -463,7 +463,7 @@ async fn test_group_e2e_accept() -> Result<(), Error> { metrics_2.contains("validator_work_keys_to_process{pool_id=\"0\",validator_id=\"0\"} 0") ); assert!(metrics_2.contains("toploc_config_name=\"Qwen/Qwen0.6\"")); - assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{}\",pool_id=\"0\",result=\"match\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"match\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -490,32 +490,32 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { let mock_storage = MockStorageProvider::new(); mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-1.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-1.parquet"), "file2", ) .await; mock_storage .add_mapping_file( HONEST_FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-0.parquet"), ) .await; mock_storage .add_mapping_file( EXCESSIVE_FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-1.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-1.parquet"), ) .await; server .mock( "POST", - format!("/validategroup/dataset/samplingn-{}-2-0.parquet", GROUP_ID).as_str(), + format!("/validategroup/dataset/samplingn-{GROUP_ID}-2-0.parquet").as_str(), ) .match_body(mockito::Matcher::Json(serde_json::json!({ "file_shas": [HONEST_FILE_SHA, EXCESSIVE_FILE_SHA], @@ -529,7 +529,7 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { server .mock( "GET", - format!("/statusgroup/dataset/samplingn-{}-2-0.parquet", GROUP_ID).as_str(), + format!("/statusgroup/dataset/samplingn-{GROUP_ID}-2-0.parquet").as_str(), ) .with_status(200) .with_body(r#"{"status": "accept", "input_flops": 1, "output_flops": 2000}"#) @@ -636,12 +636,12 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { assert_eq!(plan_3.group_trigger_tasks.len(), 0); assert_eq!(plan_3.group_status_check_tasks.len(), 0); let metrics_2 = export_metrics().unwrap(); - assert!(metrics_2.contains(&format!("validator_group_validations_total{{group_id=\"{}\",pool_id=\"0\",result=\"accept\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_validations_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"accept\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); assert!( metrics_2.contains("validator_work_keys_to_process{pool_id=\"0\",validator_id=\"0\"} 0") ); assert!(metrics_2.contains("toploc_config_name=\"Qwen/Qwen0.6\"")); - assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{}\",pool_id=\"0\",result=\"mismatch\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"mismatch\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -734,26 +734,26 @@ async fn test_incomplete_group_recovery() -> Result<(), Error> { mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-1.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-1.parquet"), "file2", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), ) .await; mock_storage .add_mapping_file( FILE_SHA_2, - &format!("TestModel/dataset/test-{}-2-0-1.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-1.parquet"), ) .await; @@ -800,7 +800,7 @@ async fn test_incomplete_group_recovery() -> Result<(), Error> { assert!(group.is_none(), "Group should be incomplete"); // Check that the incomplete group is being tracked - let group_key = format!("group:{}:2:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:2:0"); let is_tracked = validator .is_group_being_tracked_as_incomplete(&group_key) .await?; @@ -847,14 +847,14 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), ) .await; @@ -902,7 +902,7 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> // Manually expire the incomplete group tracking by removing it and simulating expiry // In a real test, you would wait for the actual expiry, but for testing we simulate it - let group_key = format!("group:{}:2:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:2:0"); validator.track_incomplete_group(&group_key).await?; // Process groups past grace period (this would normally find groups past deadline) @@ -936,7 +936,7 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> assert_eq!(key_status, Some(ValidationResult::IncompleteGroup)); let metrics = export_metrics().unwrap(); - assert!(metrics.contains(&format!("validator_work_keys_soft_invalidated_total{{group_key=\"group:{}:2:0\",pool_id=\"0\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics.contains(&format!("validator_work_keys_soft_invalidated_total{{group_key=\"group:{GROUP_ID}:2:0\",pool_id=\"0\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -952,14 +952,14 @@ async fn test_incomplete_group_status_tracking() -> Result<(), Error> { mock_storage .add_file( - &format!("TestModel/dataset/test-{}-3-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-3-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-3-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-3-0-0.parquet"), ) .await; @@ -1006,7 +1006,7 @@ async fn test_incomplete_group_status_tracking() -> Result<(), Error> { // Manually process groups past grace period to simulate what would happen // after the grace period expires (we simulate this since we can't wait in tests) - let group_key = format!("group:{}:3:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:3:0"); // Manually add the group to tracking and then process it validator.track_incomplete_group(&group_key).await?; diff --git a/crates/validator/src/validators/synthetic_data/toploc.rs b/crates/validator/src/validators/synthetic_data/toploc.rs index 33d9f57f..f5641533 100644 --- a/crates/validator/src/validators/synthetic_data/toploc.rs +++ b/crates/validator/src/validators/synthetic_data/toploc.rs @@ -689,8 +689,7 @@ mod tests { Some(expected_idx) => { assert!( matched, - "Expected file {} to match config {}", - test_file, expected_idx + "Expected file {test_file} to match config {expected_idx}" ); assert_eq!( matched_idx, @@ -701,7 +700,7 @@ mod tests { expected_idx ); } - None => assert!(!matched, "File {} should not match any config", test_file), + None => assert!(!matched, "File {test_file} should not match any config"), } } } diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index eb041cad..bd35ca32 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] shared = { workspace = true } p2p = { workspace = true } +prime-core = { workspace = true} actix-web = { workspace = true } alloy = { workspace = true } diff --git a/crates/worker/src/checks/hardware/interconnect.rs b/crates/worker/src/checks/hardware/interconnect.rs index 21725686..d87d1819 100644 --- a/crates/worker/src/checks/hardware/interconnect.rs +++ b/crates/worker/src/checks/hardware/interconnect.rs @@ -78,7 +78,7 @@ mod tests { #[tokio::test] async fn test_check_speeds() { let result = InterconnectCheck::check_speeds().await; - println!("Test Result: {:?}", result); + println!("Test Result: {result:?}"); // Verify the result is Ok and contains expected tuple structure assert!(result.is_ok()); diff --git a/crates/worker/src/checks/hardware/storage.rs b/crates/worker/src/checks/hardware/storage.rs index 9509e731..8360993b 100644 --- a/crates/worker/src/checks/hardware/storage.rs +++ b/crates/worker/src/checks/hardware/storage.rs @@ -216,7 +216,7 @@ fn test_or_create_app_directory(path: &str) -> bool { } #[cfg(not(target_os = "linux"))] -pub fn find_largest_storage() -> Option { +pub(crate) fn find_largest_storage() -> Option { None } @@ -233,7 +233,7 @@ pub(crate) fn get_available_space(path: &str) -> Option { } #[cfg(not(target_os = "linux"))] -pub fn get_available_space(_path: &str) -> Option { +pub(crate) fn get_available_space(_path: &str) -> Option { None } diff --git a/crates/worker/src/checks/stun.rs b/crates/worker/src/checks/stun.rs index 5830b49e..734f2795 100644 --- a/crates/worker/src/checks/stun.rs +++ b/crates/worker/src/checks/stun.rs @@ -139,7 +139,7 @@ mod tests { async fn test_get_public_ip() { let stun_check = StunCheck::new(Duration::from_secs(5), 0); let public_ip = stun_check.get_public_ip().await.unwrap(); - println!("Public IP: {}", public_ip); + println!("Public IP: {public_ip}"); assert!(!public_ip.is_empty()); } } diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs index 1e9e5825..f53f2762 100644 --- a/crates/worker/src/cli/command.rs +++ b/crates/worker/src/cli/command.rs @@ -6,9 +6,8 @@ use crate::console::Console; use crate::docker::taskbridge::TaskBridge; use crate::docker::DockerService; use crate::metrics::store::MetricsStore; -use crate::operations::compute_node::ComputeNodeOperations; use crate::operations::heartbeat::service::HeartbeatService; -use crate::operations::provider::ProviderOperations; +use crate::operations::node_monitor::NodeMonitor; use crate::services::discovery::DiscoveryService; use crate::services::discovery_updater::DiscoveryUpdater; use crate::state::system_state::SystemState; @@ -20,6 +19,8 @@ use alloy::signers::local::PrivateKeySigner; use alloy::signers::Signer; use clap::{Parser, Subcommand}; use log::{error, info}; +use prime_core::operations::compute_node::ComputeNodeOperations; +use prime_core::operations::provider::ProviderOperations; use shared::models::node::ComputeRequirements; use shared::models::node::Node; use shared::web3::contracts::core::builder::ContractBuilder; @@ -294,12 +295,10 @@ pub async fn execute_command( let provider_ops_cancellation = cancellation_token.clone(); - let compute_node_state = state.clone(); let compute_node_ops = ComputeNodeOperations::new( &provider_wallet_instance, &node_wallet_instance, contracts.clone(), - compute_node_state, ); let discovery_urls = vec![discovery_url @@ -606,7 +605,7 @@ pub async fn execute_command( .retry_register_provider( required_stake, *funding_retry_count, - cancellation_token.clone(), + Some(cancellation_token.clone()), ) .await { @@ -709,7 +708,7 @@ pub async fn execute_command( let heartbeat = match heartbeat_service.clone() { Ok(service) => service, Err(e) => { - error!("❌ Heartbeat service is not available: {e}"); + error!("❌ Heartbeat service is not available: {e:?}"); std::process::exit(1); } }; @@ -821,8 +820,13 @@ pub async fn execute_command( provider_ops.start_monitoring(provider_ops_cancellation); let pool_id = state.get_compute_pool_id(); - if let Err(err) = compute_node_ops.start_monitoring(cancellation_token.clone(), pool_id) - { + let node_monitor = NodeMonitor::new( + provider_wallet_instance.clone(), + node_wallet_instance.clone(), + contracts.clone(), + state.clone(), + ); + if let Err(err) = node_monitor.start_monitoring(cancellation_token.clone(), pool_id) { error!("❌ Failed to start node monitoring: {err}"); std::process::exit(1); } @@ -1031,7 +1035,6 @@ pub async fn execute_command( /* Initialize dependencies - services, contracts, operations */ - let contracts = ContractBuilder::new(provider_wallet_instance.provider()) .with_compute_registry() .with_ai_token() diff --git a/crates/worker/src/docker/taskbridge/bridge.rs b/crates/worker/src/docker/taskbridge/bridge.rs index 4765ef06..594bc62d 100644 --- a/crates/worker/src/docker/taskbridge/bridge.rs +++ b/crates/worker/src/docker/taskbridge/bridge.rs @@ -565,7 +565,7 @@ mod tests { "test_label2": 20.0, }); let sample_metric = serde_json::to_string(&data)?; - debug!("Sending {:?}", sample_metric); + debug!("Sending {sample_metric:?}"); let msg = format!("{}{}", sample_metric, "\n"); stream.write_all(msg.as_bytes()).await?; stream.flush().await?; @@ -616,7 +616,7 @@ mod tests { "output/input_flops": 2500.0, }); let sample_metric = serde_json::to_string(&json)?; - debug!("Sending {:?}", sample_metric); + debug!("Sending {sample_metric:?}"); let msg = format!("{}{}", sample_metric, "\n"); stream.write_all(msg.as_bytes()).await?; stream.flush().await?; @@ -626,8 +626,7 @@ mod tests { let all_metrics = metrics_store.get_all_metrics().await; assert!( all_metrics.is_empty(), - "Expected metrics to be empty but found: {:?}", - all_metrics + "Expected metrics to be empty but found: {all_metrics:?}" ); bridge_handle.abort(); diff --git a/crates/worker/src/operations/heartbeat/service.rs b/crates/worker/src/operations/heartbeat/service.rs index 1b002cae..289e86af 100644 --- a/crates/worker/src/operations/heartbeat/service.rs +++ b/crates/worker/src/operations/heartbeat/service.rs @@ -24,7 +24,6 @@ pub(crate) struct HeartbeatService { docker_service: Arc, metrics_store: Arc, } - #[derive(Debug, Clone, thiserror::Error)] pub(crate) enum HeartbeatError { #[error("HTTP request failed")] @@ -32,6 +31,7 @@ pub(crate) enum HeartbeatError { #[error("Service initialization failed")] InitFailed, } + impl HeartbeatService { #[allow(clippy::too_many_arguments)] pub(crate) fn new( diff --git a/crates/worker/src/operations/mod.rs b/crates/worker/src/operations/mod.rs index 193b64ae..d684160a 100644 --- a/crates/worker/src/operations/mod.rs +++ b/crates/worker/src/operations/mod.rs @@ -1,3 +1,2 @@ -pub(crate) mod compute_node; pub(crate) mod heartbeat; -pub(crate) mod provider; +pub(crate) mod node_monitor; diff --git a/crates/worker/src/operations/compute_node.rs b/crates/worker/src/operations/node_monitor.rs similarity index 59% rename from crates/worker/src/operations/compute_node.rs rename to crates/worker/src/operations/node_monitor.rs index 00f147a7..af33d450 100644 --- a/crates/worker/src/operations/compute_node.rs +++ b/crates/worker/src/operations/node_monitor.rs @@ -1,5 +1,5 @@ -use crate::{console::Console, state::system_state::SystemState}; -use alloy::{primitives::utils::keccak256 as keccak, primitives::U256, signers::Signer}; +use crate::state::system_state::SystemState; +use alloy::primitives::U256; use anyhow::Result; use shared::web3::wallet::Wallet; use shared::web3::{contracts::core::builder::Contracts, wallet::WalletProvider}; @@ -7,17 +7,17 @@ use std::sync::Arc; use tokio::time::{sleep, Duration}; use tokio_util::sync::CancellationToken; -pub(crate) struct ComputeNodeOperations<'c> { - provider_wallet: &'c Wallet, - node_wallet: &'c Wallet, +pub(crate) struct NodeMonitor { + provider_wallet: Wallet, + node_wallet: Wallet, contracts: Contracts, system_state: Arc, } -impl<'c> ComputeNodeOperations<'c> { +impl NodeMonitor { pub(crate) fn new( - provider_wallet: &'c Wallet, - node_wallet: &'c Wallet, + provider_wallet: Wallet, + node_wallet: Wallet, contracts: Contracts, system_state: Arc, ) -> Self { @@ -43,11 +43,12 @@ impl<'c> ComputeNodeOperations<'c> { let mut last_claimable = None; let mut last_locked = None; let mut first_check = true; + tokio::spawn(async move { loop { tokio::select! { _ = cancellation_token.cancelled() => { - Console::info("Monitor", "Shutting down node status monitor..."); + log::info!("Shutting down node status monitor..."); break; } _ = async { @@ -55,16 +56,15 @@ impl<'c> ComputeNodeOperations<'c> { Ok((active, validated)) => { if first_check || active != last_active { if !first_check { - Console::info("🔄 Chain Sync - Pool membership changed", &format!("From {last_active} to {active}" - )); + log::info!("🔄 Chain Sync - Pool membership changed: From {last_active} to {active}"); } else { - Console::info("🔄 Chain Sync - Node pool membership", &format!("{active}")); + log::info!("🔄 Chain Sync - Node pool membership: {active}"); } last_active = active; } let is_running = system_state.is_running().await; if !active && is_running { - Console::warning("Node is not longer in pool, shutting down heartbeat..."); + log::warn!("Node is not longer in pool, shutting down heartbeat..."); if let Err(e) = system_state.set_running(false, None).await { log::error!("Failed to set running to false: {e:?}"); } @@ -72,10 +72,9 @@ impl<'c> ComputeNodeOperations<'c> { if first_check || validated != last_validated { if !first_check { - Console::info("🔄 Chain Sync - Validation changed", &format!("From {last_validated} to {validated}" - )); + log::info!("🔄 Chain Sync - Validation changed: From {last_validated} to {validated}"); } else { - Console::info("🔄 Chain Sync - Node validation", &format!("{validated}")); + log::info!("🔄 Chain Sync - Node validation: {validated}"); } last_validated = validated; } @@ -91,7 +90,7 @@ impl<'c> ComputeNodeOperations<'c> { last_locked = Some(locked); let claimable_formatted = claimable.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0); let locked_formatted = locked.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0); - Console::info("Rewards", &format!("{claimable_formatted} claimable, {locked_formatted} locked")); + log::info!("Rewards: {claimable_formatted} claimable, {locked_formatted} locked"); } } Err(e) => { @@ -113,55 +112,4 @@ impl<'c> ComputeNodeOperations<'c> { }); Ok(()) } - - pub(crate) async fn check_compute_node_exists( - &self, - ) -> Result> { - let compute_node = self - .contracts - .compute_registry - .get_node( - self.provider_wallet.wallet.default_signer().address(), - self.node_wallet.wallet.default_signer().address(), - ) - .await; - - match compute_node { - Ok(_) => Ok(true), - Err(_) => Ok(false), - } - } - - // Returns true if the compute node was added, false if it already exists - pub(crate) async fn add_compute_node( - &self, - compute_units: U256, - ) -> Result> { - Console::title("🔄 Adding compute node"); - - if self.check_compute_node_exists().await? { - return Ok(false); - } - - Console::progress("Adding compute node"); - let provider_address = self.provider_wallet.wallet.default_signer().address(); - let node_address = self.node_wallet.wallet.default_signer().address(); - let digest = keccak([provider_address.as_slice(), node_address.as_slice()].concat()); - - let signature = self - .node_wallet - .signer - .sign_message(digest.as_slice()) - .await? - .as_bytes(); - - // Create the signature bytes - let add_node_tx = self - .contracts - .prime_network - .add_compute_node(node_address, compute_units, signature.to_vec()) - .await?; - Console::success(&format!("Add node tx: {add_node_tx:?}")); - Ok(true) - } } From 6141616e7eca9938bd3d4de5e1b5044f17257b7d Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Fri, 11 Jul 2025 16:44:13 +0200 Subject: [PATCH 02/23] basic message queue with mock data --- Cargo.lock | 36 +++ crates/prime-protocol-py/Cargo.toml | 6 +- crates/prime-protocol-py/Makefile | 33 +- crates/prime-protocol-py/README.md | 54 +++- .../prime-protocol-py/examples/basic_usage.py | 98 +++++- crates/prime-protocol-py/src/lib.rs | 108 ++++++- crates/prime-protocol-py/src/message_queue.rs | 160 ++++++++++ .../src/utils/json_parser.rs | 8 + crates/prime-protocol-py/src/utils/mod.rs | 1 + .../src/{client.rs => worker.rs} | 291 ++++++++++++------ 10 files changed, 653 insertions(+), 142 deletions(-) create mode 100644 crates/prime-protocol-py/src/message_queue.rs create mode 100644 crates/prime-protocol-py/src/utils/json_parser.rs create mode 100644 crates/prime-protocol-py/src/utils/mod.rs rename crates/prime-protocol-py/src/{client.rs => worker.rs} (50%) diff --git a/Cargo.lock b/Cargo.lock index 47b10ada..41c5f51c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6730,7 +6730,11 @@ dependencies = [ "prime-core", "pyo3", "pyo3-log", + "pythonize", + "serde", + "serde_json", "shared", + "test-log", "thiserror 1.0.69", "tokio", "tokio-test", @@ -6994,6 +6998,16 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "pythonize" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597907139a488b22573158793aa7539df36ae863eba300c75f3a0d65fc475e27" +dependencies = [ + "pyo3", + "serde", +] + [[package]] name = "quanta" version = "0.10.1" @@ -8789,6 +8803,28 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/crates/prime-protocol-py/Cargo.toml b/crates/prime-protocol-py/Cargo.toml index 9441afe1..cbb7b513 100644 --- a/crates/prime-protocol-py/Cargo.toml +++ b/crates/prime-protocol-py/Cargo.toml @@ -16,12 +16,16 @@ shared = { workspace = true } prime-core = { workspace = true } alloy = { workspace = true } alloy-provider = { workspace = true } -tokio = { version = "1.35", features = ["rt"] } +tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "sync", "time", "macros"] } url = "2.5" log = { workspace = true } pyo3-log = "0.12.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +pythonize = "0.25" [dev-dependencies] +test-log = "0.2" tokio-test = "0.4" [profile.release] diff --git a/crates/prime-protocol-py/Makefile b/crates/prime-protocol-py/Makefile index fe1858d0..dfb10ac9 100644 --- a/crates/prime-protocol-py/Makefile +++ b/crates/prime-protocol-py/Makefile @@ -4,43 +4,18 @@ install: @./setup.sh # Uses uv for fast package management .PHONY: build -build: +build: install + @uv cache clean @source .venv/bin/activate && maturin develop @source .venv/bin/activate && uv pip install --force-reinstall -e . -.PHONY: dev -dev: - @source .venv/bin/activate && maturin develop --watch - -.PHONY: build-release -build-release: - @source .venv/bin/activate && maturin build --release --strip - -.PHONY: test -test: - @source .venv/bin/activate && pytest tests/ -v - -.PHONY: example -example: - @source .venv/bin/activate && python examples/basic_usage.py - .PHONY: clean clean: @rm -rf target/ dist/ *.egg-info .pytest_cache __pycache__ .venv/ -.PHONY: clear-cache -clear-cache: - @uv cache clean - @echo "uv cache cleared" - .PHONY: help help: @echo "Available commands:" @echo " make install - Setup environment and install dependencies" - @echo " make build - Build development version" - @echo " make dev - Build with hot reload (watches for changes)" - @echo " make build-release - Build release wheel" - @echo " make test - Run tests" - @echo " make example - Run example script" - @echo " make clean - Clean build artifacts" - @echo " make clear-cache - Clear uv cache" \ No newline at end of file + @echo " make build - Build development version (includes install and cache clear)" + @echo " make clean - Clean build artifacts" \ No newline at end of file diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md index 439218c3..b72b39db 100644 --- a/crates/prime-protocol-py/README.md +++ b/crates/prime-protocol-py/README.md @@ -1,7 +1,5 @@ # Prime Protocol Python Client -Python bindings for checking if compute pools exist. - ## Build ```bash @@ -15,13 +13,59 @@ make install ## Usage +### Worker Client with Message Queue + +The Worker Client provides a message queue system for handling P2P messages from pool owners and validators. Messages are processed in a FIFO (First-In-First-Out) manner. + ```python -from primeprotocol import PrimeProtocolClient +from primeprotocol import WorkerClient +import asyncio + +# Initialize the worker client +client = WorkerClient( + compute_pool_id=1, + rpc_url="http://localhost:8545", + private_key_provider="your_provider_key", + private_key_node="your_node_key", +) + +# Start the client (registers on-chain and starts message listener) +client.start() -client = PrimeProtocolClient("http://localhost:8545") -exists = client.compute_pool_exists(0) +# Poll for messages in your application loop +async def process_messages(): + while True: + # Get next message from pool owner queue + pool_msg = client.get_pool_owner_message() + if pool_msg: + print(f"Pool owner message: {pool_msg}") + # Process the message... + + # Get next message from validator queue + validator_msg = client.get_validator_message() + if validator_msg: + print(f"Validator message: {validator_msg}") + # Process the message... + + await asyncio.sleep(0.1) + +# Run the message processing loop +asyncio.run(process_messages()) + +# Gracefully shutdown +client.stop() ``` +### Message Queue Features + +- **Background Listener**: Rust protocol listens for P2P messages in the background +- **FIFO Queue**: Messages are processed in the order they are received +- **Message Types**: Separate queues for pool owner, validator, and system messages +- **Mock Mode**: Currently generates mock messages for testing (P2P integration coming soon) +- **Thread-Safe**: Safe to use from async Python code + +See `examples/message_queue_example.py` for a complete working example. + ## Development ```bash diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py index 639eccf7..66572db7 100644 --- a/crates/prime-protocol-py/examples/basic_usage.py +++ b/crates/prime-protocol-py/examples/basic_usage.py @@ -1,13 +1,70 @@ #!/usr/bin/env python3 """Example usage of the Prime Protocol Python client.""" +import asyncio import logging import os -from primeprotocol import PrimeProtocolClient +import signal +import sys +import time +from typing import Dict, Any, Optional +from primeprotocol import WorkerClient FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' logging.basicConfig(format=FORMAT) -logging.getLogger().setLevel(logging.INFO) +logging.getLogger().setLevel(logging.DEBUG) + + +def handle_pool_owner_message(message: Dict[str, Any]) -> None: + """Handle messages from pool owner""" + logging.info(f"Received message from pool owner: {message}") + + if message.get("type") == "inference_request": + prompt = message.get("prompt", "") + # Simulate processing the inference request + response = f"Processed: {prompt}" + + logging.info(f"Processing inference request: {prompt}") + logging.info(f"Generated response: {response}") + + # In a real implementation, you would send the response back + # client.send_response({"type": "inference_response", "result": response}) + else: + logging.info("Sending PONG response") + # client.send_response("PONG") + + +def handle_validator_message(message: Dict[str, Any]) -> None: + """Handle messages from validator""" + logging.info(f"Received message from validator: {message}") + + if message.get("type") == "inference_request": + prompt = message.get("prompt", "") + # Simulate processing the inference request + response = f"Validated: {prompt}" + + logging.info(f"Processing validation request: {prompt}") + logging.info(f"Generated response: {response}") + + # In a real implementation, you would send the response back + # client.send_response({"type": "inference_response", "result": response}) + + +def check_for_messages(client: WorkerClient) -> None: + """Check for new messages from pool owner and validator""" + try: + # Check for pool owner messages + pool_owner_message = client.get_pool_owner_message() + if pool_owner_message: + handle_pool_owner_message(pool_owner_message) + + # Check for validator messages + validator_message = client.get_validator_message() + if validator_message: + handle_validator_message(validator_message) + + except Exception as e: + logging.error(f"Error checking for messages: {e}") def main(): @@ -17,8 +74,41 @@ def main(): private_key_node = os.getenv("PRIVATE_KEY_NODE", None) logging.info(f"Connecting to: {rpc_url}") - client = PrimeProtocolClient(pool_id, rpc_url, private_key_provider, private_key_node) - client.start() + client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node) + + def signal_handler(sig, frame): + logging.info("Received interrupt signal, shutting down gracefully...") + try: + client.stop() + logging.info("Client stopped successfully") + except Exception as e: + logging.error(f"Error during shutdown: {e}") + sys.exit(0) + + # Register signal handler for Ctrl+C + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + client.start() + logging.info("Setup completed. Starting message polling loop...") + print("Worker client started. Polling for messages. Press Ctrl+C to stop.") + + # Message polling loop + while True: + check_for_messages(client) + time.sleep(0.1) # Small delay to prevent busy waiting + + except KeyboardInterrupt: + logging.info("Keyboard interrupt received") + signal_handler(signal.SIGINT, None) + except Exception as e: + logging.error(f"Unexpected error: {e}") + try: + client.stop() + except: + pass + sys.exit(1) if __name__ == "__main__": main() \ No newline at end of file diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/prime-protocol-py/src/lib.rs index faa72b0c..b332a9e0 100644 --- a/crates/prime-protocol-py/src/lib.rs +++ b/crates/prime-protocol-py/src/lib.rs @@ -1,20 +1,21 @@ use pyo3::prelude::*; -mod client; mod error; +mod message_queue; +mod utils; +mod worker; -use client::PrimeProtocolClientCore; +use worker::WorkerClientCore; -// todo: We need a manager + validator side to send messages - -/// Prime Protocol Python client +/// Prime Protocol Worker Client - for compute nodes that execute tasks #[pyclass] -pub struct PrimeProtocolClient { - inner: PrimeProtocolClientCore, +pub struct WorkerClient { + inner: WorkerClientCore, + runtime: Option, } #[pymethods] -impl PrimeProtocolClient { +impl WorkerClient { #[new] #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None))] pub fn new( @@ -23,8 +24,7 @@ impl PrimeProtocolClient { private_key_provider: Option, private_key_node: Option, ) -> PyResult { - // todo: revisit default arguments here that are currently none - let inner = PrimeProtocolClientCore::new( + let inner = WorkerClientCore::new( compute_pool_id, rpc_url, private_key_provider, @@ -34,29 +34,105 @@ impl PrimeProtocolClient { ) .map_err(|e| PyErr::new::(e.to_string()))?; - Ok(Self { inner }) + Ok(Self { + inner, + runtime: None, + }) } - pub fn start(&self) -> PyResult<()> { + pub fn start(&mut self) -> PyResult<()> { // Create a new runtime for this call - let rt = tokio::runtime::Builder::new_current_thread() + let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .map_err(|e| PyErr::new::(e.to_string()))?; // Run the async function let result = rt.block_on(self.inner.start_async()); + println!("system start completed"); - // Clean shutdown - rt.shutdown_background(); + // Store the runtime for future use + self.runtime = Some(rt); result.map_err(|e| PyErr::new::(e.to_string())) } + + pub fn get_pool_owner_message(&self) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(rt.block_on(self.inner.get_message_queue().get_pool_owner_message())) + } else { + Err(PyErr::new::( + "Client not started. Call start() first.".to_string(), + )) + } + } + + pub fn get_validator_message(&self) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(rt.block_on(self.inner.get_message_queue().get_validator_message())) + } else { + Err(PyErr::new::( + "Client not started. Call start() first.".to_string(), + )) + } + } + + pub fn stop(&mut self) -> PyResult<()> { + if let Some(rt) = self.runtime.as_ref() { + rt.block_on(self.inner.stop_async()) + .map_err(|e| PyErr::new::(e.to_string()))?; + } + + // Clean up the runtime + if let Some(rt) = self.runtime.take() { + rt.shutdown_background(); + } + + Ok(()) + } +} + +/// Prime Protocol Orchestrator Client - for managing and distributing tasks +#[pyclass] +pub struct OrchestratorClient { + // TODO: Implement orchestrator-specific functionality +} + +#[pymethods] +impl OrchestratorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None))] + pub fn new(rpc_url: String, private_key: Option) -> PyResult { + // TODO: Implement orchestrator initialization + let _ = rpc_url; + let _ = private_key; + Ok(Self {}) + } +} + +/// Prime Protocol Validator Client - for validating task results +#[pyclass] +pub struct ValidatorClient { + // TODO: Implement validator-specific functionality +} + +#[pymethods] +impl ValidatorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None))] + pub fn new(rpc_url: String, private_key: Option) -> PyResult { + // TODO: Implement validator initialization + let _ = rpc_url; + let _ = private_key; + Ok(Self {}) + } } #[pymodule] fn primeprotocol(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/crates/prime-protocol-py/src/message_queue.rs b/crates/prime-protocol-py/src/message_queue.rs new file mode 100644 index 00000000..9af9a687 --- /dev/null +++ b/crates/prime-protocol-py/src/message_queue.rs @@ -0,0 +1,160 @@ +use pyo3::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::Mutex; +use tokio::time::{interval, Duration}; + +use crate::utils::json_parser::json_to_pyobject; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub message_type: MessageType, + pub content: serde_json::Value, + pub timestamp: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum MessageType { + PoolOwner, + Validator, +} + +#[derive(Clone)] +pub struct MessageQueue { + pool_owner_queue: Arc>>, + validator_queue: Arc>>, + shutdown_tx: Arc>>>, +} + +impl MessageQueue { + pub fn new() -> Self { + Self { + pool_owner_queue: Arc::new(Mutex::new(VecDeque::new())), + validator_queue: Arc::new(Mutex::new(VecDeque::new())), + shutdown_tx: Arc::new(Mutex::new(None)), + } + } + + /// Start the background message listener + pub(crate) async fn start_listener(&self) -> Result<(), String> { + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + + // Store the shutdown sender + { + let mut tx_guard = self.shutdown_tx.lock().await; + *tx_guard = Some(shutdown_tx); + } + + let pool_owner_queue = self.pool_owner_queue.clone(); + let validator_queue = self.validator_queue.clone(); + + // Spawn background task to simulate incoming p2p messages + tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(5)); + let mut counter = 0u64; + + loop { + tokio::select! { + _ = ticker.tick() => { + // Mock pool owner messages + if counter % 2 == 0 { + let message = Message { + message_type: MessageType::PoolOwner, + content: serde_json::json!({ + "type": "inference_request", + "task_id": format!("task_{}", counter), + "prompt": format!("Test prompt {}", counter), + }), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + let mut queue = pool_owner_queue.lock().await; + queue.push_back(message); + log::debug!("Added mock pool owner message to queue"); + } + + // Mock validator messages + if counter % 3 == 0 { + let message = Message { + message_type: MessageType::Validator, + content: serde_json::json!({ + "type": "validation_request", + "task_id": format!("validation_{}", counter), + }), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + let mut queue = validator_queue.lock().await; + queue.push_back(message); + log::debug!("Added mock validator message to queue"); + } + + counter += 1; + } + _ = shutdown_rx.recv() => { + log::info!("Message listener shutting down"); + break; + } + } + } + }); + + Ok(()) + } + + /// Stop the background listener + #[allow(unused)] + pub(crate) async fn stop_listener(&self) -> Result<(), String> { + if let Some(tx) = self.shutdown_tx.lock().await.take() { + let _ = tx.send(()).await; + } + Ok(()) + } + /// Get the next message from the pool owner queue + pub(crate) async fn get_pool_owner_message(&self) -> Option { + let mut queue = self.pool_owner_queue.lock().await; + queue + .pop_front() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + } + + /// Get the next message from the validator queue + pub(crate) async fn get_validator_message(&self) -> Option { + let mut queue = self.validator_queue.lock().await; + queue + .pop_front() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + } + + /// Push a message to the appropriate queue (for testing or internal use) + #[allow(unused)] + pub(crate) async fn push_message(&self, message: Message) -> Result<(), String> { + match message.message_type { + MessageType::PoolOwner => { + let mut queue = self.pool_owner_queue.lock().await; + queue.push_back(message); + } + MessageType::Validator => { + let mut queue = self.validator_queue.lock().await; + queue.push_back(message); + } + } + Ok(()) + } + + /// Get queue sizes for monitoring + #[allow(unused)] + pub(crate) async fn get_queue_sizes(&self) -> (usize, usize) { + let pool_owner_size = self.pool_owner_queue.lock().await.len(); + let validator_size = self.validator_queue.lock().await.len(); + (pool_owner_size, validator_size) + } +} diff --git a/crates/prime-protocol-py/src/utils/json_parser.rs b/crates/prime-protocol-py/src/utils/json_parser.rs new file mode 100644 index 00000000..b5ed4aa2 --- /dev/null +++ b/crates/prime-protocol-py/src/utils/json_parser.rs @@ -0,0 +1,8 @@ +use pyo3::prelude::*; +use pythonize::pythonize; + +/// Convert a serde_json::Value to a Python object +pub fn json_to_pyobject(py: Python, value: &serde_json::Value) -> PyObject { + // pythonize handles all the conversion automatically! + pythonize(py, value).unwrap().into() +} diff --git a/crates/prime-protocol-py/src/utils/mod.rs b/crates/prime-protocol-py/src/utils/mod.rs new file mode 100644 index 00000000..3e9394ce --- /dev/null +++ b/crates/prime-protocol-py/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod json_parser; diff --git a/crates/prime-protocol-py/src/client.rs b/crates/prime-protocol-py/src/worker.rs similarity index 50% rename from crates/prime-protocol-py/src/client.rs rename to crates/prime-protocol-py/src/worker.rs index b4139b7b..d7459ba4 100644 --- a/crates/prime-protocol-py/src/client.rs +++ b/crates/prime-protocol-py/src/worker.rs @@ -1,23 +1,26 @@ use crate::error::{PrimeProtocolError, Result}; +use crate::message_queue::MessageQueue; use alloy::primitives::utils::format_ether; -use alloy::primitives::U256; +use alloy::primitives::{Address, U256}; use prime_core::operations::compute_node::ComputeNodeOperations; use prime_core::operations::provider::ProviderOperations; use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; use shared::web3::contracts::structs::compute_pool::PoolStatus; use shared::web3::wallet::{Wallet, WalletProvider}; +use std::sync::Arc; use url::Url; -pub struct PrimeProtocolClientCore { +pub struct WorkerClientCore { rpc_url: String, compute_pool_id: u64, private_key_provider: Option, private_key_node: Option, auto_accept_transactions: bool, funding_retry_count: u32, + message_queue: Arc, } -impl PrimeProtocolClientCore { +impl WorkerClientCore { pub fn new( compute_pool_id: u64, rpc_url: String, @@ -42,6 +45,7 @@ impl PrimeProtocolClientCore { private_key_node, auto_accept_transactions: auto_accept_transactions.unwrap_or(true), funding_retry_count: funding_retry_count.unwrap_or(10), + message_queue: Arc::new(MessageQueue::new()), }) } @@ -50,15 +54,23 @@ impl PrimeProtocolClientCore { self.initialize_blockchain_components().await?; let pool_info = self.wait_for_active_pool(&contracts).await?; - log::info!("Pool info: {:?}", pool_info); - + log::debug!("Pool info: {:?}", pool_info); + log::debug!("Checking provider"); self.ensure_provider_registered(&provider_wallet, &contracts) .await?; + log::debug!("Checking compute node"); self.ensure_compute_node_registered(&provider_wallet, &node_wallet, &contracts) .await?; - // TODO: Optional - run hardware check? - // TODO: p2p reachable? + log::debug!("blockchain components initialized"); + log::debug!("starting queues"); + + // Start the message queue listener + self.message_queue.start_listener().await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to start message listener: {}", e)) + })?; + + log::debug!("Message queue listener started"); Ok(()) } @@ -114,6 +126,7 @@ impl PrimeProtocolClientCore { } } } + async fn ensure_provider_registered( &self, provider_wallet: &Wallet, @@ -125,22 +138,33 @@ impl PrimeProtocolClientCore { self.auto_accept_transactions, ); - // Check if provider exists - let provider_exists = provider_ops.check_provider_exists().await.map_err(|e| { + let provider_exists = self.check_provider_exists(&provider_ops).await?; + let is_whitelisted = self.check_provider_whitelisted(&provider_ops).await?; + + if provider_exists && is_whitelisted { + log::info!("Provider is registered and whitelisted"); + } else { + self.register_provider_if_needed(&provider_ops, contracts) + .await?; + } + + self.ensure_adequate_stake(&provider_ops, provider_wallet, contracts) + .await?; + + Ok(()) + } + + async fn check_provider_exists(&self, provider_ops: &ProviderOperations) -> Result { + provider_ops.check_provider_exists().await.map_err(|e| { PrimeProtocolError::BlockchainError(format!( "Failed to check if provider exists: {}", e )) - })?; - - let Some(stake_manager) = contracts.stake_manager.as_ref() else { - return Err(PrimeProtocolError::BlockchainError( - "Stake manager not initialized".to_string(), - )); - }; + }) + } - // Check if provider is whitelisted - let is_whitelisted = provider_ops + async fn check_provider_whitelisted(&self, provider_ops: &ProviderOperations) -> Result { + provider_ops .check_provider_whitelisted() .await .map_err(|e| { @@ -148,59 +172,58 @@ impl PrimeProtocolClientCore { "Failed to check provider whitelist status: {}", e )) - })?; - - // todo: revisit this - if provider_exists && is_whitelisted { - log::info!("Provider is registered and whitelisted"); - } else { - // For now, we'll use a default compute_units value - this should be configurable - let compute_units = U256::from(1); - - let required_stake = stake_manager - .calculate_stake(compute_units, U256::from(0)) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to calculate required stake: {}", - e - )) - })?; - - log::info!("Required stake: {}", format_ether(required_stake)); - - provider_ops - .retry_register_provider(required_stake, self.funding_retry_count, None) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to register provider: {}", - e - )) - })?; + }) + } - log::info!("Provider registered successfully"); - } + async fn register_provider_if_needed( + &self, + provider_ops: &ProviderOperations, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts.stake_manager.as_ref().ok_or_else(|| { + PrimeProtocolError::BlockchainError("Stake manager not initialized".to_string()) + })?; + let compute_units = U256::from(1); // TODO: Make configurable - // Get provider's current total compute and stake - let provider_total_compute = contracts - .compute_registry - .get_provider_total_compute(provider_wallet.wallet.default_signer().address()) + let required_stake = stake_manager + .calculate_stake(compute_units, U256::from(0)) .await .map_err(|e| { PrimeProtocolError::BlockchainError(format!( - "Failed to get provider total compute: {}", + "Failed to calculate required stake: {}", e )) })?; - let provider_stake = stake_manager - .get_stake(provider_wallet.wallet.default_signer().address()) + log::info!("Required stake: {}", format_ether(required_stake)); + + provider_ops + .retry_register_provider(required_stake, self.funding_retry_count, None) .await - .unwrap_or_default(); + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to register provider: {}", e)) + })?; + + log::info!("Provider registered successfully"); + Ok(()) + } + + async fn ensure_adequate_stake( + &self, + provider_ops: &ProviderOperations, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts.stake_manager.as_ref().ok_or_else(|| { + PrimeProtocolError::BlockchainError("Stake manager not initialized".to_string()) + })?; + let provider_address = provider_wallet.wallet.default_signer().address(); - // For now, we'll use a default compute_units value - this should be configurable - let compute_units = U256::from(1); + let provider_total_compute = self + .get_provider_total_compute(contracts, provider_address) + .await?; + let provider_stake = self.get_provider_stake(contracts, provider_address).await; + let compute_units = U256::from(1); // TODO: Make configurable let required_stake = stake_manager .calculate_stake(compute_units, provider_total_compute) @@ -213,22 +236,65 @@ impl PrimeProtocolClientCore { })?; if required_stake > provider_stake { - log::info!( - "Provider stake is less than required stake. Required: {} tokens, Current: {} tokens", - format_ether(required_stake), - format_ether(provider_stake) - ); - - provider_ops - .increase_stake(required_stake - provider_stake) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) - })?; + self.increase_provider_stake(provider_ops, required_stake, provider_stake) + .await?; + } + + Ok(()) + } + + async fn get_provider_total_compute( + &self, + contracts: &Contracts, + provider_address: Address, + ) -> Result { + contracts + .compute_registry + .get_provider_total_compute(provider_address) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to get provider total compute: {}", + e + )) + }) + } - log::info!("Successfully increased stake"); + async fn get_provider_stake( + &self, + contracts: &Contracts, + provider_address: Address, + ) -> U256 { + let stake_manager = contracts.stake_manager.as_ref(); + match stake_manager { + Some(manager) => manager + .get_stake(provider_address) + .await + .unwrap_or_default(), + None => U256::from(0), } + } + async fn increase_provider_stake( + &self, + provider_ops: &ProviderOperations, + required_stake: U256, + current_stake: U256, + ) -> Result<()> { + log::info!( + "Provider stake is less than required stake. Required: {} tokens, Current: {} tokens", + format_ether(required_stake), + format_ether(current_stake) + ); + + provider_ops + .increase_stake(required_stake - current_stake) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) + })?; + + log::info!("Successfully increased stake"); Ok(()) } @@ -241,27 +307,40 @@ impl PrimeProtocolClientCore { let compute_node_ops = ComputeNodeOperations::new(provider_wallet, node_wallet, contracts.clone()); - // Check if compute node exists - let compute_node_exists = - compute_node_ops - .check_compute_node_exists() - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to check if compute node exists: {}", - e - )) - })?; + let compute_node_exists = self.check_compute_node_exists(&compute_node_ops).await?; if compute_node_exists { log::info!("Compute node is already registered"); return Ok(()); } - // If compute node doesn't exist, register it - // For now, we'll use default compute specs - this should be configurable + self.register_compute_node(&compute_node_ops).await?; + Ok(()) + } + + async fn check_compute_node_exists( + &self, + compute_node_ops: &ComputeNodeOperations<'_>, + ) -> Result { compute_node_ops - .add_compute_node(U256::from(1)) + .check_compute_node_exists() + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check if compute node exists: {}", + e + )) + }) + } + + async fn register_compute_node( + &self, + compute_node_ops: &ComputeNodeOperations<'_>, + ) -> Result<()> { + let compute_units = U256::from(1); // TODO: Make configurable + + compute_node_ops + .add_compute_node(compute_units) .await .map_err(|e| { PrimeProtocolError::BlockchainError(format!( @@ -291,4 +370,42 @@ impl PrimeProtocolClientCore { }), } } + + /// Get the shared message queue instance + pub fn get_message_queue(&self) -> Arc { + self.message_queue.clone() + } + + /// Stop the message queue listener + pub async fn stop_async(&self) -> Result<()> { + self.message_queue.stop_listener().await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to stop message listener: {}", e)) + })?; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use test_log::test; + + #[test(tokio::test)] + async fn test_start_async() { + // standard anvil blockchain keys for local testing + let node_key = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + + // todo: currently still have to make up the local blockchain incl. smart contract deployments + let worker = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key.to_string()), + Some(node_key.to_string()), + None, + None, + ) + .unwrap(); + worker.start_async().await.unwrap(); + } } From 5ac1a8e45edb43fe82c55915b2a925ff937fc0d6 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Fri, 11 Jul 2025 18:26:13 +0200 Subject: [PATCH 03/23] restructure python sdk lib to have pyo bindings in sep. modules --- crates/prime-protocol-py/src/lib.rs | 129 +----------------- .../prime-protocol-py/src/orchestrator/mod.rs | 19 +++ crates/prime-protocol-py/src/validator/mod.rs | 19 +++ .../src/{worker.rs => worker/client.rs} | 2 +- .../src/{ => worker}/message_queue.rs | 0 crates/prime-protocol-py/src/worker/mod.rs | 88 ++++++++++++ .../src/checks/hardware/storage.rs:236:1 | 0 7 files changed, 132 insertions(+), 125 deletions(-) create mode 100644 crates/prime-protocol-py/src/orchestrator/mod.rs create mode 100644 crates/prime-protocol-py/src/validator/mod.rs rename crates/prime-protocol-py/src/{worker.rs => worker/client.rs} (99%) rename crates/prime-protocol-py/src/{ => worker}/message_queue.rs (100%) create mode 100644 crates/prime-protocol-py/src/worker/mod.rs create mode 100644 crates/worker/src/checks/hardware/storage.rs:236:1 diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/prime-protocol-py/src/lib.rs index b332a9e0..0715c33a 100644 --- a/crates/prime-protocol-py/src/lib.rs +++ b/crates/prime-protocol-py/src/lib.rs @@ -1,133 +1,14 @@ +use crate::orchestrator::OrchestratorClient; +use crate::validator::ValidatorClient; +use crate::worker::WorkerClient; use pyo3::prelude::*; mod error; -mod message_queue; +mod orchestrator; mod utils; +mod validator; mod worker; -use worker::WorkerClientCore; - -/// Prime Protocol Worker Client - for compute nodes that execute tasks -#[pyclass] -pub struct WorkerClient { - inner: WorkerClientCore, - runtime: Option, -} - -#[pymethods] -impl WorkerClient { - #[new] - #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None))] - pub fn new( - compute_pool_id: u64, - rpc_url: String, - private_key_provider: Option, - private_key_node: Option, - ) -> PyResult { - let inner = WorkerClientCore::new( - compute_pool_id, - rpc_url, - private_key_provider, - private_key_node, - None, - None, - ) - .map_err(|e| PyErr::new::(e.to_string()))?; - - Ok(Self { - inner, - runtime: None, - }) - } - - pub fn start(&mut self) -> PyResult<()> { - // Create a new runtime for this call - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .map_err(|e| PyErr::new::(e.to_string()))?; - - // Run the async function - let result = rt.block_on(self.inner.start_async()); - println!("system start completed"); - - // Store the runtime for future use - self.runtime = Some(rt); - - result.map_err(|e| PyErr::new::(e.to_string())) - } - - pub fn get_pool_owner_message(&self) -> PyResult> { - if let Some(rt) = self.runtime.as_ref() { - Ok(rt.block_on(self.inner.get_message_queue().get_pool_owner_message())) - } else { - Err(PyErr::new::( - "Client not started. Call start() first.".to_string(), - )) - } - } - - pub fn get_validator_message(&self) -> PyResult> { - if let Some(rt) = self.runtime.as_ref() { - Ok(rt.block_on(self.inner.get_message_queue().get_validator_message())) - } else { - Err(PyErr::new::( - "Client not started. Call start() first.".to_string(), - )) - } - } - - pub fn stop(&mut self) -> PyResult<()> { - if let Some(rt) = self.runtime.as_ref() { - rt.block_on(self.inner.stop_async()) - .map_err(|e| PyErr::new::(e.to_string()))?; - } - - // Clean up the runtime - if let Some(rt) = self.runtime.take() { - rt.shutdown_background(); - } - - Ok(()) - } -} - -/// Prime Protocol Orchestrator Client - for managing and distributing tasks -#[pyclass] -pub struct OrchestratorClient { - // TODO: Implement orchestrator-specific functionality -} - -#[pymethods] -impl OrchestratorClient { - #[new] - #[pyo3(signature = (rpc_url, private_key=None))] - pub fn new(rpc_url: String, private_key: Option) -> PyResult { - // TODO: Implement orchestrator initialization - let _ = rpc_url; - let _ = private_key; - Ok(Self {}) - } -} - -/// Prime Protocol Validator Client - for validating task results -#[pyclass] -pub struct ValidatorClient { - // TODO: Implement validator-specific functionality -} - -#[pymethods] -impl ValidatorClient { - #[new] - #[pyo3(signature = (rpc_url, private_key=None))] - pub fn new(rpc_url: String, private_key: Option) -> PyResult { - // TODO: Implement validator initialization - let _ = rpc_url; - let _ = private_key; - Ok(Self {}) - } -} - #[pymodule] fn primeprotocol(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs new file mode 100644 index 00000000..39f4d915 --- /dev/null +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -0,0 +1,19 @@ +use pyo3::prelude::*; + +/// Prime Protocol Orchestrator Client - for managing and distributing tasks +#[pyclass] +pub struct OrchestratorClient { + // TODO: Implement orchestrator-specific functionality +} + +#[pymethods] +impl OrchestratorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None))] + pub fn new(rpc_url: String, private_key: Option) -> PyResult { + // TODO: Implement orchestrator initialization + let _ = rpc_url; + let _ = private_key; + Ok(Self {}) + } +} diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs new file mode 100644 index 00000000..ed02939c --- /dev/null +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -0,0 +1,19 @@ +use pyo3::prelude::*; + +/// Prime Protocol Validator Client - for validating task results +#[pyclass] +pub(crate) struct ValidatorClient { + // TODO: Implement validator-specific functionality +} + +#[pymethods] +impl ValidatorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None))] + pub fn new(rpc_url: String, private_key: Option) -> PyResult { + // TODO: Implement validator initialization + let _ = rpc_url; + let _ = private_key; + Ok(Self {}) + } +} diff --git a/crates/prime-protocol-py/src/worker.rs b/crates/prime-protocol-py/src/worker/client.rs similarity index 99% rename from crates/prime-protocol-py/src/worker.rs rename to crates/prime-protocol-py/src/worker/client.rs index d7459ba4..e15ed7c6 100644 --- a/crates/prime-protocol-py/src/worker.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -1,5 +1,5 @@ use crate::error::{PrimeProtocolError, Result}; -use crate::message_queue::MessageQueue; +use crate::worker::message_queue::MessageQueue; use alloy::primitives::utils::format_ether; use alloy::primitives::{Address, U256}; use prime_core::operations::compute_node::ComputeNodeOperations; diff --git a/crates/prime-protocol-py/src/message_queue.rs b/crates/prime-protocol-py/src/worker/message_queue.rs similarity index 100% rename from crates/prime-protocol-py/src/message_queue.rs rename to crates/prime-protocol-py/src/worker/message_queue.rs diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs new file mode 100644 index 00000000..02f7634f --- /dev/null +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -0,0 +1,88 @@ +use pyo3::prelude::*; +mod client; +pub(crate) mod message_queue; +pub(crate) use client::WorkerClientCore; + +/// Prime Protocol Worker Client - for compute nodes that execute tasks +#[pyclass] +pub(crate) struct WorkerClient { + inner: WorkerClientCore, + runtime: Option, +} + +#[pymethods] +impl WorkerClient { + #[new] + #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None))] + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + ) -> PyResult { + let inner = WorkerClientCore::new( + compute_pool_id, + rpc_url, + private_key_provider, + private_key_node, + None, + None, + ) + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(Self { + inner, + runtime: None, + }) + } + + pub fn start(&mut self) -> PyResult<()> { + // Create a new runtime for this call + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Run the async function + let result = rt.block_on(self.inner.start_async()); + + // Store the runtime for future use + self.runtime = Some(rt); + + result.map_err(|e| PyErr::new::(e.to_string())) + } + + pub fn get_pool_owner_message(&self) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(rt.block_on(self.inner.get_message_queue().get_pool_owner_message())) + } else { + Err(PyErr::new::( + "Client not started. Call start() first.".to_string(), + )) + } + } + + pub fn get_validator_message(&self) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(rt.block_on(self.inner.get_message_queue().get_validator_message())) + } else { + Err(PyErr::new::( + "Client not started. Call start() first.".to_string(), + )) + } + } + + pub fn stop(&mut self) -> PyResult<()> { + if let Some(rt) = self.runtime.as_ref() { + rt.block_on(self.inner.stop_async()) + .map_err(|e| PyErr::new::(e.to_string()))?; + } + + // Clean up the runtime + if let Some(rt) = self.runtime.take() { + rt.shutdown_background(); + } + + Ok(()) + } +} diff --git a/crates/worker/src/checks/hardware/storage.rs:236:1 b/crates/worker/src/checks/hardware/storage.rs:236:1 new file mode 100644 index 00000000..e69de29b From 47bc4f47fb5a439fdb7d458952586ab8ca12b14d Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Fri, 11 Jul 2025 18:58:36 +0200 Subject: [PATCH 04/23] fix async gil issues, add bootstrap cmd to Makefile --- Makefile | 12 ++++ .../prime-protocol-py/examples/basic_usage.py | 18 +++-- crates/prime-protocol-py/src/worker/client.rs | 70 +++++++++++++------ crates/prime-protocol-py/src/worker/mod.rs | 22 +++--- 4 files changed, 87 insertions(+), 35 deletions(-) diff --git a/Makefile b/Makefile index dfc0d0af..5de39578 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,18 @@ up: @# Attach to session @tmux attach-session -t prime-dev +# Start Docker services and deploy contracts only +.PHONY: bootstrap +bootstrap: + @echo "Starting Docker services and deploying contracts..." + @# Start Docker services + @docker compose up -d reth redis --wait --wait-timeout 180 + @# Deploy contracts + @cd smart-contracts && sh deploy.sh && sh deploy_work_validation.sh && cd .. + @# Run setup + @$(MAKE) setup + @echo "Bootstrap complete - Docker services running and contracts deployed" + # Stop development environment .PHONY: down down: diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py index 66572db7..02b19bd9 100644 --- a/crates/prime-protocol-py/examples/basic_usage.py +++ b/crates/prime-protocol-py/examples/basic_usage.py @@ -12,7 +12,7 @@ FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' logging.basicConfig(format=FORMAT) -logging.getLogger().setLevel(logging.DEBUG) +logging.getLogger().setLevel(logging.INFO) def handle_pool_owner_message(message: Dict[str, Any]) -> None: @@ -85,22 +85,30 @@ def signal_handler(sig, frame): logging.error(f"Error during shutdown: {e}") sys.exit(0) - # Register signal handler for Ctrl+C + # Register signal handler for Ctrl+C before starting client signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: + logging.info("Starting client... (Press Ctrl+C to interrupt)") client.start() logging.info("Setup completed. Starting message polling loop...") print("Worker client started. Polling for messages. Press Ctrl+C to stop.") # Message polling loop while True: - check_for_messages(client) - time.sleep(0.1) # Small delay to prevent busy waiting + try: + check_for_messages(client) + time.sleep(0.1) # Small delay to prevent busy waiting + except KeyboardInterrupt: + # Handle Ctrl+C during message polling + logging.info("Keyboard interrupt received during polling") + signal_handler(signal.SIGINT, None) + break except KeyboardInterrupt: - logging.info("Keyboard interrupt received") + # Handle Ctrl+C during client startup + logging.info("Keyboard interrupt received during startup") signal_handler(signal.SIGINT, None) except Exception as e: logging.error(f"Unexpected error: {e}") diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index e15ed7c6..db30c0b4 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -197,12 +197,23 @@ impl WorkerClientCore { log::info!("Required stake: {}", format_ether(required_stake)); - provider_ops - .retry_register_provider(required_stake, self.funding_retry_count, None) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to register provider: {}", e)) - })?; + // Add timeout to prevent hanging on blockchain operations + let register_future = + provider_ops.retry_register_provider(required_stake, self.funding_retry_count, None); + + tokio::time::timeout( + tokio::time::Duration::from_secs(300), // 5 minute timeout + register_future, + ) + .await + .map_err(|_| { + PrimeProtocolError::BlockchainError( + "Provider registration timed out after 5 minutes".to_string(), + ) + })? + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to register provider: {}", e)) + })?; log::info!("Provider registered successfully"); Ok(()) @@ -287,12 +298,22 @@ impl WorkerClientCore { format_ether(current_stake) ); - provider_ops - .increase_stake(required_stake - current_stake) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) - })?; + // Add timeout to prevent hanging on stake increase operations + let stake_future = provider_ops.increase_stake(required_stake - current_stake); + + tokio::time::timeout( + tokio::time::Duration::from_secs(300), // 5 minute timeout + stake_future, + ) + .await + .map_err(|_| { + PrimeProtocolError::BlockchainError( + "Stake increase timed out after 5 minutes".to_string(), + ) + })? + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) + })?; log::info!("Successfully increased stake"); Ok(()) @@ -339,15 +360,22 @@ impl WorkerClientCore { ) -> Result<()> { let compute_units = U256::from(1); // TODO: Make configurable - compute_node_ops - .add_compute_node(compute_units) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to register compute node: {}", - e - )) - })?; + // Add timeout to prevent hanging on compute node registration + let register_future = compute_node_ops.add_compute_node(compute_units); + + tokio::time::timeout( + tokio::time::Duration::from_secs(300), // 5 minute timeout + register_future, + ) + .await + .map_err(|_| { + PrimeProtocolError::BlockchainError( + "Compute node registration timed out after 5 minutes".to_string(), + ) + })? + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to register compute node: {}", e)) + })?; log::info!("Compute node registered successfully"); Ok(()) diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index 02f7634f..b28e2216 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -36,15 +36,15 @@ impl WorkerClient { }) } - pub fn start(&mut self) -> PyResult<()> { + pub fn start(&mut self, py: Python) -> PyResult<()> { // Create a new runtime for this call let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .map_err(|e| PyErr::new::(e.to_string()))?; - // Run the async function - let result = rt.block_on(self.inner.start_async()); + // Run the async function with GIL released + let result = py.allow_threads(|| rt.block_on(self.inner.start_async())); // Store the runtime for future use self.runtime = Some(rt); @@ -52,9 +52,11 @@ impl WorkerClient { result.map_err(|e| PyErr::new::(e.to_string())) } - pub fn get_pool_owner_message(&self) -> PyResult> { + pub fn get_pool_owner_message(&self, py: Python) -> PyResult> { if let Some(rt) = self.runtime.as_ref() { - Ok(rt.block_on(self.inner.get_message_queue().get_pool_owner_message())) + Ok(py.allow_threads(|| { + rt.block_on(self.inner.get_message_queue().get_pool_owner_message()) + })) } else { Err(PyErr::new::( "Client not started. Call start() first.".to_string(), @@ -62,9 +64,11 @@ impl WorkerClient { } } - pub fn get_validator_message(&self) -> PyResult> { + pub fn get_validator_message(&self, py: Python) -> PyResult> { if let Some(rt) = self.runtime.as_ref() { - Ok(rt.block_on(self.inner.get_message_queue().get_validator_message())) + Ok(py.allow_threads(|| { + rt.block_on(self.inner.get_message_queue().get_validator_message()) + })) } else { Err(PyErr::new::( "Client not started. Call start() first.".to_string(), @@ -72,9 +76,9 @@ impl WorkerClient { } } - pub fn stop(&mut self) -> PyResult<()> { + pub fn stop(&mut self, py: Python) -> PyResult<()> { if let Some(rt) = self.runtime.as_ref() { - rt.block_on(self.inner.stop_async()) + py.allow_threads(|| rt.block_on(self.inner.stop_async())) .map_err(|e| PyErr::new::(e.to_string()))?; } From 58d2aec148eb12b9cdf2811289432bcbbe0b7fa1 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Fri, 11 Jul 2025 19:34:26 +0200 Subject: [PATCH 05/23] cleanup message queue setup --- .../prime-protocol-py/src/orchestrator/mod.rs | 36 ++++ .../src/utils/message_queue.rs | 152 +++++++++++++++++ crates/prime-protocol-py/src/utils/mod.rs | 3 +- .../src/validator/message_queue.rs | 46 +++++ crates/prime-protocol-py/src/validator/mod.rs | 93 ++++++++++- .../src/worker/message_queue.rs | 158 +++++------------- crates/prime-protocol-py/src/worker/mod.rs | 1 - 7 files changed, 366 insertions(+), 123 deletions(-) create mode 100644 crates/prime-protocol-py/src/utils/message_queue.rs create mode 100644 crates/prime-protocol-py/src/validator/message_queue.rs diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index 39f4d915..c610ea6f 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -16,4 +16,40 @@ impl OrchestratorClient { let _ = private_key; Ok(Self {}) } + + pub fn list_validated_nodes(&self) -> PyResult> { + // TODO: Implement orchestrator node listing + Ok(vec![]) + } + + pub fn list_nodes_from_chain(&self) -> PyResult> { + // TODO: Implement orchestrator node listing from chain + Ok(vec![]) + } + + // pub fn get_node_details(&self, node_id: String) -> PyResult> { + // // TODO: Implement orchestrator node details fetching + // Ok(None) + // } + + // pub fn get_node_details_from_chain(&self, node_id: String) -> PyResult> { + // // TODO: Implement orchestrator node details fetching from chain + // Ok(None) + // } + + // pub fn send_invite_to_node(&self, node_id: String) -> PyResult<()> { + // // TODO: Implement orchestrator node invite sending + // Ok(()) + // } + + // pub fn send_request_to_node(&self, node_id: String, request: String) -> PyResult<()> { + // // TODO: Implement orchestrator node request sending + // Ok(()) + // } + + // // TODO: Sender of this message? + // pub fn read_message(&self) -> PyResult> { + // // TODO: Implement orchestrator message reading + // Ok(None) + // } } diff --git a/crates/prime-protocol-py/src/utils/message_queue.rs b/crates/prime-protocol-py/src/utils/message_queue.rs new file mode 100644 index 00000000..43153cb1 --- /dev/null +++ b/crates/prime-protocol-py/src/utils/message_queue.rs @@ -0,0 +1,152 @@ +use pyo3::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::Mutex; +use tokio::time::{interval, Duration}; + +use crate::utils::json_parser::json_to_pyobject; + +/// Generic message that can be sent between components +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub content: serde_json::Value, + pub timestamp: u64, + pub sender: Option, +} + +/// Simple message queue for handling messages +#[derive(Clone)] +pub struct MessageQueue { + queue: Arc>>, + max_size: Option, + shutdown_tx: Arc>>>, +} + +impl MessageQueue { + /// Create a new message queue + pub fn new(max_size: Option) -> Self { + Self { + queue: Arc::new(Mutex::new(VecDeque::new())), + max_size, + shutdown_tx: Arc::new(Mutex::new(None)), + } + } + + /// Push a message to the queue + pub async fn push_message(&self, message: Message) -> Result<(), String> { + let mut queue = self.queue.lock().await; + + // Check max size if configured + if let Some(max_size) = self.max_size { + if queue.len() >= max_size { + return Err(format!("Queue is full (max size: {})", max_size)); + } + } + + queue.push_back(message); + Ok(()) + } + + /// Get the next message from the queue + pub async fn get_message(&self) -> Option { + let mut queue = self.queue.lock().await; + + queue + .pop_front() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + } + + /// Get all messages from the queue (draining it) + pub async fn get_all_messages(&self) -> Vec { + let mut queue = self.queue.lock().await; + + let messages: Vec = queue.drain(..).collect(); + messages + .into_iter() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + .collect() + } + + /// Peek at the next message without removing it + pub async fn peek_message(&self) -> Option { + let queue = self.queue.lock().await; + + queue + .front() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + } + + /// Get the size of the queue + pub async fn get_queue_size(&self) -> usize { + let queue = self.queue.lock().await; + queue.len() + } + + /// Clear the queue + pub async fn clear(&self) -> Result<(), String> { + let mut queue = self.queue.lock().await; + queue.clear(); + Ok(()) + } + + /// Start a mock message listener (for testing/development) + pub async fn start_mock_listener(&self, frequency: u64) -> Result<(), String> { + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + + // Store the shutdown sender + { + let mut tx_guard = self.shutdown_tx.lock().await; + *tx_guard = Some(shutdown_tx); + } + + let queue_clone = self.queue.clone(); + + // Spawn background task to simulate incoming messages + tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(1)); + let mut counter = 0u64; + + loop { + tokio::select! { + _ = ticker.tick() => { + if counter % frequency == 0 { + let message = Message { + content: serde_json::json!({ + "type": "mock_message", + "id": format!("mock_{}", counter), + "data": format!("Mock data #{}", counter), + }), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + sender: Some("mock_listener".to_string()), + }; + + let mut queue = queue_clone.lock().await; + queue.push_back(message); + log::debug!("Added mock message to queue"); + } + counter += 1; + } + _ = shutdown_rx.recv() => { + log::info!("Mock message listener shutting down"); + break; + } + } + } + }); + + Ok(()) + } + + /// Stop the mock listener + pub async fn stop_listener(&self) -> Result<(), String> { + if let Some(tx) = self.shutdown_tx.lock().await.take() { + let _ = tx.send(()).await; + } + Ok(()) + } +} diff --git a/crates/prime-protocol-py/src/utils/mod.rs b/crates/prime-protocol-py/src/utils/mod.rs index 3e9394ce..da6afad7 100644 --- a/crates/prime-protocol-py/src/utils/mod.rs +++ b/crates/prime-protocol-py/src/utils/mod.rs @@ -1 +1,2 @@ -pub mod json_parser; +pub(crate) mod json_parser; +pub(crate) mod message_queue; diff --git a/crates/prime-protocol-py/src/validator/message_queue.rs b/crates/prime-protocol-py/src/validator/message_queue.rs new file mode 100644 index 00000000..72f1b468 --- /dev/null +++ b/crates/prime-protocol-py/src/validator/message_queue.rs @@ -0,0 +1,46 @@ +use crate::utils::message_queue::{Message, MessageQueue as GenericMessageQueue}; +use pyo3::prelude::*; + +/// Validator-specific message queue for incoming validation results +#[derive(Clone)] +pub struct MessageQueue { + inner: GenericMessageQueue, +} + +impl MessageQueue { + /// Create a new validator message queue for validation results + pub fn new() -> Self { + let inner = GenericMessageQueue::new(None); + + Self { inner } + } + + /// Get the next validation result from nodes + pub async fn get_validation_result(&self) -> Option { + self.inner.get_message().await + } + + /// Push a validation result (for testing or internal use) + pub async fn push_validation_result(&self, content: serde_json::Value) -> Result<(), String> { + let message = Message { + content, + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + sender: None, // Will be set to the node ID when implemented + }; + + self.inner.push_message(message).await + } + + /// Get the number of pending validation results + pub async fn get_queue_size(&self) -> usize { + self.inner.get_queue_size().await + } + + /// Clear all validation results (use with caution) + pub async fn clear(&self) -> Result<(), String> { + self.inner.clear().await + } +} diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index ed02939c..6890e799 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -1,9 +1,29 @@ use pyo3::prelude::*; +pub(crate) mod message_queue; +use self::message_queue::MessageQueue; + +/// Node details for validator operations +#[pyclass] +#[derive(Clone)] +pub(crate) struct NodeDetails { + #[pyo3(get)] + pub address: String, +} + +#[pymethods] +impl NodeDetails { + #[new] + pub fn new(address: String) -> Self { + Self { address } + } +} + /// Prime Protocol Validator Client - for validating task results #[pyclass] pub(crate) struct ValidatorClient { - // TODO: Implement validator-specific functionality + message_queue: MessageQueue, + runtime: Option, } #[pymethods] @@ -14,6 +34,75 @@ impl ValidatorClient { // TODO: Implement validator initialization let _ = rpc_url; let _ = private_key; - Ok(Self {}) + + Ok(Self { + message_queue: MessageQueue::new(), + runtime: None, + }) + } + + /// Initialize the validator client and start listening for messages + pub fn start(&mut self, py: Python) -> PyResult<()> { + // Create a new runtime for this validator + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Store the runtime for future use + self.runtime = Some(rt); + + Ok(()) + } + + pub fn list_nodes(&self) -> PyResult> { + // TODO: Implement validator node listing from chain that are not yet validated + Ok(vec![]) + } + + pub fn fetch_node_details(&self, node_id: String) -> PyResult> { + // TODO: Implement validator node details fetching + Ok(None) + } + + pub fn mark_node_as_validated(&self, node_id: String) -> PyResult<()> { + // TODO: Implement validator node marking as validated + Ok(()) + } + + pub fn send_request_to_node(&self, node_id: String, request: String) -> PyResult<()> { + // TODO: Implement validator node request sending + Ok(()) + } + + pub fn send_request_to_node_address( + &self, + node_address: String, + request: String, + ) -> PyResult<()> { + // TODO: Implement validator node request sending to specific address + let _ = node_address; + let _ = request; + Ok(()) + } + + /// Get the latest validation result from the internal message queue + pub fn get_latest_message(&self, py: Python) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(py.allow_threads(|| rt.block_on(self.message_queue.get_validation_result()))) + } else { + Err(PyErr::new::( + "Validator not started. Call start() first.".to_string(), + )) + } + } + + /// Get the number of pending validation results + pub fn get_queue_size(&self, py: Python) -> PyResult { + if let Some(rt) = self.runtime.as_ref() { + Ok(py.allow_threads(|| rt.block_on(self.message_queue.get_queue_size()))) + } else { + Ok(0) + } } } diff --git a/crates/prime-protocol-py/src/worker/message_queue.rs b/crates/prime-protocol-py/src/worker/message_queue.rs index 9af9a687..167fde05 100644 --- a/crates/prime-protocol-py/src/worker/message_queue.rs +++ b/crates/prime-protocol-py/src/worker/message_queue.rs @@ -1,160 +1,80 @@ +use crate::utils::message_queue::{Message, MessageQueue as GenericMessageQueue}; use pyo3::prelude::*; -use serde::{Deserialize, Serialize}; -use std::collections::VecDeque; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::sync::Mutex; -use tokio::time::{interval, Duration}; -use crate::utils::json_parser::json_to_pyobject; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub message_type: MessageType, - pub content: serde_json::Value, - pub timestamp: u64, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum MessageType { +/// Queue types for the worker message queue +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum QueueType { PoolOwner, Validator, } +/// Worker-specific message queue with predefined queue types #[derive(Clone)] pub struct MessageQueue { - pool_owner_queue: Arc>>, - validator_queue: Arc>>, - shutdown_tx: Arc>>>, + pool_owner_queue: GenericMessageQueue, + validator_queue: GenericMessageQueue, } impl MessageQueue { + /// Create a new worker message queue with pool_owner and validator queues pub fn new() -> Self { Self { - pool_owner_queue: Arc::new(Mutex::new(VecDeque::new())), - validator_queue: Arc::new(Mutex::new(VecDeque::new())), - shutdown_tx: Arc::new(Mutex::new(None)), + pool_owner_queue: GenericMessageQueue::new(None), + validator_queue: GenericMessageQueue::new(None), } } - /// Start the background message listener + /// Start the background message listener for worker pub(crate) async fn start_listener(&self) -> Result<(), String> { - let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - - // Store the shutdown sender - { - let mut tx_guard = self.shutdown_tx.lock().await; - *tx_guard = Some(shutdown_tx); - } - - let pool_owner_queue = self.pool_owner_queue.clone(); - let validator_queue = self.validator_queue.clone(); - - // Spawn background task to simulate incoming p2p messages - tokio::spawn(async move { - let mut ticker = interval(Duration::from_secs(5)); - let mut counter = 0u64; - - loop { - tokio::select! { - _ = ticker.tick() => { - // Mock pool owner messages - if counter % 2 == 0 { - let message = Message { - message_type: MessageType::PoolOwner, - content: serde_json::json!({ - "type": "inference_request", - "task_id": format!("task_{}", counter), - "prompt": format!("Test prompt {}", counter), - }), - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - }; - - let mut queue = pool_owner_queue.lock().await; - queue.push_back(message); - log::debug!("Added mock pool owner message to queue"); - } - - // Mock validator messages - if counter % 3 == 0 { - let message = Message { - message_type: MessageType::Validator, - content: serde_json::json!({ - "type": "validation_request", - "task_id": format!("validation_{}", counter), - }), - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - }; - - let mut queue = validator_queue.lock().await; - queue.push_back(message); - log::debug!("Added mock validator message to queue"); - } - - counter += 1; - } - _ = shutdown_rx.recv() => { - log::info!("Message listener shutting down"); - break; - } - } - } - }); - + // Start mock listeners with different frequencies + // pool_owner messages every 2 seconds, validator messages every 3 seconds + self.pool_owner_queue.start_mock_listener(2).await?; + self.validator_queue.start_mock_listener(3).await?; Ok(()) } /// Stop the background listener - #[allow(unused)] pub(crate) async fn stop_listener(&self) -> Result<(), String> { - if let Some(tx) = self.shutdown_tx.lock().await.take() { - let _ = tx.send(()).await; - } + self.pool_owner_queue.stop_listener().await?; + self.validator_queue.stop_listener().await?; Ok(()) } + /// Get the next message from the pool owner queue pub(crate) async fn get_pool_owner_message(&self) -> Option { - let mut queue = self.pool_owner_queue.lock().await; - queue - .pop_front() - .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + self.pool_owner_queue.get_message().await } /// Get the next message from the validator queue pub(crate) async fn get_validator_message(&self) -> Option { - let mut queue = self.validator_queue.lock().await; - queue - .pop_front() - .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + self.validator_queue.get_message().await } /// Push a message to the appropriate queue (for testing or internal use) - #[allow(unused)] - pub(crate) async fn push_message(&self, message: Message) -> Result<(), String> { - match message.message_type { - MessageType::PoolOwner => { - let mut queue = self.pool_owner_queue.lock().await; - queue.push_back(message); - } - MessageType::Validator => { - let mut queue = self.validator_queue.lock().await; - queue.push_back(message); - } + pub(crate) async fn push_message( + &self, + queue_type: QueueType, + content: serde_json::Value, + ) -> Result<(), String> { + let message = Message { + content, + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + sender: Some("worker".to_string()), + }; + + match queue_type { + QueueType::PoolOwner => self.pool_owner_queue.push_message(message).await, + QueueType::Validator => self.validator_queue.push_message(message).await, } - Ok(()) } /// Get queue sizes for monitoring - #[allow(unused)] pub(crate) async fn get_queue_sizes(&self) -> (usize, usize) { - let pool_owner_size = self.pool_owner_queue.lock().await.len(); - let validator_size = self.validator_queue.lock().await.len(); + let pool_owner_size = self.pool_owner_queue.get_queue_size().await; + let validator_size = self.validator_queue.get_queue_size().await; (pool_owner_size, validator_size) } } diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index b28e2216..a308df12 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -2,7 +2,6 @@ use pyo3::prelude::*; mod client; pub(crate) mod message_queue; pub(crate) use client::WorkerClientCore; - /// Prime Protocol Worker Client - for compute nodes that execute tasks #[pyclass] pub(crate) struct WorkerClient { From 71ee224d58e83a25ebaf49e08acada0c9e18cdd5 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 12:21:40 +0200 Subject: [PATCH 06/23] integrate basic p2p messaging between nodes, auth flow with auth manager --- Cargo.lock | 6 + crates/p2p/src/lib.rs | 2 + crates/p2p/src/message/mod.rs | 4 +- crates/prime-protocol-py/Cargo.toml | 6 + .../prime-protocol-py/examples/basic_usage.py | 33 +- .../src/utils/message_queue.rs | 152 ---- crates/prime-protocol-py/src/utils/mod.rs | 1 - .../src/validator/message_queue.rs | 46 -- crates/prime-protocol-py/src/validator/mod.rs | 36 +- crates/prime-protocol-py/src/worker/auth.rs | 284 ++++++++ .../src/worker/blockchain.rs | 269 +++++++ crates/prime-protocol-py/src/worker/client.rs | 683 +++++++++--------- .../prime-protocol-py/src/worker/constants.rs | 26 + .../src/worker/message_processor.rs | 270 +++++++ .../src/worker/message_queue.rs | 80 -- crates/prime-protocol-py/src/worker/mod.rs | 166 ++++- crates/prime-protocol-py/src/worker/p2p.rs | 492 +++++++++++++ crates/shared/src/discovery/README.md | 60 ++ crates/shared/src/discovery/mod.rs | 127 ++++ crates/shared/src/p2p/service.rs | 2 +- 20 files changed, 2074 insertions(+), 671 deletions(-) delete mode 100644 crates/prime-protocol-py/src/utils/message_queue.rs delete mode 100644 crates/prime-protocol-py/src/validator/message_queue.rs create mode 100644 crates/prime-protocol-py/src/worker/auth.rs create mode 100644 crates/prime-protocol-py/src/worker/blockchain.rs create mode 100644 crates/prime-protocol-py/src/worker/constants.rs create mode 100644 crates/prime-protocol-py/src/worker/message_processor.rs delete mode 100644 crates/prime-protocol-py/src/worker/message_queue.rs create mode 100644 crates/prime-protocol-py/src/worker/p2p.rs create mode 100644 crates/shared/src/discovery/README.md create mode 100644 crates/shared/src/discovery/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 41c5f51c..918c65c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6726,11 +6726,16 @@ version = "0.1.0" dependencies = [ "alloy", "alloy-provider", + "anyhow", + "futures", + "hex", "log", + "p2p", "prime-core", "pyo3", "pyo3-log", "pythonize", + "rand 0.8.5", "serde", "serde_json", "shared", @@ -6738,6 +6743,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-test", + "tokio-util", "url", ] diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs index f5bc648c..6542c7eb 100644 --- a/crates/p2p/src/lib.rs +++ b/crates/p2p/src/lib.rs @@ -274,6 +274,8 @@ impl NodeBuilder { cancellation_token, } = self; + println!("multi addrs: {:?}", listen_addrs); + let keypair = keypair.unwrap_or(identity::Keypair::generate_ed25519()); let peer_id = keypair.public().to_peer_id(); diff --git a/crates/p2p/src/message/mod.rs b/crates/p2p/src/message/mod.rs index 74b09c5a..5d4431d8 100644 --- a/crates/p2p/src/message/mod.rs +++ b/crates/p2p/src/message/mod.rs @@ -229,7 +229,7 @@ impl From for Response { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeneralRequest { - data: Vec, + pub data: Vec, } impl From for Request { @@ -240,7 +240,7 @@ impl From for Request { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeneralResponse { - data: Vec, + pub data: Vec, } impl From for Response { diff --git a/crates/prime-protocol-py/Cargo.toml b/crates/prime-protocol-py/Cargo.toml index cbb7b513..8f4d96be 100644 --- a/crates/prime-protocol-py/Cargo.toml +++ b/crates/prime-protocol-py/Cargo.toml @@ -14,6 +14,7 @@ pyo3 = { version = "0.25.1", features = ["extension-module"] } thiserror = "1.0" shared = { workspace = true } prime-core = { workspace = true } +p2p = { workspace = true } alloy = { workspace = true } alloy-provider = { workspace = true } tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "sync", "time", "macros"] } @@ -23,6 +24,11 @@ pyo3-log = "0.12.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" pythonize = "0.25" +futures = { workspace = true } +anyhow = { workspace = true } +tokio-util = { workspace = true } +rand = { version = "0.8", features = ["std"] } +hex = "0.4" [dev-dependencies] test-log = "0.2" diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py index 02b19bd9..dd72b548 100644 --- a/crates/prime-protocol-py/examples/basic_usage.py +++ b/crates/prime-protocol-py/examples/basic_usage.py @@ -74,7 +74,12 @@ def main(): private_key_node = os.getenv("PRIVATE_KEY_NODE", None) logging.info(f"Connecting to: {rpc_url}") - client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node) + + peer_id = os.getenv("PEER_ID", "12D3KooWELi4p1oR3QBSYiq1rvPpyjbkiQVhQJqCobBBUS7C6JrX") + port = int(os.getenv("PORT", 8003)) + peer_port = int(os.getenv("PEER_PORT", port-1)) + send_message_to_peer = os.getenv("SEND_MESSAGE_TO_PEER", "True").lower() == "true" + client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node, port) def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down gracefully...") @@ -92,13 +97,37 @@ def signal_handler(sig, frame): try: logging.info("Starting client... (Press Ctrl+C to interrupt)") client.start() + + my_peer_id = client.get_own_peer_id() + logging.info(f"My Peer ID: {my_peer_id}") + + time.sleep(5) + peer_multi_addr = f"/ip4/127.0.0.1/tcp/{peer_port}" + + if send_message_to_peer: + print(f"Sending message to peer: {peer_id} on {peer_multi_addr}") + client.send_message(peer_id, b"Hello, world!", [peer_multi_addr]) + logging.info("Setup completed. Starting message polling loop...") print("Worker client started. Polling for messages. Press Ctrl+C to stop.") # Message polling loop while True: try: - check_for_messages(client) + message = client.get_next_message() + if message: + logging.info(f"Received full message: {message}") + logging.info(f"Received message from peer {message['peer_id']}") + if message.get('sender_address'): + logging.info(f"Sender Ethereum address: {message['sender_address']}") + + msg_data = message.get('message', {}) + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + logging.info(f"Message data: {data}") + else: + logging.info(f"Message type: {msg_data.get('type')}") + time.sleep(0.1) # Small delay to prevent busy waiting except KeyboardInterrupt: # Handle Ctrl+C during message polling diff --git a/crates/prime-protocol-py/src/utils/message_queue.rs b/crates/prime-protocol-py/src/utils/message_queue.rs deleted file mode 100644 index 43153cb1..00000000 --- a/crates/prime-protocol-py/src/utils/message_queue.rs +++ /dev/null @@ -1,152 +0,0 @@ -use pyo3::prelude::*; -use serde::{Deserialize, Serialize}; -use std::collections::VecDeque; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::sync::Mutex; -use tokio::time::{interval, Duration}; - -use crate::utils::json_parser::json_to_pyobject; - -/// Generic message that can be sent between components -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub content: serde_json::Value, - pub timestamp: u64, - pub sender: Option, -} - -/// Simple message queue for handling messages -#[derive(Clone)] -pub struct MessageQueue { - queue: Arc>>, - max_size: Option, - shutdown_tx: Arc>>>, -} - -impl MessageQueue { - /// Create a new message queue - pub fn new(max_size: Option) -> Self { - Self { - queue: Arc::new(Mutex::new(VecDeque::new())), - max_size, - shutdown_tx: Arc::new(Mutex::new(None)), - } - } - - /// Push a message to the queue - pub async fn push_message(&self, message: Message) -> Result<(), String> { - let mut queue = self.queue.lock().await; - - // Check max size if configured - if let Some(max_size) = self.max_size { - if queue.len() >= max_size { - return Err(format!("Queue is full (max size: {})", max_size)); - } - } - - queue.push_back(message); - Ok(()) - } - - /// Get the next message from the queue - pub async fn get_message(&self) -> Option { - let mut queue = self.queue.lock().await; - - queue - .pop_front() - .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) - } - - /// Get all messages from the queue (draining it) - pub async fn get_all_messages(&self) -> Vec { - let mut queue = self.queue.lock().await; - - let messages: Vec = queue.drain(..).collect(); - messages - .into_iter() - .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) - .collect() - } - - /// Peek at the next message without removing it - pub async fn peek_message(&self) -> Option { - let queue = self.queue.lock().await; - - queue - .front() - .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) - } - - /// Get the size of the queue - pub async fn get_queue_size(&self) -> usize { - let queue = self.queue.lock().await; - queue.len() - } - - /// Clear the queue - pub async fn clear(&self) -> Result<(), String> { - let mut queue = self.queue.lock().await; - queue.clear(); - Ok(()) - } - - /// Start a mock message listener (for testing/development) - pub async fn start_mock_listener(&self, frequency: u64) -> Result<(), String> { - let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - - // Store the shutdown sender - { - let mut tx_guard = self.shutdown_tx.lock().await; - *tx_guard = Some(shutdown_tx); - } - - let queue_clone = self.queue.clone(); - - // Spawn background task to simulate incoming messages - tokio::spawn(async move { - let mut ticker = interval(Duration::from_secs(1)); - let mut counter = 0u64; - - loop { - tokio::select! { - _ = ticker.tick() => { - if counter % frequency == 0 { - let message = Message { - content: serde_json::json!({ - "type": "mock_message", - "id": format!("mock_{}", counter), - "data": format!("Mock data #{}", counter), - }), - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - sender: Some("mock_listener".to_string()), - }; - - let mut queue = queue_clone.lock().await; - queue.push_back(message); - log::debug!("Added mock message to queue"); - } - counter += 1; - } - _ = shutdown_rx.recv() => { - log::info!("Mock message listener shutting down"); - break; - } - } - } - }); - - Ok(()) - } - - /// Stop the mock listener - pub async fn stop_listener(&self) -> Result<(), String> { - if let Some(tx) = self.shutdown_tx.lock().await.take() { - let _ = tx.send(()).await; - } - Ok(()) - } -} diff --git a/crates/prime-protocol-py/src/utils/mod.rs b/crates/prime-protocol-py/src/utils/mod.rs index da6afad7..0ab14864 100644 --- a/crates/prime-protocol-py/src/utils/mod.rs +++ b/crates/prime-protocol-py/src/utils/mod.rs @@ -1,2 +1 @@ pub(crate) mod json_parser; -pub(crate) mod message_queue; diff --git a/crates/prime-protocol-py/src/validator/message_queue.rs b/crates/prime-protocol-py/src/validator/message_queue.rs deleted file mode 100644 index 72f1b468..00000000 --- a/crates/prime-protocol-py/src/validator/message_queue.rs +++ /dev/null @@ -1,46 +0,0 @@ -use crate::utils::message_queue::{Message, MessageQueue as GenericMessageQueue}; -use pyo3::prelude::*; - -/// Validator-specific message queue for incoming validation results -#[derive(Clone)] -pub struct MessageQueue { - inner: GenericMessageQueue, -} - -impl MessageQueue { - /// Create a new validator message queue for validation results - pub fn new() -> Self { - let inner = GenericMessageQueue::new(None); - - Self { inner } - } - - /// Get the next validation result from nodes - pub async fn get_validation_result(&self) -> Option { - self.inner.get_message().await - } - - /// Push a validation result (for testing or internal use) - pub async fn push_validation_result(&self, content: serde_json::Value) -> Result<(), String> { - let message = Message { - content, - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - sender: None, // Will be set to the node ID when implemented - }; - - self.inner.push_message(message).await - } - - /// Get the number of pending validation results - pub async fn get_queue_size(&self) -> usize { - self.inner.get_queue_size().await - } - - /// Clear all validation results (use with caution) - pub async fn clear(&self) -> Result<(), String> { - self.inner.clear().await - } -} diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index 6890e799..ff649e32 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -1,8 +1,5 @@ use pyo3::prelude::*; -pub(crate) mod message_queue; -use self::message_queue::MessageQueue; - /// Node details for validator operations #[pyclass] #[derive(Clone)] @@ -22,7 +19,6 @@ impl NodeDetails { /// Prime Protocol Validator Client - for validating task results #[pyclass] pub(crate) struct ValidatorClient { - message_queue: MessageQueue, runtime: Option, } @@ -35,14 +31,11 @@ impl ValidatorClient { let _ = rpc_url; let _ = private_key; - Ok(Self { - message_queue: MessageQueue::new(), - runtime: None, - }) + Ok(Self { runtime: None }) } /// Initialize the validator client and start listening for messages - pub fn start(&mut self, py: Python) -> PyResult<()> { + pub fn start(&mut self, _py: Python) -> PyResult<()> { // Create a new runtime for this validator let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -60,17 +53,17 @@ impl ValidatorClient { Ok(vec![]) } - pub fn fetch_node_details(&self, node_id: String) -> PyResult> { + pub fn fetch_node_details(&self, _node_id: String) -> PyResult> { // TODO: Implement validator node details fetching Ok(None) } - pub fn mark_node_as_validated(&self, node_id: String) -> PyResult<()> { + pub fn mark_node_as_validated(&self, _node_id: String) -> PyResult<()> { // TODO: Implement validator node marking as validated Ok(()) } - pub fn send_request_to_node(&self, node_id: String, request: String) -> PyResult<()> { + pub fn send_request_to_node(&self, _node_id: String, _request: String) -> PyResult<()> { // TODO: Implement validator node request sending Ok(()) } @@ -86,23 +79,8 @@ impl ValidatorClient { Ok(()) } - /// Get the latest validation result from the internal message queue - pub fn get_latest_message(&self, py: Python) -> PyResult> { - if let Some(rt) = self.runtime.as_ref() { - Ok(py.allow_threads(|| rt.block_on(self.message_queue.get_validation_result()))) - } else { - Err(PyErr::new::( - "Validator not started. Call start() first.".to_string(), - )) - } - } - /// Get the number of pending validation results - pub fn get_queue_size(&self, py: Python) -> PyResult { - if let Some(rt) = self.runtime.as_ref() { - Ok(py.allow_threads(|| rt.block_on(self.message_queue.get_queue_size()))) - } else { - Ok(0) - } + pub fn get_queue_size(&self) -> usize { + 0 } } diff --git a/crates/prime-protocol-py/src/worker/auth.rs b/crates/prime-protocol-py/src/worker/auth.rs new file mode 100644 index 00000000..1c31ca25 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/auth.rs @@ -0,0 +1,284 @@ +use crate::error::{PrimeProtocolError, Result}; +use crate::worker::p2p::Message; +use alloy::primitives::{Address, Signature}; +use rand::Rng; +use shared::security::request_signer::sign_message; +use shared::web3::wallet::Wallet; +use std::collections::{HashMap, HashSet}; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::RwLock; + +/* + * Authentication Flow: + * + * This module implements a mutual authentication protocol between peers using cryptographic signatures. + * The flow works as follows: + * + * 1. INITIATION: Peer A wants to send a message to Peer B but they're not authenticated + * - Peer A generates a random challenge (32 bytes, hex-encoded) + * - Peer A stores the challenge and queues the original message + * - Peer A sends AuthenticationInitiation{challenge} to Peer B + * + * 2. RESPONSE: Peer B receives the initiation + * - Peer B signs Peer A's challenge with their private key + * - Peer B generates their own challenge for Peer A + * - Peer B sends AuthenticationResponse{challenge: B's challenge, signature: signed A's challenge} + * + * 3. SOLUTION: Peer A receives the response + * - Peer A verifies Peer B's signature against the original challenge to get B's wallet address + * - Peer A signs Peer B's challenge with their private key + * - Peer A marks Peer B as authenticated + * - Peer A sends AuthenticationSolution{signature: signed B's challenge} + * - Peer A sends the originally queued message + * + * 4. COMPLETION: Peer B receives the solution + * - Peer B verifies Peer A's signature against their challenge to get A's wallet address + * - Peer B marks Peer A as authenticated + * - Both peers can now exchange messages freely + * + * Security Properties: + * - Mutual authentication: Both peers prove they control their private keys + * - Replay protection: Each challenge is randomly generated + * - Address verification: Signature recovery provides cryptographic proof of wallet ownership + */ + +/// Represents an ongoing authentication challenge with a peer +#[derive(Debug)] +pub struct OngoingAuthChallenge { + pub peer_wallet_address: Address, + pub auth_challenge_request_message: String, + pub outgoing_message: Message, + pub their_challenge: Option, // The challenge we received from them that we need to sign +} + +/// Manages authentication state and operations +pub struct AuthenticationManager { + /// Track ongoing authentication requests (when we initiate) + ongoing_auth_requests: Arc>>, + /// Track outgoing challenges (when they initiate and we respond) + outgoing_challenges: Arc>>, + /// Track peers we're responding to (to prevent initiating auth with them) + responding_to_peers: Arc>>, + /// Queue messages for peers we're authenticating with as responder + responder_message_queue: Arc>>>, + /// Track authenticated peers + authenticated_peers: Arc>>, + /// Our wallet for signing + node_wallet: Arc, +} + +impl AuthenticationManager { + pub fn new(node_wallet: Arc) -> Self { + Self { + ongoing_auth_requests: Arc::new(RwLock::new(HashMap::new())), + outgoing_challenges: Arc::new(RwLock::new(HashMap::new())), + responding_to_peers: Arc::new(RwLock::new(HashSet::new())), + responder_message_queue: Arc::new(RwLock::new(HashMap::new())), + authenticated_peers: Arc::new(RwLock::new(HashSet::new())), + node_wallet, + } + } + + /// Check if a peer is already authenticated + pub async fn is_authenticated(&self, peer_id: &str) -> bool { + self.authenticated_peers.read().await.contains(peer_id) + } + + /// Mark a peer as authenticated + pub async fn mark_authenticated(&self, peer_id: String) { + self.authenticated_peers.write().await.insert(peer_id); + } + + /// Check our role in ongoing authentication + pub async fn get_auth_role(&self, peer_id: &str) -> Option { + if self + .ongoing_auth_requests + .read() + .await + .contains_key(peer_id) + { + Some("initiator".to_string()) + } else if self.responding_to_peers.read().await.contains(peer_id) { + Some("responder".to_string()) + } else { + None + } + } + + /// Queue a message for a peer we're responding to + pub async fn queue_message_as_responder( + &self, + peer_id: String, + message: Message, + ) -> Result<()> { + let mut queue = self.responder_message_queue.write().await; + queue.entry(peer_id).or_insert_with(Vec::new).push(message); + Ok(()) + } + + /// Start authentication with a peer + pub async fn start_authentication( + &self, + peer_id: String, + outgoing_message: Message, + ) -> Result { + // Generate authentication challenge + let challenge_bytes: [u8; 32] = rand::thread_rng().gen(); + let auth_challenge_message = hex::encode(challenge_bytes); + + // Store the ongoing auth challenge + let mut ongoing_auth = self.ongoing_auth_requests.write().await; + ongoing_auth.insert( + peer_id, + OngoingAuthChallenge { + peer_wallet_address: Address::ZERO, // Will be updated when we get their signature + auth_challenge_request_message: auth_challenge_message.clone(), + outgoing_message, + their_challenge: None, + }, + ); + + Ok(auth_challenge_message) + } + + /// Handle incoming authentication initiation + pub async fn handle_auth_initiation( + &self, + peer_id: &str, + challenge: &str, + ) -> Result<(String, String)> { + // Mark that we're responding to this peer + self.responding_to_peers + .write() + .await + .insert(peer_id.to_string()); + + // Sign the challenge + let signature = sign_message(challenge, &self.node_wallet) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to sign challenge: {}", e)) + })?; + + // Generate our own challenge for the peer + let our_challenge_bytes: [u8; 32] = rand::thread_rng().gen(); + let our_challenge = hex::encode(our_challenge_bytes); + + // Store the challenge we're sending so we can verify the signature later + self.outgoing_challenges + .write() + .await + .insert(peer_id.to_string(), our_challenge.clone()); + + Ok((our_challenge, signature)) + } + + /// Handle authentication response from peer + pub async fn handle_auth_response( + &self, + peer_id: &str, + their_challenge: &str, + their_signature: &str, + ) -> Result<(String, Option)> { + // Verify we have an ongoing auth request for this peer + let mut ongoing_auth = self.ongoing_auth_requests.write().await; + let auth_challenge = ongoing_auth.get_mut(peer_id).ok_or_else(|| { + PrimeProtocolError::InvalidConfig(format!( + "No ongoing auth request for peer {}", + peer_id + )) + })?; + + // Verify their signature to get their address + let parsed_signature = Signature::from_str(their_signature).map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Invalid signature format: {}", e)) + })?; + + // Recover the peer's address from their signature + let recovered_address = parsed_signature + .recover_address_from_msg(&auth_challenge.auth_challenge_request_message) + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to recover address: {}", e)) + })?; + + // Update the peer's wallet address and store their challenge + auth_challenge.peer_wallet_address = recovered_address; + auth_challenge.their_challenge = Some(their_challenge.to_string()); + log::debug!("Recovered peer address: {}", recovered_address); + + // Sign their challenge + let our_signature = sign_message(their_challenge, &self.node_wallet) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to sign challenge: {}", e)) + })?; + + // Mark peer as authenticated + self.mark_authenticated(peer_id.to_string()).await; + + // Get the queued message to send after auth + let queued_message = ongoing_auth + .remove(peer_id) + .map(|auth| auth.outgoing_message); + + Ok((our_signature, queued_message)) + } + + /// Handle authentication solution from peer + pub async fn handle_auth_solution( + &self, + peer_id: &str, + signature: &str, + ) -> Result<(Address, Vec)> { + // Get the challenge we sent to this peer + let mut outgoing_challenges = self.outgoing_challenges.write().await; + let challenge = outgoing_challenges.remove(peer_id).ok_or_else(|| { + PrimeProtocolError::InvalidConfig(format!( + "No outgoing challenge found for peer {}", + peer_id + )) + })?; + + // Parse and verify the signature + let parsed_signature = Signature::from_str(signature).map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Invalid signature format: {}", e)) + })?; + + // Recover the peer's address from their signature of our challenge + let recovered_address = parsed_signature + .recover_address_from_msg(&challenge) + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to recover address: {}", e)) + })?; + + log::debug!( + "Verified auth solution from peer {} with address {}", + peer_id, + recovered_address + ); + self.mark_authenticated(peer_id.to_string()).await; + + // Clean up responding state + self.responding_to_peers.write().await.remove(peer_id); + + // Get any queued messages to send now that we're authenticated + let queued_messages = self + .responder_message_queue + .write() + .await + .remove(peer_id) + .unwrap_or_default(); + + Ok((recovered_address, queued_messages)) + } + + /// Get wallet address + pub fn wallet_address(&self) -> String { + self.node_wallet + .wallet + .default_signer() + .address() + .to_string() + } +} diff --git a/crates/prime-protocol-py/src/worker/blockchain.rs b/crates/prime-protocol-py/src/worker/blockchain.rs new file mode 100644 index 00000000..a3df3155 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/blockchain.rs @@ -0,0 +1,269 @@ +use alloy::primitives::utils::format_ether; +use alloy::primitives::U256; +use anyhow::{Context, Result}; +use prime_core::operations::compute_node::ComputeNodeOperations; +use prime_core::operations::provider::ProviderOperations; +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::contracts::structs::compute_pool::PoolStatus; +use shared::web3::wallet::{Wallet, WalletProvider}; +use url::Url; + +use crate::worker::constants::{BLOCKCHAIN_OPERATION_TIMEOUT, DEFAULT_COMPUTE_UNITS}; + +/// Configuration for blockchain operations +pub struct BlockchainConfig { + pub rpc_url: String, + pub compute_pool_id: u64, + pub private_key_provider: String, + pub private_key_node: String, + pub auto_accept_transactions: bool, + pub funding_retry_count: u32, +} + +/// Handles all blockchain-related operations for the worker +pub struct BlockchainService { + config: BlockchainConfig, + provider_wallet: Option, + node_wallet: Option, +} + +impl BlockchainService { + pub fn new(config: BlockchainConfig) -> Result { + // Validate RPC URL + Url::parse(&config.rpc_url).context("Invalid RPC URL format")?; + + Ok(Self { + config, + provider_wallet: None, + node_wallet: None, + }) + } + + /// Get the node wallet (used for authentication) + pub fn node_wallet(&self) -> Option<&Wallet> { + self.node_wallet.as_ref() + } + + /// Initialize blockchain components and ensure the node is properly registered + pub async fn initialize(&mut self) -> Result<()> { + let (provider_wallet, node_wallet, contracts) = self.create_wallets_and_contracts().await?; + + // Store the wallets + self.provider_wallet = Some(provider_wallet.clone()); + self.node_wallet = Some(node_wallet.clone()); + + self.wait_for_active_pool(&contracts).await?; + self.ensure_provider_registered(&provider_wallet, &contracts) + .await?; + self.ensure_compute_node_registered(&provider_wallet, &node_wallet, &contracts) + .await?; + + Ok(()) + } + + async fn create_wallets_and_contracts( + &self, + ) -> Result<(Wallet, Wallet, Contracts)> { + let rpc_url = Url::parse(&self.config.rpc_url)?; + + let provider_wallet = Wallet::new(&self.config.private_key_provider, rpc_url.clone()) + .map_err(|e| anyhow::anyhow!("Failed to create provider wallet: {}", e))?; + + let node_wallet = Wallet::new(&self.config.private_key_node, rpc_url.clone()) + .map_err(|e| anyhow::anyhow!("Failed to create node wallet: {}", e))?; + + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .context("Failed to build contracts")?; + + Ok((provider_wallet, node_wallet, contracts)) + } + + async fn wait_for_active_pool( + &self, + contracts: &Contracts, + ) -> Result { + loop { + match contracts + .compute_pool + .get_pool_info(U256::from(self.config.compute_pool_id)) + .await + { + Ok(pool) if pool.status == PoolStatus::ACTIVE => { + log::info!("Pool {} is active", self.config.compute_pool_id); + return Ok(pool); + } + Ok(pool) => { + log::info!( + "Pool {} is not active yet (status: {:?}), waiting...", + self.config.compute_pool_id, + pool.status + ); + tokio::time::sleep(crate::worker::constants::POOL_STATUS_CHECK_INTERVAL).await; + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to get pool info: {}", e)); + } + } + } + } + + async fn ensure_provider_registered( + &self, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let provider_ops = ProviderOperations::new( + provider_wallet.clone(), + contracts.clone(), + self.config.auto_accept_transactions, + ); + + let provider_exists = provider_ops + .check_provider_exists() + .await + .map_err(|e| anyhow::anyhow!("Failed to check if provider exists: {}", e))?; + + let is_whitelisted = provider_ops + .check_provider_whitelisted() + .await + .map_err(|e| anyhow::anyhow!("Failed to check provider whitelist status: {}", e))?; + + if !provider_exists || !is_whitelisted { + self.register_provider(&provider_ops, contracts).await?; + } else { + log::info!("Provider is already registered and whitelisted"); + } + + self.ensure_adequate_stake(&provider_ops, provider_wallet, contracts) + .await?; + + Ok(()) + } + + async fn register_provider( + &self, + provider_ops: &ProviderOperations, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts + .stake_manager + .as_ref() + .context("Stake manager not initialized")?; + + let compute_units = U256::from(DEFAULT_COMPUTE_UNITS); + let required_stake = stake_manager + .calculate_stake(compute_units, U256::from(0)) + .await + .map_err(|e| anyhow::anyhow!("Failed to calculate required stake: {}", e))?; + + log::info!( + "Required stake for registration: {}", + format_ether(required_stake) + ); + + tokio::time::timeout( + BLOCKCHAIN_OPERATION_TIMEOUT, + provider_ops.retry_register_provider( + required_stake, + self.config.funding_retry_count, + None, + ), + ) + .await + .context("Provider registration timed out")? + .map_err(|e| anyhow::anyhow!("Failed to register provider: {}", e))?; + + log::info!("Provider registered successfully"); + Ok(()) + } + + async fn ensure_adequate_stake( + &self, + provider_ops: &ProviderOperations, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts + .stake_manager + .as_ref() + .context("Stake manager not initialized")?; + + let provider_address = provider_wallet.wallet.default_signer().address(); + + let provider_total_compute = contracts + .compute_registry + .get_provider_total_compute(provider_address) + .await + .map_err(|e| anyhow::anyhow!("Failed to get provider total compute: {}", e))?; + + let provider_stake = stake_manager + .get_stake(provider_address) + .await + .unwrap_or_default(); + + let compute_units = U256::from(DEFAULT_COMPUTE_UNITS); + let required_stake = stake_manager + .calculate_stake(compute_units, provider_total_compute) + .await + .map_err(|e| anyhow::anyhow!("Failed to calculate required stake: {}", e))?; + + if required_stake > provider_stake { + log::info!( + "Increasing provider stake. Required: {} tokens, Current: {} tokens", + format_ether(required_stake), + format_ether(provider_stake) + ); + + tokio::time::timeout( + BLOCKCHAIN_OPERATION_TIMEOUT, + provider_ops.increase_stake(required_stake - provider_stake), + ) + .await + .context("Stake increase timed out")? + .map_err(|e| anyhow::anyhow!("Failed to increase stake: {}", e))?; + + log::info!("Successfully increased stake"); + } + + Ok(()) + } + + async fn ensure_compute_node_registered( + &self, + provider_wallet: &Wallet, + node_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let compute_node_ops = + ComputeNodeOperations::new(provider_wallet, node_wallet, contracts.clone()); + + let compute_node_exists = compute_node_ops + .check_compute_node_exists() + .await + .map_err(|e| anyhow::anyhow!("Failed to check if compute node exists: {}", e))?; + + if compute_node_exists { + log::info!("Compute node is already registered"); + return Ok(()); + } + + let compute_units = U256::from(DEFAULT_COMPUTE_UNITS); + + tokio::time::timeout( + BLOCKCHAIN_OPERATION_TIMEOUT, + compute_node_ops.add_compute_node(compute_units), + ) + .await + .context("Compute node registration timed out")? + .map_err(|e| anyhow::anyhow!("Failed to register compute node: {}", e))?; + + log::info!("Compute node registered successfully"); + Ok(()) + } +} diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index db30c0b4..08a0c824 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -1,26 +1,57 @@ use crate::error::{PrimeProtocolError, Result}; -use crate::worker::message_queue::MessageQueue; -use alloy::primitives::utils::format_ether; -use alloy::primitives::{Address, U256}; -use prime_core::operations::compute_node::ComputeNodeOperations; -use prime_core::operations::provider::ProviderOperations; -use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; -use shared::web3::contracts::structs::compute_pool::PoolStatus; -use shared::web3::wallet::{Wallet, WalletProvider}; +use crate::worker::auth::AuthenticationManager; +use crate::worker::blockchain::{BlockchainConfig, BlockchainService}; +use crate::worker::constants::{ + DEFAULT_FUNDING_RETRY_COUNT, MESSAGE_QUEUE_TIMEOUT, P2P_SHUTDOWN_TIMEOUT, +}; +use crate::worker::message_processor::MessageProcessor; +use crate::worker::p2p::{Message, MessageType, Service as P2PService}; +use p2p::{Keypair, PeerId}; use std::sync::Arc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use url::Url; +/// Core worker client that handles P2P networking and blockchain operations pub struct WorkerClientCore { + config: ClientConfig, + blockchain_service: Option, + p2p_state: P2PState, + auth_manager: Option>, + cancellation_token: CancellationToken, + // Message processing + user_message_tx: Option>, + user_message_rx: Option>>>, + message_processor_handle: Option>, +} + +/// Configuration for the worker client +struct ClientConfig { rpc_url: String, compute_pool_id: u64, private_key_provider: Option, private_key_node: Option, auto_accept_transactions: bool, funding_retry_count: u32, - message_queue: Arc, + p2p_port: u16, +} + +/// P2P networking state +struct P2PState { + keypair: Keypair, + peer_id: Option, + outbound_tx: Option>>>, + message_queue_rx: Option>>>, + handle: Option>>, + authenticated_peers: + Option>>>, } impl WorkerClientCore { + /// Create a new worker client + #[allow(clippy::too_many_arguments)] pub fn new( compute_pool_id: u64, rpc_url: String, @@ -28,7 +59,10 @@ impl WorkerClientCore { private_key_node: Option, auto_accept_transactions: Option, funding_retry_count: Option, + cancellation_token: CancellationToken, + p2p_port: u16, ) -> Result { + // Validate inputs if rpc_url.is_empty() { return Err(PrimeProtocolError::InvalidConfig( "RPC URL cannot be empty".to_string(), @@ -38,402 +72,401 @@ impl WorkerClientCore { Url::parse(&rpc_url) .map_err(|_| PrimeProtocolError::InvalidConfig("Invalid RPC URL format".to_string()))?; - Ok(Self { + let config = ClientConfig { rpc_url, compute_pool_id, private_key_provider, private_key_node, auto_accept_transactions: auto_accept_transactions.unwrap_or(true), - funding_retry_count: funding_retry_count.unwrap_or(10), - message_queue: Arc::new(MessageQueue::new()), + funding_retry_count: funding_retry_count.unwrap_or(DEFAULT_FUNDING_RETRY_COUNT), + p2p_port, + }; + + let p2p_state = P2PState { + keypair: Keypair::generate_ed25519(), + peer_id: None, + outbound_tx: None, + message_queue_rx: None, + handle: None, + authenticated_peers: None, + }; + + // Create user message channel + let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); + + Ok(Self { + config, + blockchain_service: None, + p2p_state, + auth_manager: None, + cancellation_token, + user_message_tx: Some(user_message_tx), + user_message_rx: Some(Arc::new(Mutex::new(user_message_rx))), + message_processor_handle: None, }) } - pub async fn start_async(&self) -> Result<()> { - let (provider_wallet, node_wallet, contracts) = - self.initialize_blockchain_components().await?; - let pool_info = self.wait_for_active_pool(&contracts).await?; + /// Start the worker client asynchronously + pub async fn start_async(&mut self) -> Result<()> { + log::info!("Starting worker client..."); - log::debug!("Pool info: {:?}", pool_info); - log::debug!("Checking provider"); - self.ensure_provider_registered(&provider_wallet, &contracts) - .await?; - log::debug!("Checking compute node"); - self.ensure_compute_node_registered(&provider_wallet, &node_wallet, &contracts) - .await?; + // Initialize blockchain components + self.initialize_blockchain().await?; - log::debug!("blockchain components initialized"); - log::debug!("starting queues"); + // Initialize authentication manager + self.initialize_auth_manager()?; - // Start the message queue listener - self.message_queue.start_listener().await.map_err(|e| { - PrimeProtocolError::InvalidConfig(format!("Failed to start message listener: {}", e)) - })?; + // Start P2P networking + self.start_p2p_service().await?; - log::debug!("Message queue listener started"); + // Start message processor + self.start_message_processor().await?; + log::info!("Worker client started successfully"); Ok(()) } - async fn initialize_blockchain_components( - &self, - ) -> Result<(Wallet, Wallet, Contracts)> { - let private_key_provider = self.get_private_key_provider()?; - let private_key_node = self.get_private_key_node()?; - let rpc_url = Url::parse(&self.rpc_url).unwrap(); - - let provider_wallet = Wallet::new(&private_key_provider, rpc_url.clone()).map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to create provider wallet: {}", e)) - })?; - - let node_wallet = Wallet::new(&private_key_node, rpc_url.clone()).map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to create node wallet: {}", e)) - })?; - - let contracts = ContractBuilder::new(provider_wallet.provider()) - .with_compute_pool() - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_stake_manager() - .build() - .map_err(|e| PrimeProtocolError::BlockchainError(e.to_string()))?; + /// Stop the worker client and clean up resources + pub async fn stop_async(&mut self) -> Result<()> { + log::info!("Stopping worker client..."); - Ok((provider_wallet, node_wallet, contracts)) - } + // Cancel all background tasks + self.cancellation_token.cancel(); - async fn wait_for_active_pool( - &self, - contracts: &Contracts, - ) -> Result { - loop { - match contracts - .compute_pool - .get_pool_info(U256::from(self.compute_pool_id)) - .await - { - Ok(pool) if pool.status == PoolStatus::ACTIVE => return Ok(pool), - Ok(_) => { - log::info!("Pool not active yet, waiting..."); - tokio::time::sleep(tokio::time::Duration::from_secs(15)).await; - } - Err(e) => { - return Err(PrimeProtocolError::BlockchainError(format!( - "Failed to get pool info: {}", - e - ))); - } - } + // Stop message processor + if let Some(handle) = self.message_processor_handle.take() { + handle.abort(); } - } - async fn ensure_provider_registered( - &self, - provider_wallet: &Wallet, - contracts: &Contracts, - ) -> Result<()> { - let provider_ops = ProviderOperations::new( - provider_wallet.clone(), - contracts.clone(), - self.auto_accept_transactions, - ); - - let provider_exists = self.check_provider_exists(&provider_ops).await?; - let is_whitelisted = self.check_provider_whitelisted(&provider_ops).await?; - - if provider_exists && is_whitelisted { - log::info!("Provider is registered and whitelisted"); - } else { - self.register_provider_if_needed(&provider_ops, contracts) - .await?; + // Wait for P2P service to shut down + if let Some(handle) = self.p2p_state.handle.take() { + match tokio::time::timeout(P2P_SHUTDOWN_TIMEOUT, handle).await { + Ok(Ok(_)) => log::info!("P2P service shut down gracefully"), + Ok(Err(e)) => log::error!("P2P service error during shutdown: {:?}", e), + Err(_) => log::warn!("P2P service shutdown timed out"), + } } - self.ensure_adequate_stake(&provider_ops, provider_wallet, contracts) - .await?; - + log::info!("Worker client stopped"); Ok(()) } - async fn check_provider_exists(&self, provider_ops: &ProviderOperations) -> Result { - provider_ops.check_provider_exists().await.map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to check if provider exists: {}", - e - )) - }) + /// Get the peer ID of this node + pub fn get_peer_id(&self) -> Option { + self.p2p_state.peer_id } - async fn check_provider_whitelisted(&self, provider_ops: &ProviderOperations) -> Result { - provider_ops - .check_provider_whitelisted() + /// Get the next message from the P2P network (only returns general messages) + pub async fn get_next_message(&self) -> Option { + let rx = self.user_message_rx.as_ref()?; + + tokio::time::timeout(MESSAGE_QUEUE_TIMEOUT, rx.lock().await.recv()) .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to check provider whitelist status: {}", - e - )) - }) + .ok() + .flatten() } - async fn register_provider_if_needed( - &self, - provider_ops: &ProviderOperations, - contracts: &Contracts, - ) -> Result<()> { - let stake_manager = contracts.stake_manager.as_ref().ok_or_else(|| { - PrimeProtocolError::BlockchainError("Stake manager not initialized".to_string()) - })?; - let compute_units = U256::from(1); // TODO: Make configurable + /// Send a message to a peer + pub async fn send_message(&self, message: Message) -> Result<()> { + log::debug!("Sending message to peer: {}", message.peer_id); - let required_stake = stake_manager - .calculate_stake(compute_units, U256::from(0)) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to calculate required stake: {}", - e - )) - })?; - - log::info!("Required stake: {}", format_ether(required_stake)); - - // Add timeout to prevent hanging on blockchain operations - let register_future = - provider_ops.retry_register_provider(required_stake, self.funding_retry_count, None); - - tokio::time::timeout( - tokio::time::Duration::from_secs(300), // 5 minute timeout - register_future, - ) - .await - .map_err(|_| { - PrimeProtocolError::BlockchainError( - "Provider registration timed out after 5 minutes".to_string(), - ) - })? - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to register provider: {}", e)) + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Authentication manager not initialized".to_string()) })?; - log::info!("Provider registered successfully"); - Ok(()) - } - - async fn ensure_adequate_stake( - &self, - provider_ops: &ProviderOperations, - provider_wallet: &Wallet, - contracts: &Contracts, - ) -> Result<()> { - let stake_manager = contracts.stake_manager.as_ref().ok_or_else(|| { - PrimeProtocolError::BlockchainError("Stake manager not initialized".to_string()) + let tx = self.p2p_state.outbound_tx.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) })?; - let provider_address = provider_wallet.wallet.default_signer().address(); - let provider_total_compute = self - .get_provider_total_compute(contracts, provider_address) - .await?; - let provider_stake = self.get_provider_stake(contracts, provider_address).await; - let compute_units = U256::from(1); // TODO: Make configurable + // Check if we're already authenticated with this peer + if auth_manager.is_authenticated(&message.peer_id).await { + log::debug!( + "Already authenticated with peer {}, sending message directly", + message.peer_id + ); + return tx.lock().await.send(message).await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to send message: {}", e)) + }); + } - let required_stake = stake_manager - .calculate_stake(compute_units, provider_total_compute) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to calculate required stake: {}", - e - )) - })?; - - if required_stake > provider_stake { - self.increase_provider_stake(provider_ops, required_stake, provider_stake) - .await?; + // Not authenticated yet, check if we have ongoing authentication + log::debug!("Not authenticated with peer {}", message.peer_id); + + // Check if there's already an ongoing auth request + if let Some(role) = auth_manager.get_auth_role(&message.peer_id).await { + match role.as_str() { + "initiator" => { + return Err(PrimeProtocolError::InvalidConfig(format!( + "Already initiated authentication with peer {}", + message.peer_id + ))); + } + "responder" => { + // We're responding to their auth, queue the message + log::debug!( + "Queuing message for peer {} (we're responding to their auth)", + message.peer_id + ); + return auth_manager + .queue_message_as_responder(message.peer_id.clone(), message) + .await; + } + _ => {} + } } - Ok(()) - } + // Extract fields we need before moving the message + let peer_id = message.peer_id.clone(); + let multiaddrs = message.multiaddrs.clone(); - async fn get_provider_total_compute( - &self, - contracts: &Contracts, - provider_address: Address, - ) -> Result { - contracts - .compute_registry - .get_provider_total_compute(provider_address) - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to get provider total compute: {}", - e - )) - }) - } + // Start authentication (takes ownership of message) + let auth_challenge = auth_manager + .start_authentication(peer_id.clone(), message) + .await?; - async fn get_provider_stake( - &self, - contracts: &Contracts, - provider_address: Address, - ) -> U256 { - let stake_manager = contracts.stake_manager.as_ref(); - match stake_manager { - Some(manager) => manager - .get_stake(provider_address) - .await - .unwrap_or_default(), - None => U256::from(0), - } + // Send authentication initiation + let auth_message = Message { + message_type: MessageType::AuthenticationInitiation { + challenge: auth_challenge, + }, + peer_id, + multiaddrs, + sender_address: Some(auth_manager.wallet_address()), + response_tx: None, + }; + + tx.lock().await.send(auth_message).await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to send auth message: {}", e)) + }) } - async fn increase_provider_stake( - &self, - provider_ops: &ProviderOperations, - required_stake: U256, - current_stake: U256, - ) -> Result<()> { - log::info!( - "Provider stake is less than required stake. Required: {} tokens, Current: {} tokens", - format_ether(required_stake), - format_ether(current_stake) - ); + // Private helper methods - // Add timeout to prevent hanging on stake increase operations - let stake_future = provider_ops.increase_stake(required_stake - current_stake); + async fn initialize_blockchain(&mut self) -> Result<()> { + let blockchain_config = BlockchainConfig { + rpc_url: self.config.rpc_url.clone(), + compute_pool_id: self.config.compute_pool_id, + private_key_provider: self.get_private_key_provider()?, + private_key_node: self.get_private_key_node()?, + auto_accept_transactions: self.config.auto_accept_transactions, + funding_retry_count: self.config.funding_retry_count, + }; - tokio::time::timeout( - tokio::time::Duration::from_secs(300), // 5 minute timeout - stake_future, - ) - .await - .map_err(|_| { - PrimeProtocolError::BlockchainError( - "Stake increase timed out after 5 minutes".to_string(), - ) - })? - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) + // Create blockchain service - wallets are created internally + let mut blockchain_service = BlockchainService::new(blockchain_config).map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to create blockchain service: {}", + e + )) })?; - log::info!("Successfully increased stake"); + blockchain_service.initialize().await.map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to initialize blockchain: {}", e)) + })?; + + self.blockchain_service = Some(blockchain_service); Ok(()) } - async fn ensure_compute_node_registered( - &self, - provider_wallet: &Wallet, - node_wallet: &Wallet, - contracts: &Contracts, - ) -> Result<()> { - let compute_node_ops = - ComputeNodeOperations::new(provider_wallet, node_wallet, contracts.clone()); - - let compute_node_exists = self.check_compute_node_exists(&compute_node_ops).await?; + fn initialize_auth_manager(&mut self) -> Result<()> { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; - if compute_node_exists { - log::info!("Compute node is already registered"); - return Ok(()); - } + let node_wallet = blockchain_service.node_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Node wallet not initialized".to_string()) + })?; - self.register_compute_node(&compute_node_ops).await?; + self.auth_manager = Some(Arc::new(AuthenticationManager::new(Arc::new( + node_wallet.clone(), + )))); Ok(()) } - async fn check_compute_node_exists( - &self, - compute_node_ops: &ComputeNodeOperations<'_>, - ) -> Result { - compute_node_ops - .check_compute_node_exists() - .await - .map_err(|e| { - PrimeProtocolError::BlockchainError(format!( - "Failed to check if compute node exists: {}", - e - )) - }) - } - - async fn register_compute_node( - &self, - compute_node_ops: &ComputeNodeOperations<'_>, - ) -> Result<()> { - let compute_units = U256::from(1); // TODO: Make configurable - - // Add timeout to prevent hanging on compute node registration - let register_future = compute_node_ops.add_compute_node(compute_units); + async fn start_p2p_service(&mut self) -> Result<()> { + // Get wallet address from auth manager + let wallet_address = self.auth_manager.as_ref().map(|auth| auth.wallet_address()); - tokio::time::timeout( - tokio::time::Duration::from_secs(300), // 5 minute timeout - register_future, + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = P2PService::new( + self.p2p_state.keypair.clone(), + self.config.p2p_port, + self.cancellation_token.clone(), + wallet_address, ) - .await - .map_err(|_| { - PrimeProtocolError::BlockchainError( - "Compute node registration timed out after 5 minutes".to_string(), - ) - })? .map_err(|e| { - PrimeProtocolError::BlockchainError(format!("Failed to register compute node: {}", e)) + PrimeProtocolError::InvalidConfig(format!("Failed to create P2P service: {}", e)) })?; - log::info!("Compute node registered successfully"); + self.p2p_state.peer_id = Some(p2p_service.node.peer_id()); + self.p2p_state.outbound_tx = Some(Arc::new(Mutex::new(outbound_tx))); + self.p2p_state.message_queue_rx = Some(Arc::new(Mutex::new(message_queue_rx))); + self.p2p_state.authenticated_peers = Some(authenticated_peers); + + log::info!( + "P2P service initialized with peer ID: {:?}", + self.p2p_state.peer_id + ); + + self.p2p_state.handle = Some(tokio::task::spawn(p2p_service.run())); Ok(()) } - fn get_private_key_provider(&self) -> Result { - match &self.private_key_provider { - Some(key) => Ok(key.clone()), - None => std::env::var("PRIVATE_KEY_PROVIDER").map_err(|_| { - PrimeProtocolError::InvalidConfig("PRIVATE_KEY_PROVIDER must be set".to_string()) - }), - } - } + async fn start_message_processor(&mut self) -> Result<()> { + let message_queue_rx = self + .p2p_state + .message_queue_rx + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) + })? + .clone(); + + let user_message_tx = self + .user_message_tx + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "User message channel not initialized".to_string(), + ) + })? + .clone(); + + let auth_manager = self + .auth_manager + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "Authentication manager not initialized".to_string(), + ) + })? + .clone(); + + let outbound_tx = self + .p2p_state + .outbound_tx + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) + })? + .clone(); + + let authenticated_peers = self + .p2p_state + .authenticated_peers + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "Authenticated peers map not initialized".to_string(), + ) + })? + .clone(); + + let message_processor = MessageProcessor::new( + auth_manager, + message_queue_rx, + user_message_tx, + outbound_tx, + authenticated_peers, + self.cancellation_token.clone(), + ); - fn get_private_key_node(&self) -> Result { - match &self.private_key_node { - Some(key) => Ok(key.clone()), - None => std::env::var("PRIVATE_KEY_NODE").map_err(|_| { - PrimeProtocolError::InvalidConfig("PRIVATE_KEY_NODE must be set".to_string()) - }), - } + self.message_processor_handle = Some(tokio::task::spawn(message_processor.run())); + Ok(()) } - /// Get the shared message queue instance - pub fn get_message_queue(&self) -> Arc { - self.message_queue.clone() + fn get_private_key_provider(&self) -> Result { + self.config + .private_key_provider + .clone() + .or_else(|| std::env::var("PRIVATE_KEY_PROVIDER").ok()) + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "PRIVATE_KEY_PROVIDER must be set either as parameter or environment variable" + .to_string(), + ) + }) } - /// Stop the message queue listener - pub async fn stop_async(&self) -> Result<()> { - self.message_queue.stop_listener().await.map_err(|e| { - PrimeProtocolError::InvalidConfig(format!("Failed to stop message listener: {}", e)) - })?; - Ok(()) + fn get_private_key_node(&self) -> Result { + self.config + .private_key_node + .clone() + .or_else(|| std::env::var("PRIVATE_KEY_NODE").ok()) + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig( + "PRIVATE_KEY_NODE must be set either as parameter or environment variable" + .to_string(), + ) + }) } } #[cfg(test)] -mod test { +mod tests { use super::*; use test_log::test; - #[test(tokio::test)] - async fn test_start_async() { - // standard anvil blockchain keys for local testing + fn create_test_config() -> (String, String) { + // Standard Anvil blockchain keys for local testing let node_key = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; let provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + (node_key.to_string(), provider_key.to_string()) + } + + #[test] + fn test_client_creation() { + let (node_key, provider_key) = create_test_config(); + let cancellation_token = CancellationToken::new(); - // todo: currently still have to make up the local blockchain incl. smart contract deployments - let worker = WorkerClientCore::new( + let result = WorkerClientCore::new( 0, "http://localhost:8545".to_string(), - Some(provider_key.to_string()), - Some(node_key.to_string()), + Some(provider_key), + Some(node_key), None, None, - ) - .unwrap(); - worker.start_async().await.unwrap(); + cancellation_token, + 8000, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_invalid_rpc_url() { + let (node_key, provider_key) = create_test_config(); + let cancellation_token = CancellationToken::new(); + + let result = WorkerClientCore::new( + 0, + "invalid-url".to_string(), + Some(provider_key), + Some(node_key), + None, + None, + cancellation_token, + 8000, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_empty_rpc_url() { + let (node_key, provider_key) = create_test_config(); + let cancellation_token = CancellationToken::new(); + + let result = WorkerClientCore::new( + 0, + "".to_string(), + Some(provider_key), + Some(node_key), + None, + None, + cancellation_token, + 8000, + ); + + assert!(result.is_err()); } } diff --git a/crates/prime-protocol-py/src/worker/constants.rs b/crates/prime-protocol-py/src/worker/constants.rs new file mode 100644 index 00000000..800b5e49 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/constants.rs @@ -0,0 +1,26 @@ +use std::time::Duration; + +/// Default P2P port for worker nodes +pub const DEFAULT_P2P_PORT: u16 = 8000; + +/// Default number of retries for funding operations +pub const DEFAULT_FUNDING_RETRY_COUNT: u32 = 10; + +/// Default compute units for node registration +pub const DEFAULT_COMPUTE_UNITS: u64 = 1; + +/// Timeout for blockchain operations +pub const BLOCKCHAIN_OPERATION_TIMEOUT: Duration = Duration::from_secs(300); + +/// Timeout for message queue operations +pub const MESSAGE_QUEUE_TIMEOUT: Duration = Duration::from_millis(100); + +/// Pool status check interval +pub const POOL_STATUS_CHECK_INTERVAL: Duration = Duration::from_secs(15); + +/// P2P shutdown timeout +pub const P2P_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +/// Channel sizes +pub const P2P_CHANNEL_SIZE: usize = 100; +pub const MESSAGE_QUEUE_CHANNEL_SIZE: usize = 300; diff --git a/crates/prime-protocol-py/src/worker/message_processor.rs b/crates/prime-protocol-py/src/worker/message_processor.rs new file mode 100644 index 00000000..4cf40424 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/message_processor.rs @@ -0,0 +1,270 @@ +use crate::error::Result; +use crate::worker::auth::AuthenticationManager; +use crate::worker::p2p::{Message, MessageType}; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + Mutex, RwLock, +}; +use tokio_util::sync::CancellationToken; + +/// Handles processing of incoming P2P messages +pub struct MessageProcessor { + auth_manager: Arc, + message_queue_rx: Arc>>, + user_message_tx: Sender, + outbound_tx: Arc>>, + authenticated_peers: Arc>>, + cancellation_token: CancellationToken, +} + +impl MessageProcessor { + pub fn new( + auth_manager: Arc, + message_queue_rx: Arc>>, + user_message_tx: Sender, + outbound_tx: Arc>>, + authenticated_peers: Arc>>, + cancellation_token: CancellationToken, + ) -> Self { + Self { + auth_manager, + message_queue_rx, + user_message_tx, + outbound_tx, + authenticated_peers, + cancellation_token, + } + } + + /// Run the message processing loop + pub async fn run(self) { + loop { + tokio::select! { + _ = self.cancellation_token.cancelled() => { + log::info!("Message processor shutting down"); + break; + } + message_result = async { + let mut rx = self.message_queue_rx.lock().await; + tokio::time::timeout( + crate::worker::constants::MESSAGE_QUEUE_TIMEOUT, + rx.recv() + ).await + } => { + let message = match message_result { + Ok(Some(msg)) => msg, + Ok(None) => { + log::debug!("Message queue closed"); + break; + } + Err(_) => continue, // Timeout, continue loop + }; + + if let Err(e) = self.process_message(message).await { + log::error!("Failed to process message: {}", e); + } + } + } + } + } + + /// Process a single message + async fn process_message(&self, message: Message) -> Result<()> { + let Message { + message_type, + peer_id, + multiaddrs, + sender_address, + response_tx, + } = message; + + match message_type { + MessageType::AuthenticationInitiation { challenge } => { + if let Some(tx) = response_tx { + self.handle_auth_initiation_with_response( + peer_id, + multiaddrs, + sender_address, + challenge, + tx, + ) + .await + } else { + log::error!("AuthenticationInitiation received without response tx"); + Ok(()) + } + } + MessageType::AuthenticationResponse { + challenge, + signature, + } => { + let msg = Message { + message_type: MessageType::AuthenticationResponse { + challenge: challenge.clone(), + signature: signature.clone(), + }, + peer_id, + multiaddrs, + sender_address, + response_tx: None, + }; + self.handle_auth_response(msg, challenge, signature).await + } + MessageType::AuthenticationSolution { signature } => { + if let Some(tx) = response_tx { + self.handle_auth_solution_with_response( + peer_id, + multiaddrs, + sender_address, + signature, + tx, + ) + .await + } else { + log::error!("AuthenticationSolution received without response tx"); + Ok(()) + } + } + MessageType::General { data } => { + // Forward general messages to user + let msg = Message { + message_type: MessageType::General { data }, + peer_id, + multiaddrs, + sender_address, + response_tx: None, + }; + self.user_message_tx.send(msg).await.map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to forward message to user: {}", + e + )) + }) + } + } + } + + async fn handle_auth_initiation_with_response( + &self, + peer_id: String, + _multiaddrs: Vec, + _sender_address: Option, + challenge: String, + response_tx: tokio::sync::oneshot::Sender, + ) -> Result<()> { + let (our_challenge, signature) = self + .auth_manager + .handle_auth_initiation(&peer_id, &challenge) + .await?; + + // Send authentication response via the one-shot channel + let response = p2p::AuthenticationInitiationResponse { + message: our_challenge, + signature, + }; + + response_tx.send(response.into()).map_err(|_| { + crate::error::PrimeProtocolError::InvalidConfig( + "Failed to send auth response: receiver dropped".to_string(), + ) + }) + } + + async fn handle_auth_response( + &self, + message: Message, + challenge: String, + signature: String, + ) -> Result<()> { + let (our_signature, queued_message) = self + .auth_manager + .handle_auth_response(&message.peer_id, &challenge, &signature) + .await?; + + // Extract and store the peer's wallet address from their signature + if let Ok(parsed_signature) = alloy::primitives::Signature::from_str(&signature) { + if let Ok(recovered_address) = parsed_signature.recover_address_from_msg(&challenge) { + self.authenticated_peers + .write() + .await + .insert(message.peer_id.clone(), recovered_address.to_string()); + } + } + + // Send authentication solution + let solution = Message { + message_type: MessageType::AuthenticationSolution { + signature: our_signature, + }, + peer_id: message.peer_id.clone(), + multiaddrs: message.multiaddrs, + sender_address: Some(self.auth_manager.wallet_address()), + response_tx: None, + }; + + self.outbound_tx + .lock() + .await + .send(solution) + .await + .map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send auth solution: {}", + e + )) + })?; + + // Send the queued message if any + if let Some(msg) = queued_message { + self.outbound_tx.lock().await.send(msg).await.map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send queued message: {}", + e + )) + })?; + } + + Ok(()) + } + + async fn handle_auth_solution_with_response( + &self, + peer_id: String, + _multiaddrs: Vec, + _sender_address: Option, + signature: String, + response_tx: tokio::sync::oneshot::Sender, + ) -> Result<()> { + let (peer_address, queued_messages) = self + .auth_manager + .handle_auth_solution(&peer_id, &signature) + .await?; + + // Store the peer's wallet address for future message handling + self.authenticated_peers + .write() + .await + .insert(peer_id.clone(), peer_address.to_string()); + + // Send any queued messages now that we're authenticated + for msg in queued_messages { + self.outbound_tx.lock().await.send(msg).await.map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send queued message: {}", + e + )) + })?; + } + + // Send the response + let response = p2p::AuthenticationSolutionResponse::Granted; + response_tx.send(response.into()).map_err(|_| { + crate::error::PrimeProtocolError::InvalidConfig( + "Failed to send auth solution response: receiver dropped".to_string(), + ) + }) + } +} diff --git a/crates/prime-protocol-py/src/worker/message_queue.rs b/crates/prime-protocol-py/src/worker/message_queue.rs deleted file mode 100644 index 167fde05..00000000 --- a/crates/prime-protocol-py/src/worker/message_queue.rs +++ /dev/null @@ -1,80 +0,0 @@ -use crate::utils::message_queue::{Message, MessageQueue as GenericMessageQueue}; -use pyo3::prelude::*; - -/// Queue types for the worker message queue -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum QueueType { - PoolOwner, - Validator, -} - -/// Worker-specific message queue with predefined queue types -#[derive(Clone)] -pub struct MessageQueue { - pool_owner_queue: GenericMessageQueue, - validator_queue: GenericMessageQueue, -} - -impl MessageQueue { - /// Create a new worker message queue with pool_owner and validator queues - pub fn new() -> Self { - Self { - pool_owner_queue: GenericMessageQueue::new(None), - validator_queue: GenericMessageQueue::new(None), - } - } - - /// Start the background message listener for worker - pub(crate) async fn start_listener(&self) -> Result<(), String> { - // Start mock listeners with different frequencies - // pool_owner messages every 2 seconds, validator messages every 3 seconds - self.pool_owner_queue.start_mock_listener(2).await?; - self.validator_queue.start_mock_listener(3).await?; - Ok(()) - } - - /// Stop the background listener - pub(crate) async fn stop_listener(&self) -> Result<(), String> { - self.pool_owner_queue.stop_listener().await?; - self.validator_queue.stop_listener().await?; - Ok(()) - } - - /// Get the next message from the pool owner queue - pub(crate) async fn get_pool_owner_message(&self) -> Option { - self.pool_owner_queue.get_message().await - } - - /// Get the next message from the validator queue - pub(crate) async fn get_validator_message(&self) -> Option { - self.validator_queue.get_message().await - } - - /// Push a message to the appropriate queue (for testing or internal use) - pub(crate) async fn push_message( - &self, - queue_type: QueueType, - content: serde_json::Value, - ) -> Result<(), String> { - let message = Message { - content, - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - sender: Some("worker".to_string()), - }; - - match queue_type { - QueueType::PoolOwner => self.pool_owner_queue.push_message(message).await, - QueueType::Validator => self.validator_queue.push_message(message).await, - } - } - - /// Get queue sizes for monitoring - pub(crate) async fn get_queue_sizes(&self) -> (usize, usize) { - let pool_owner_size = self.pool_owner_queue.get_queue_size().await; - let validator_size = self.validator_queue.get_queue_size().await; - (pool_owner_size, validator_size) - } -} diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index a308df12..78853ac2 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -1,24 +1,38 @@ use pyo3::prelude::*; + +mod auth; +mod blockchain; mod client; -pub(crate) mod message_queue; +mod constants; +mod message_processor; +mod p2p; + pub(crate) use client::WorkerClientCore; +use tokio_util::sync::CancellationToken; + +use crate::worker::p2p::Message; +use constants::DEFAULT_P2P_PORT; + /// Prime Protocol Worker Client - for compute nodes that execute tasks #[pyclass] pub(crate) struct WorkerClient { inner: WorkerClientCore, runtime: Option, + cancellation_token: CancellationToken, } #[pymethods] impl WorkerClient { #[new] - #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None))] + #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None, p2p_port=DEFAULT_P2P_PORT))] pub fn new( compute_pool_id: u64, rpc_url: String, private_key_provider: Option, private_key_node: Option, + p2p_port: u16, ) -> PyResult { + let cancellation_token = CancellationToken::new(); let inner = WorkerClientCore::new( compute_pool_id, rpc_url, @@ -26,59 +40,78 @@ impl WorkerClient { private_key_node, None, None, + cancellation_token.clone(), + p2p_port, ) - .map_err(|e| PyErr::new::(e.to_string()))?; + .map_err(to_py_err)?; Ok(Self { inner, runtime: None, + cancellation_token, }) } + /// Start the worker client pub fn start(&mut self, py: Python) -> PyResult<()> { - // Create a new runtime for this call - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .map_err(|e| PyErr::new::(e.to_string()))?; + if self.runtime.is_some() { + return Err(to_py_runtime_err("Client already started")); + } - // Run the async function with GIL released + let rt = create_runtime()?; let result = py.allow_threads(|| rt.block_on(self.inner.start_async())); - // Store the runtime for future use self.runtime = Some(rt); + result.map_err(to_py_err) + } - result.map_err(|e| PyErr::new::(e.to_string())) + /// Get the next message from the P2P network + pub fn get_next_message(&self, py: Python) -> PyResult> { + let rt = self.ensure_runtime()?; + + Ok(py.allow_threads(|| { + rt.block_on(self.inner.get_next_message()) + .map(message_to_pyobject) + })) } - pub fn get_pool_owner_message(&self, py: Python) -> PyResult> { - if let Some(rt) = self.runtime.as_ref() { - Ok(py.allow_threads(|| { - rt.block_on(self.inner.get_message_queue().get_pool_owner_message()) - })) - } else { - Err(PyErr::new::( - "Client not started. Call start() first.".to_string(), - )) - } + /// Send a message to a peer + pub fn send_message( + &self, + peer_id: String, + data: Vec, + multiaddrs: Vec, + py: Python, + ) -> PyResult<()> { + let rt = self.ensure_runtime()?; + + let message = Message { + message_type: p2p::MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: None, // Will be filled from our wallet automatically + response_tx: None, + }; + + py.allow_threads(|| rt.block_on(self.inner.send_message(message))) + .map_err(to_py_err) } - pub fn get_validator_message(&self, py: Python) -> PyResult> { - if let Some(rt) = self.runtime.as_ref() { - Ok(py.allow_threads(|| { - rt.block_on(self.inner.get_message_queue().get_validator_message()) - })) - } else { - Err(PyErr::new::( - "Client not started. Call start() first.".to_string(), - )) - } + /// Get this node's peer ID + pub fn get_own_peer_id(&self) -> PyResult> { + self.ensure_runtime()?; + Ok(self.inner.get_peer_id().map(|id| id.to_string())) } + /// Stop the worker client and clean up resources pub fn stop(&mut self, py: Python) -> PyResult<()> { + // Cancel all background tasks + self.cancellation_token.cancel(); + if let Some(rt) = self.runtime.as_ref() { - py.allow_threads(|| rt.block_on(self.inner.stop_async())) - .map_err(|e| PyErr::new::(e.to_string()))?; + let inner = &mut self.inner; + py.allow_threads(|| rt.block_on(inner.stop_async())) + .map_err(to_py_err)?; } // Clean up the runtime @@ -89,3 +122,70 @@ impl WorkerClient { Ok(()) } } + +// Helper methods +impl WorkerClient { + fn ensure_runtime(&self) -> PyResult<&tokio::runtime::Runtime> { + self.runtime + .as_ref() + .ok_or_else(|| to_py_runtime_err("Client not started. Call start() first.")) + } +} + +// Utility functions +fn create_runtime() -> PyResult { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| to_py_runtime_err(&format!("Failed to create runtime: {}", e))) +} + +fn to_py_err(err: impl std::fmt::Display) -> PyErr { + PyErr::new::(err.to_string()) +} + +fn to_py_runtime_err(msg: &str) -> PyErr { + PyErr::new::(msg.to_string()) +} + +fn message_to_pyobject(message: Message) -> PyObject { + let message_data = match message.message_type { + p2p::MessageType::General { data } => { + serde_json::json!({ + "type": "general", + "data": data, + }) + } + p2p::MessageType::AuthenticationInitiation { challenge } => { + serde_json::json!({ + "type": "auth_initiation", + "challenge": challenge, + }) + } + p2p::MessageType::AuthenticationResponse { + challenge, + signature, + } => { + serde_json::json!({ + "type": "auth_response", + "challenge": challenge, + "signature": signature, + }) + } + p2p::MessageType::AuthenticationSolution { signature } => { + serde_json::json!({ + "type": "auth_solution", + "signature": signature, + }) + } + }; + + let json_value = serde_json::json!({ + "message": message_data, + "peer_id": message.peer_id, + "multiaddrs": message.multiaddrs, + "sender_address": message.sender_address, + }); + + Python::with_gil(|py| crate::utils::json_parser::json_to_pyobject(py, &json_value)) +} diff --git a/crates/prime-protocol-py/src/worker/p2p.rs b/crates/prime-protocol-py/src/worker/p2p.rs new file mode 100644 index 00000000..1bf9d449 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/p2p.rs @@ -0,0 +1,492 @@ +use anyhow::{Context, Result}; +use p2p::{IncomingMessage, Keypair, Node, NodeBuilder, OutgoingMessage, PeerId, Protocols}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + RwLock, +}; +use tokio_util::sync::CancellationToken; + +use crate::worker::constants::{MESSAGE_QUEUE_CHANNEL_SIZE, P2P_CHANNEL_SIZE}; + +// Type alias for the complex return type of Service::new +type ServiceNewResult = Result<( + Service, + Sender, + Receiver, + Arc>>, +)>; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + General { + data: Vec, + }, + AuthenticationInitiation { + challenge: String, + }, + AuthenticationResponse { + challenge: String, + signature: String, + }, + AuthenticationSolution { + signature: String, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub message_type: MessageType, + pub peer_id: String, + pub multiaddrs: Vec, + pub sender_address: Option, // Ethereum address of the sender + #[serde(skip)] + pub response_tx: Option>, // For sending responses to auth requests +} + +pub struct Service { + p2p_rx: Receiver, + outbound_message_tx: Sender, + incoming_message_rx: Receiver, + message_queue_tx: Sender, + pub node: Node, + cancellation_token: CancellationToken, + wallet_address: Option, // Our wallet address for authentication + // Map peer_id to their wallet address after authentication + authenticated_peers: Arc>>, +} + +impl Service { + pub(crate) fn new( + keypair: Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet_address: Option, + ) -> ServiceNewResult { + // Channels for application <-> P2P service communication + let (p2p_tx, p2p_rx) = tokio::sync::mpsc::channel(P2P_CHANNEL_SIZE); + let (message_queue_tx, message_queue_rx) = + tokio::sync::mpsc::channel(MESSAGE_QUEUE_CHANNEL_SIZE); + + let protocols = Protocols::new().with_general().with_authentication(); + + let listen_addr = format!("/ip4/0.0.0.0/tcp/{}", port) + .parse() + .context("Failed to parse listen address")?; + + let (node, incoming_messages_rx, outgoing_messages_tx) = NodeBuilder::new() + .with_keypair(keypair) + .with_port(port) + .with_listen_addr(listen_addr) + .with_protocols(protocols) + .with_cancellation_token(cancellation_token.clone()) + .try_build() + .context("Failed to create P2P node")?; + + let authenticated_peers = Arc::new(RwLock::new(HashMap::new())); + + Ok(( + Self { + p2p_rx, + outbound_message_tx: outgoing_messages_tx, + incoming_message_rx: incoming_messages_rx, + message_queue_tx, + node, + cancellation_token, + wallet_address, + authenticated_peers: authenticated_peers.clone(), + }, + p2p_tx, + message_queue_rx, + authenticated_peers, + )) + } + + async fn handle_outgoing_message( + mut message: Message, + outgoing_message_tx: &Sender, + wallet_address: &Option, + ) -> Result<()> { + // Add our wallet address if not already set + if message.sender_address.is_none() { + message.sender_address = wallet_address.clone(); + } + + let req = match &message.message_type { + MessageType::General { data } => { + p2p::Request::General(p2p::GeneralRequest { data: data.clone() }) + } + MessageType::AuthenticationInitiation { challenge } => p2p::Request::Authentication( + p2p::AuthenticationRequest::Initiation(p2p::AuthenticationInitiationRequest { + message: challenge.clone(), + }), + ), + MessageType::AuthenticationResponse { + challenge: _, + signature: _, + } => { + // This message type should not be sent as a request + // It should be handled via response channels + log::error!( + "AuthenticationResponse should be sent via response channel, not as a request" + ); + return Ok(()); + } + MessageType::AuthenticationSolution { signature } => p2p::Request::Authentication( + p2p::AuthenticationRequest::Solution(p2p::AuthenticationSolutionRequest { + signature: signature.clone(), + }), + ), + }; + + let peer_id = PeerId::from_str(&message.peer_id).context("Failed to parse peer ID")?; + + let multiaddrs = message + .multiaddrs + .iter() + .map(|addr| addr.parse()) + .collect::, _>>() + .context("Failed to parse multiaddresses")?; + + log::debug!( + "Sending message to peer: {}, multiaddrs: {:?}", + peer_id, + multiaddrs + ); + + outgoing_message_tx + .send(req.into_outgoing_message(peer_id, multiaddrs)) + .await + .context("Failed to send outgoing message")?; + + Ok(()) + } + + pub(crate) async fn run(self) -> Result<()> { + // Extract all necessary fields before the async move block + let node = self.node; + let node_cancel_token = self.cancellation_token.clone(); + let mut p2p_rx = self.p2p_rx; + let outbound_message_tx = self.outbound_message_tx; + let mut incoming_message_rx = self.incoming_message_rx; + let message_queue_tx = self.message_queue_tx; + let cancellation_token = self.cancellation_token; + let wallet_address = self.wallet_address; + let authenticated_peers = self.authenticated_peers; + + // Start the P2P node in a separate task + let node_handle = tokio::spawn(async move { + tokio::select! { + result = node.run() => { + if let Err(e) = result { + log::error!("P2P node error: {}", e); + } + } + _ = node_cancel_token.cancelled() => { + log::info!("P2P node shutdown requested"); + } + } + }); + + log::info!("P2P service started"); + + // Main event loop + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + log::info!("P2P service shutdown requested"); + break; + } + + // Handle outgoing messages from application + Some(message) = p2p_rx.recv() => { + if let Err(e) = Self::handle_outgoing_message(message, &outbound_message_tx, &wallet_address).await { + log::error!("Failed to handle outgoing message: {}", e); + } + } + + // Handle incoming messages from network + Some(incoming) = incoming_message_rx.recv() => { + let peer_id = incoming.peer; + match incoming.message { + p2p::Libp2pIncomingMessage::Request { + request_id, + request, + channel, + } => { + if let Err(e) = Self::handle_incoming_request_static( + &message_queue_tx, + &outbound_message_tx, + &authenticated_peers, + peer_id, + request_id, + request, + channel + ).await { + log::error!("Failed to handle incoming request: {}", e); + } + } + p2p::Libp2pIncomingMessage::Response { + request_id, + response, + } => { + if let Err(e) = Self::handle_incoming_response_static( + &message_queue_tx, + peer_id, + request_id, + response + ).await { + log::error!("Failed to handle incoming response: {}", e); + } + } + } + } + } + } + + // Wait for the node task to complete + if let Err(e) = node_handle.await { + log::error!("P2P node task error: {}", e); + } + + log::info!("P2P service shutdown complete"); + Ok(()) + } + + async fn handle_incoming_request_static( + message_queue_tx: &Sender, + outbound_message_tx: &Sender, + authenticated_peers: &Arc>>, + peer_id: PeerId, + request_id: T, + request: p2p::Request, + channel: p2p::ResponseChannel, + ) -> Result<()> + where + T: std::fmt::Debug, + { + log::debug!( + "Received request from peer: {} (ID: {:?})", + peer_id, + request_id + ); + + match request { + p2p::Request::General(p2p::GeneralRequest { data }) => { + log::debug!( + "Processing GeneralRequest with {} bytes from {}", + data.len(), + peer_id + ); + + // Check if peer is authenticated + let sender_address = { + let auth_peers = authenticated_peers.read().await; + match auth_peers.get(&peer_id.to_string()) { + Some(address) => address.clone(), + None => { + log::warn!("Rejecting message from unauthenticated peer: {}", peer_id); + // Send error response + let response = p2p::Response::General(p2p::GeneralResponse { + data: b"ERROR: Not authenticated".to_vec(), + }); + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send error response")?; + return Ok(()); + } + } + }; + + // Forward message to application + let message = Message { + message_type: MessageType::General { data: data.clone() }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], // TODO: Extract multiaddrs from peer info + sender_address: Some(sender_address), + response_tx: None, // General messages don't need response channels + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward message to application: {}", e); + } + + // Send acknowledgment + let response = p2p::Response::General(p2p::GeneralResponse { + data: b"ACK".to_vec(), + }); + + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send acknowledgment")?; + } + p2p::Request::Authentication(auth_req) => { + log::debug!("Processing Authentication request from {}", peer_id); + + match auth_req { + p2p::AuthenticationRequest::Initiation(init_req) => { + // Create a one-shot channel for the response + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + // Forward to application for handling + let message = Message { + message_type: MessageType::AuthenticationInitiation { + challenge: init_req.message, + }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + response_tx: Some(response_tx), // Pass the sender for response + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward auth initiation to application: {}", e); + return Ok(()); + } + + // Wait for the response from the message processor + match response_rx.await { + Ok(response) => { + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send auth response")?; + } + Err(_) => { + log::error!("Response channel closed for auth initiation"); + } + } + } + p2p::AuthenticationRequest::Solution(sol_req) => { + // Create a one-shot channel for the response + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + // Forward to application for handling + let message = Message { + message_type: MessageType::AuthenticationSolution { + signature: sol_req.signature, + }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + response_tx: Some(response_tx), // Pass the sender for response + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward auth solution to application: {}", e); + return Ok(()); + } + + // Wait for the response from the message processor + match response_rx.await { + Ok(response) => { + outbound_message_tx + .send(response.into_outgoing_message(channel)) + .await + .context("Failed to send auth solution response")?; + } + Err(_) => { + log::error!("Response channel closed for auth solution"); + } + } + } + } + } + _ => { + log::warn!("Received unsupported request type: {:?}", request); + } + } + + Ok(()) + } + + async fn handle_incoming_response_static( + message_queue_tx: &Sender, + peer_id: PeerId, + request_id: T, + response: p2p::Response, + ) -> Result<()> + where + T: std::fmt::Debug, + { + log::debug!( + "Received response from peer: {} (ID: {:?})", + peer_id, + request_id + ); + + match response { + p2p::Response::General(p2p::GeneralResponse { data }) => { + log::debug!("General response received: {} bytes", data.len()); + } + p2p::Response::Authentication(auth_resp) => { + log::debug!("Authentication response received from {}", peer_id); + + match auth_resp { + p2p::AuthenticationResponse::Initiation(init_resp) => { + // Forward to application + let message = Message { + message_type: MessageType::AuthenticationResponse { + challenge: init_resp.message, + signature: init_resp.signature, + }, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + response_tx: None, + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!("Failed to forward auth response to application: {}", e); + } + } + p2p::AuthenticationResponse::Solution(sol_resp) => { + log::debug!("Auth solution response: {:?}", sol_resp); + // This is handled by the authentication flow + } + } + } + _ => { + log::debug!("Received other response type: {:?}", response); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_message_serialization() { + let message = Message { + message_type: MessageType::General { + data: vec![1, 2, 3, 4], + }, + peer_id: "12D3KooWExample".to_string(), + multiaddrs: vec!["/ip4/127.0.0.1/tcp/4001".to_string()], + sender_address: Some("0x1234567890123456789012345678901234567890".to_string()), + response_tx: None, + }; + + let json = serde_json::to_string(&message).unwrap(); + let deserialized: Message = serde_json::from_str(&json).unwrap(); + + match (&message.message_type, &deserialized.message_type) { + (MessageType::General { data: data1 }, MessageType::General { data: data2 }) => { + assert_eq!(data1, data2); + } + _ => panic!("Message types don't match"), + } + assert_eq!(message.peer_id, deserialized.peer_id); + assert_eq!(message.multiaddrs, deserialized.multiaddrs); + assert_eq!(message.sender_address, deserialized.sender_address); + } +} diff --git a/crates/shared/src/discovery/README.md b/crates/shared/src/discovery/README.md new file mode 100644 index 00000000..25fe8f59 --- /dev/null +++ b/crates/shared/src/discovery/README.md @@ -0,0 +1,60 @@ +# Shared Discovery Utilities + +This module provides shared utilities for interacting with the Prime Protocol discovery service. + +## Overview + +The discovery service is a temporary solution while we migrate to Kadmilla DHT. It allows nodes to register their information (IP, port, compute specs) and enables validators and orchestrators to discover nodes. + +## Functions + +### `fetch_nodes_from_discovery_url` +Fetches nodes from a single discovery URL with proper authentication. + +### `fetch_nodes_from_discovery_urls` +Fetches nodes from multiple discovery URLs with automatic deduplication. + +### `fetch_pool_nodes_from_discovery` +Convenience function to fetch nodes for a specific compute pool. + +### `fetch_validator_nodes_from_discovery` +Convenience function to fetch all validator-accessible nodes. + +## Usage + +### From Rust (Validator/Orchestrator) + +```rust +use shared::discovery::fetch_validator_nodes_from_discovery; + +let nodes = fetch_validator_nodes_from_discovery( + &discovery_urls, + &wallet +).await?; +``` + +### From Python (Worker) + +```python +from prime_protocol import WorkerClient + +worker = WorkerClient(compute_pool_id=1, rpc_url="http://localhost:8545") + +# Configure discovery +worker.set_discovery_urls(["http://localhost:8089"]) +worker.set_node_config( + ip="192.168.1.100", + port=8080 +) + +# Upload to discovery +worker.start() +worker.upload_to_discovery() +``` + +## Migration Plan + +This is a temporary solution. Once Kadmilla DHT is fully integrated: +1. Discovery uploads will be replaced with DHT announcements +2. Discovery queries will be replaced with DHT lookups +3. The discovery service endpoints can be deprecated \ No newline at end of file diff --git a/crates/shared/src/discovery/mod.rs b/crates/shared/src/discovery/mod.rs new file mode 100644 index 00000000..11f33ce1 --- /dev/null +++ b/crates/shared/src/discovery/mod.rs @@ -0,0 +1,127 @@ +use crate::models::api::ApiResponse; +use crate::models::node::DiscoveryNode; +use crate::security::request_signer::sign_request_with_nonce; +use crate::web3::wallet::Wallet; +use anyhow::{Context, Result}; +use log::{debug, error}; +use std::time::Duration; + +/// Fetch nodes from a single discovery URL +pub async fn fetch_nodes_from_discovery_url( + discovery_url: &str, + route: &str, + wallet: &Wallet, +) -> Result> { + let address = wallet + .wallet + .default_signer() + .address() + .to_string(); + + let signature = sign_request_with_nonce(route, wallet, None) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "x-address", + reqwest::header::HeaderValue::from_str(&address) + .context("Failed to create address header")?, + ); + headers.insert( + "x-signature", + reqwest::header::HeaderValue::from_str(&signature.signature) + .context("Failed to create signature header")?, + ); + + debug!("Fetching nodes from: {discovery_url}{route}"); + let response = reqwest::Client::new() + .get(format!("{discovery_url}{route}")) + .query(&[("nonce", signature.nonce)]) + .headers(headers) + .timeout(Duration::from_secs(10)) + .send() + .await + .context("Failed to fetch nodes")?; + + let response_text = response + .text() + .await + .context("Failed to get response text")?; + + let parsed_response: ApiResponse> = + serde_json::from_str(&response_text).context("Failed to parse response")?; + + if !parsed_response.success { + error!("Failed to fetch nodes from {discovery_url}: {parsed_response:?}"); + return Ok(vec![]); + } + + Ok(parsed_response.data) +} + +/// Fetch nodes from multiple discovery URLs with deduplication +pub async fn fetch_nodes_from_discovery_urls( + discovery_urls: &[String], + route: &str, + wallet: &Wallet, +) -> Result> { + let mut all_nodes = Vec::new(); + let mut any_success = false; + + for discovery_url in discovery_urls { + match fetch_nodes_from_discovery_url(discovery_url, route, wallet).await { + Ok(nodes) => { + debug!( + "Successfully fetched {} nodes from {}", + nodes.len(), + discovery_url + ); + all_nodes.extend(nodes); + any_success = true; + } + Err(e) => { + error!("Failed to fetch nodes from {discovery_url}: {e:#}"); + } + } + } + + if !any_success { + error!("Failed to fetch nodes from all discovery services"); + return Ok(vec![]); + } + + // Remove duplicates based on node ID + let mut unique_nodes = Vec::new(); + let mut seen_ids = std::collections::HashSet::new(); + for node in all_nodes { + if seen_ids.insert(node.node.id.clone()) { + unique_nodes.push(node); + } + } + + debug!( + "Total unique nodes after deduplication: {}", + unique_nodes.len() + ); + Ok(unique_nodes) +} + +/// Fetch nodes for a specific pool from discovery +pub async fn fetch_pool_nodes_from_discovery( + discovery_urls: &[String], + compute_pool_id: u32, + wallet: &Wallet, +) -> Result> { + let route = format!("/api/pool/{}", compute_pool_id); + fetch_nodes_from_discovery_urls(discovery_urls, &route, wallet).await +} + +/// Fetch all validator-accessible nodes from discovery +pub async fn fetch_validator_nodes_from_discovery( + discovery_urls: &[String], + wallet: &Wallet, +) -> Result> { + let route = "/api/validator"; + fetch_nodes_from_discovery_urls(discovery_urls, &route, wallet).await +} \ No newline at end of file diff --git a/crates/shared/src/p2p/service.rs b/crates/shared/src/p2p/service.rs index bf776009..d2208b6e 100644 --- a/crates/shared/src/p2p/service.rs +++ b/crates/shared/src/p2p/service.rs @@ -365,7 +365,7 @@ async fn handle_validation_authentication_response( let ongoing_auth_requests = context.ongoing_auth_requests.read().await; let Some(ongoing_challenge) = ongoing_auth_requests.get(&from) else { bail!( - "no ongoing hardware challenge for peer {from}, cannot handle ValidatorAuthenticationInitiationResponse" + "no ongoing challenge for peer {from}, cannot handle ValidatorAuthenticationInitiationResponse" ); }; From 2b66c0bc0f3bea4ffb28f2370047a29ccf9642c8 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 12:25:44 +0200 Subject: [PATCH 07/23] clippy, add integration tests --- crates/p2p/src/lib.rs | 2 +- .../tests/integration/test_auth_flow.py | 147 ++++++++++++++++ .../tests/integration/test_worker.rs | 166 ++++++++++++++++++ 3 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 crates/prime-protocol-py/tests/integration/test_auth_flow.py create mode 100644 crates/prime-protocol-py/tests/integration/test_worker.rs diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs index 6542c7eb..228660e9 100644 --- a/crates/p2p/src/lib.rs +++ b/crates/p2p/src/lib.rs @@ -274,7 +274,7 @@ impl NodeBuilder { cancellation_token, } = self; - println!("multi addrs: {:?}", listen_addrs); + println!("multi addrs: {listen_addrs:?}"); let keypair = keypair.unwrap_or(identity::Keypair::generate_ed25519()); let peer_id = keypair.public().to_peer_id(); diff --git a/crates/prime-protocol-py/tests/integration/test_auth_flow.py b/crates/prime-protocol-py/tests/integration/test_auth_flow.py new file mode 100644 index 00000000..54d20137 --- /dev/null +++ b/crates/prime-protocol-py/tests/integration/test_auth_flow.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +Integration test for the authentication flow between two nodes. + +This test demonstrates: +1. Two nodes starting up with their own wallets +2. First message triggers authentication +3. Addresses are extracted from signatures +4. Subsequent messages are sent directly +""" + +import asyncio +import pytest +import time +from primeprotocol import WorkerClient +import logging + +# Set up logging +logging.basicConfig(level=logging.DEBUG) + + +@pytest.fixture +def setup_test_environment(): + """Set up test environment with Anvil keys""" + # Anvil test keys + node_a_key = "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80" + node_b_key = "0x59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d" + provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a" + + return { + "rpc_url": "http://localhost:8545", + "pool_id": 0, + "node_a": {"node_key": node_a_key, "provider_key": provider_key, "port": 9001}, + "node_b": {"node_key": node_b_key, "provider_key": provider_key, "port": 9002}, + } + + +@pytest.mark.asyncio +async def test_authentication_flow(setup_test_environment): + """Test the full authentication flow between two nodes""" + env = setup_test_environment + + # Create two worker clients + client_a = WorkerClient( + compute_pool_id=env["pool_id"], + rpc_url=env["rpc_url"], + private_key_provider=env["node_a"]["provider_key"], + private_key_node=env["node_a"]["node_key"], + p2p_port=env["node_a"]["port"] + ) + + client_b = WorkerClient( + compute_pool_id=env["pool_id"], + rpc_url=env["rpc_url"], + private_key_provider=env["node_b"]["provider_key"], + private_key_node=env["node_b"]["node_key"], + p2p_port=env["node_b"]["port"] + ) + + try: + # Start both clients + logging.info("Starting client A...") + client_a.start() + peer_a_id = client_a.get_own_peer_id() + logging.info(f"Client A started with peer ID: {peer_a_id}") + + logging.info("Starting client B...") + client_b.start() + peer_b_id = client_b.get_own_peer_id() + logging.info(f"Client B started with peer ID: {peer_b_id}") + + # Give nodes time to start + await asyncio.sleep(2) + + # Test 1: First message from A to B (triggers authentication) + logging.info("Test 1: Sending first message from A to B...") + client_a.send_message( + peer_id=peer_b_id, + data=b"Hello from A! This is the first message.", + multiaddrs=[f"/ip4/127.0.0.1/tcp/{env['node_b']['port']}"] + ) + + # Check if B receives the message + message = None + for _ in range(50): # Poll for up to 5 seconds + message = client_b.get_next_message() + if message: + break + await asyncio.sleep(0.1) + + assert message is not None, "Client B did not receive message" + assert message["message"]["type"] == "general" + assert bytes(message["message"]["data"]) == b"Hello from A! This is the first message." + assert message["sender_address"] is not None # Should have sender's address + logging.info(f"Client B received message with sender address: {message['sender_address']}") + + # Test 2: Second message from A to B (should be direct, no auth) + logging.info("Test 2: Sending second message from A to B...") + client_a.send_message( + peer_id=peer_b_id, + data=b"Second message - should be direct", + multiaddrs=[f"/ip4/127.0.0.1/tcp/{env['node_b']['port']}"] + ) + + # Should receive quickly since already authenticated + message = None + for _ in range(20): # Should be faster + message = client_b.get_next_message() + if message: + break + await asyncio.sleep(0.1) + + assert message is not None, "Client B did not receive second message" + assert bytes(message["message"]["data"]) == b"Second message - should be direct" + + # Test 3: Message from B to A (first message in this direction) + logging.info("Test 3: Sending message from B to A...") + client_b.send_message( + peer_id=peer_a_id, + data=b"Hello from B!", + multiaddrs=[f"/ip4/127.0.0.1/tcp/{env['node_a']['port']}"] + ) + + # Check if A receives the message + message = None + for _ in range(50): + message = client_a.get_next_message() + if message: + break + await asyncio.sleep(0.1) + + assert message is not None, "Client A did not receive message" + assert bytes(message["message"]["data"]) == b"Hello from B!" + assert message["sender_address"] is not None + logging.info(f"Client A received message with sender address: {message['sender_address']}") + + logging.info("All tests passed! Authentication flow working correctly.") + + finally: + # Clean up + client_a.stop() + client_b.stop() + + +if __name__ == "__main__": + # Run the test + asyncio.run(test_authentication_flow({})) \ No newline at end of file diff --git a/crates/prime-protocol-py/tests/integration/test_worker.rs b/crates/prime-protocol-py/tests/integration/test_worker.rs new file mode 100644 index 00000000..22e77446 --- /dev/null +++ b/crates/prime-protocol-py/tests/integration/test_worker.rs @@ -0,0 +1,166 @@ +#[cfg(test)] +mod worker_integration_tests { + use prime_protocol_py::worker::{Message, WorkerClientCore}; + use test_log::test; + use tokio_util::sync::CancellationToken; + + // Note: These tests require a running local blockchain with deployed contracts + // Run with: cargo test --test integration -- --ignored + + #[test(tokio::test)] + #[ignore = "requires local blockchain setup"] + async fn test_worker_full_lifecycle() { + // Standard Anvil test keys + let node_key = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + let cancellation_token = CancellationToken::new(); + + let mut worker = WorkerClientCore::new( + 0, // compute_pool_id + "http://localhost:8545".to_string(), + Some(provider_key.to_string()), + Some(node_key.to_string()), + Some(true), // auto_accept_transactions + Some(5), // funding_retry_count + cancellation_token.clone(), + 8000, // p2p_port + ) + .expect("Failed to create worker"); + + // Start the worker + worker + .start_async() + .await + .expect("Failed to start worker"); + + // Verify peer ID was assigned + let peer_id = worker.get_peer_id(); + assert!(peer_id.is_some()); + println!("Worker started with peer ID: {:?}", peer_id); + + // Test message sending to self + let test_message = Message { + data: b"Hello, self!".to_vec(), + peer_id: peer_id.unwrap().to_string(), + multiaddrs: vec![], + }; + + worker + .send_message(test_message) + .await + .expect("Failed to send message"); + + // Give some time for the message to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Try to receive the message + let received = worker.get_next_message().await; + assert!(received.is_some()); + + let msg = received.unwrap(); + assert_eq!(msg.data, b"Hello, self!"); + + // Clean shutdown + worker + .stop_async() + .await + .expect("Failed to stop worker"); + } + + #[test(tokio::test)] + #[ignore = "requires local blockchain setup"] + async fn test_multiple_workers_communication() { + let node_key1 = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key1 = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + + let node_key2 = "0x47e179ec197488593b187f80a00eb0da91f1b9d0b13f8733639f19c30a34926a"; + let provider_key2 = "0x8b3a350cf5c34c9194ca85829a2df0ec3153be0318b5e2d3348e872092edffba"; + + let cancel_token1 = CancellationToken::new(); + let cancel_token2 = CancellationToken::new(); + + // Create two workers + let mut worker1 = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key1.to_string()), + Some(node_key1.to_string()), + Some(true), + Some(5), + cancel_token1.clone(), + 8001, + ) + .expect("Failed to create worker1"); + + let mut worker2 = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key2.to_string()), + Some(node_key2.to_string()), + Some(true), + Some(5), + cancel_token2.clone(), + 8002, + ) + .expect("Failed to create worker2"); + + // Start both workers + worker1.start_async().await.expect("Failed to start worker1"); + worker2.start_async().await.expect("Failed to start worker2"); + + let peer_id1 = worker1.get_peer_id().unwrap(); + let peer_id2 = worker2.get_peer_id().unwrap(); + + println!("Worker1 peer ID: {}", peer_id1); + println!("Worker2 peer ID: {}", peer_id2); + + // Worker1 sends message to Worker2 + let message = Message { + data: b"Hello from Worker1!".to_vec(), + peer_id: peer_id2.to_string(), + multiaddrs: vec!["/ip4/127.0.0.1/tcp/8002".to_string()], + }; + + worker1 + .send_message(message) + .await + .expect("Failed to send message from worker1"); + + // Give time for message delivery + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + // Worker2 should receive the message + let received = worker2.get_next_message().await; + assert!(received.is_some()); + + let msg = received.unwrap(); + assert_eq!(msg.data, b"Hello from Worker1!"); + assert_eq!(msg.peer_id, peer_id1.to_string()); + + // Clean shutdown + worker1.stop_async().await.expect("Failed to stop worker1"); + worker2.stop_async().await.expect("Failed to stop worker2"); + } + + #[test(tokio::test)] + async fn test_worker_without_blockchain() { + // Test that worker fails gracefully when blockchain is not available + let cancellation_token = CancellationToken::new(); + + let mut worker = WorkerClientCore::new( + 0, + "http://localhost:9999".to_string(), // Non-existent RPC + Some("0x1234".to_string()), + Some("0x5678".to_string()), + Some(true), + Some(1), // Low retry count for faster test + cancellation_token, + 8003, + ) + .expect("Failed to create worker"); + + // Starting should fail due to blockchain connection issues + let result = worker.start_async().await; + assert!(result.is_err()); + } +} \ No newline at end of file From 720dab3f76c75d65d53f3279ddf38141965c5da2 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 13:30:19 +0200 Subject: [PATCH 08/23] add discovery to shared, use discovery service in protocol sdk until we have dht support --- Cargo.lock | 1 + Makefile | 2 +- crates/prime-protocol-py/README.md | 90 -------------- .../prime-protocol-py/examples/basic_usage.py | 2 + .../src/worker/blockchain.rs | 5 + crates/prime-protocol-py/src/worker/client.rs | 61 ++++++++++ .../prime-protocol-py/src/worker/discovery.rs | 55 +++++++++ crates/prime-protocol-py/src/worker/mod.rs | 42 +++++++ .../tests/integration/test_worker.rs | 3 - crates/shared/Cargo.toml | 1 + crates/shared/src/discovery/mod.rs | 111 ++++++++++++++++-- crates/shared/src/lib.rs | 1 + crates/worker/src/services/discovery.rs | 89 ++------------ 13 files changed, 281 insertions(+), 182 deletions(-) delete mode 100644 crates/prime-protocol-py/README.md create mode 100644 crates/prime-protocol-py/src/worker/discovery.rs diff --git a/Cargo.lock b/Cargo.lock index 918c65c6..b6e846d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8315,6 +8315,7 @@ dependencies = [ "rand 0.9.1", "redis", "regex", + "reqwest", "serde", "serde_json", "subtle", diff --git a/Makefile b/Makefile index 5de39578..fb4b62be 100644 --- a/Makefile +++ b/Makefile @@ -102,7 +102,7 @@ up: bootstrap: @echo "Starting Docker services and deploying contracts..." @# Start Docker services - @docker compose up -d reth redis --wait --wait-timeout 180 + @docker compose up -d reth redis discovery --wait --wait-timeout 180 @# Deploy contracts @cd smart-contracts && sh deploy.sh && sh deploy_work_validation.sh && cd .. @# Run setup diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md deleted file mode 100644 index b72b39db..00000000 --- a/crates/prime-protocol-py/README.md +++ /dev/null @@ -1,90 +0,0 @@ -# Prime Protocol Python Client - -## Build - -```bash -# Install uv (one-time) -curl -LsSf https://astral.sh/uv/install.sh | sh - -# Setup and build -cd crates/prime-protocol-py -make install -``` - -## Usage - -### Worker Client with Message Queue - -The Worker Client provides a message queue system for handling P2P messages from pool owners and validators. Messages are processed in a FIFO (First-In-First-Out) manner. - -```python -from primeprotocol import WorkerClient -import asyncio - -# Initialize the worker client -client = WorkerClient( - compute_pool_id=1, - rpc_url="http://localhost:8545", - private_key_provider="your_provider_key", - private_key_node="your_node_key", -) - -# Start the client (registers on-chain and starts message listener) -client.start() - -# Poll for messages in your application loop -async def process_messages(): - while True: - # Get next message from pool owner queue - pool_msg = client.get_pool_owner_message() - if pool_msg: - print(f"Pool owner message: {pool_msg}") - # Process the message... - - # Get next message from validator queue - validator_msg = client.get_validator_message() - if validator_msg: - print(f"Validator message: {validator_msg}") - # Process the message... - - await asyncio.sleep(0.1) - -# Run the message processing loop -asyncio.run(process_messages()) - -# Gracefully shutdown -client.stop() -``` - -### Message Queue Features - -- **Background Listener**: Rust protocol listens for P2P messages in the background -- **FIFO Queue**: Messages are processed in the order they are received -- **Message Types**: Separate queues for pool owner, validator, and system messages -- **Mock Mode**: Currently generates mock messages for testing (P2P integration coming soon) -- **Thread-Safe**: Safe to use from async Python code - -See `examples/message_queue_example.py` for a complete working example. - -## Development - -```bash -make build # Build development version -make test # Run tests -make example # Run example -make clean # Clean artifacts -make help # Show all commands -``` - -## Installing in other projects - -```bash -# Build the wheel -make build-release - -# Install with uv (recommended) -uv pip install target/wheels/primeprotocol-*.whl - -# Or install directly from source -uv pip install /path/to/prime-protocol-py/ -``` \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py index dd72b548..365f883d 100644 --- a/crates/prime-protocol-py/examples/basic_usage.py +++ b/crates/prime-protocol-py/examples/basic_usage.py @@ -98,6 +98,8 @@ def signal_handler(sig, frame): logging.info("Starting client... (Press Ctrl+C to interrupt)") client.start() + client.upload_to_discovery("127.0.0.1", None) + my_peer_id = client.get_own_peer_id() logging.info(f"My Peer ID: {my_peer_id}") diff --git a/crates/prime-protocol-py/src/worker/blockchain.rs b/crates/prime-protocol-py/src/worker/blockchain.rs index a3df3155..e1d5479f 100644 --- a/crates/prime-protocol-py/src/worker/blockchain.rs +++ b/crates/prime-protocol-py/src/worker/blockchain.rs @@ -44,6 +44,11 @@ impl BlockchainService { self.node_wallet.as_ref() } + /// Get the provider wallet + pub fn provider_wallet(&self) -> Option<&Wallet> { + self.provider_wallet.as_ref() + } + /// Initialize blockchain components and ensure the node is properly registered pub async fn initialize(&mut self) -> Result<()> { let (provider_wallet, node_wallet, contracts) = self.create_wallets_and_contracts().await?; diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index 08a0c824..497cd261 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -374,6 +374,67 @@ impl WorkerClientCore { Ok(()) } + /// Get the provider's Ethereum address + pub fn get_provider_address(&self) -> Result { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + let provider_wallet = blockchain_service.provider_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Provider wallet not initialized".to_string()) + })?; + Ok(provider_wallet + .wallet + .default_signer() + .address() + .to_string()) + } + + // todo: move to blockchain service + pub fn get_node_address(&self) -> Result { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + let node_wallet = blockchain_service.node_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Node wallet not initialized".to_string()) + })?; + Ok(node_wallet.wallet.default_signer().address().to_string()) + } + + /// Get the compute pool ID + pub fn get_compute_pool_id(&self) -> u64 { + self.config.compute_pool_id + } + + /// Get the listening multiaddresses from the P2P service + pub async fn get_listening_addresses(&self) -> Vec { + // For now, return a simple localhost address with the configured port + // In the future, this could query the actual P2P service for its listen addresses + vec![format!("/ip4/0.0.0.0/tcp/{}", self.config.p2p_port)] + } + + /// Upload node information to discovery services + pub async fn upload_to_discovery( + &self, + node_info: &crate::worker::discovery::SimpleNode, + discovery_urls: &[String], + ) -> Result<()> { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + + let node_wallet = blockchain_service.node_wallet().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Provider wallet not initialized".to_string()) + })?; + + let node_json = node_info.to_json_value(); + + shared::discovery::upload_node_to_discovery(discovery_urls, &node_json, node_wallet) + .await + .map_err(|e| { + PrimeProtocolError::RuntimeError(format!("Failed to upload to discovery: {}", e)) + }) + } + fn get_private_key_provider(&self) -> Result { self.config .private_key_provider diff --git a/crates/prime-protocol-py/src/worker/discovery.rs b/crates/prime-protocol-py/src/worker/discovery.rs new file mode 100644 index 00000000..9be96eb3 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/discovery.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +/// Simple node information for discovery upload +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleNode { + /// Unique identifier for the node + pub id: String, + /// The external IP address of the node + pub ip_address: String, + /// The port for the node (usually 0 for dynamic port allocation) + pub port: u16, + /// The Ethereum address of the node + pub provider_address: String, + /// The compute pool ID this node belongs to + pub compute_pool_id: u32, + /// The P2P peer ID + pub worker_p2p_id: String, + /// The P2P multiaddresses for connecting to this node + pub worker_p2p_addresses: Vec, +} + +impl SimpleNode { + pub fn new( + ip_address: String, + port: u16, + provider_wallet_address: String, + node_wallet_address: String, + compute_pool_id: u32, + peer_id: String, + multi_addresses: Vec, + ) -> Self { + Self { + id: node_wallet_address, + ip_address, + port, + provider_address: provider_wallet_address, + compute_pool_id, + worker_p2p_id: peer_id, + worker_p2p_addresses: multi_addresses, + } + } + + /// Convert to serde_json::Value for upload + pub fn to_json_value(&self) -> serde_json::Value { + serde_json::json!({ + "id": self.id, + "ip_address": self.ip_address, + "port": self.port, + "provider_address": self.provider_address, + "compute_pool_id": self.compute_pool_id, + "worker_p2p_id": self.worker_p2p_id, + "worker_p2p_addresses": self.worker_p2p_addresses, + }) + } +} diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index 78853ac2..13d6150a 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -4,6 +4,7 @@ mod auth; mod blockchain; mod client; mod constants; +mod discovery; mod message_processor; mod p2p; @@ -103,6 +104,47 @@ impl WorkerClient { Ok(self.inner.get_peer_id().map(|id| id.to_string())) } + /// Upload node information to discovery services + /// + /// Args: + /// ip_address: External IP address of the node + /// discovery_urls: List of discovery service URLs (defaults to ["http://localhost:8089"]) + pub fn upload_to_discovery( + &self, + ip_address: String, + discovery_urls: Option>, + py: Python, + ) -> PyResult<()> { + let rt = self.ensure_runtime()?; + + // Get the peer ID and multiaddresses from the P2P service + let peer_id = self + .inner + .get_peer_id() + .ok_or_else(|| to_py_err("P2P service not started. Cannot get peer ID"))? + .to_string(); + + let multi_addresses = + py.allow_threads(|| rt.block_on(self.inner.get_listening_addresses())); + // Create simple node info (port 0 indicates dynamic port allocation) + let node_info = discovery::SimpleNode::new( + ip_address, + 0, // Port 0 for dynamic allocation + self.inner.get_provider_address().map_err(to_py_err)?, + self.inner.get_node_address().map_err(to_py_err)?, + self.inner.get_compute_pool_id() as u32, + peer_id, + multi_addresses, + ); + + // Use default discovery URLs if none provided + let urls = discovery_urls.unwrap_or_else(|| vec!["http://localhost:8089".to_string()]); + + // Upload to discovery + py.allow_threads(|| rt.block_on(self.inner.upload_to_discovery(&node_info, &urls))) + .map_err(to_py_err) + } + /// Stop the worker client and clean up resources pub fn stop(&mut self, py: Python) -> PyResult<()> { // Cancel all background tasks diff --git a/crates/prime-protocol-py/tests/integration/test_worker.rs b/crates/prime-protocol-py/tests/integration/test_worker.rs index 22e77446..ed4e7a4e 100644 --- a/crates/prime-protocol-py/tests/integration/test_worker.rs +++ b/crates/prime-protocol-py/tests/integration/test_worker.rs @@ -36,7 +36,6 @@ mod worker_integration_tests { // Verify peer ID was assigned let peer_id = worker.get_peer_id(); assert!(peer_id.is_some()); - println!("Worker started with peer ID: {:?}", peer_id); // Test message sending to self let test_message = Message { @@ -111,8 +110,6 @@ mod worker_integration_tests { let peer_id1 = worker1.get_peer_id().unwrap(); let peer_id2 = worker2.get_peer_id().unwrap(); - println!("Worker1 peer ID: {}", peer_id1); - println!("Worker2 peer ID: {}", peer_id2); // Worker1 sends message to Worker2 let message = Message { diff --git a/crates/shared/Cargo.toml b/crates/shared/Cargo.toml index 4d3a8760..a1b5f189 100644 --- a/crates/shared/Cargo.toml +++ b/crates/shared/Cargo.toml @@ -44,3 +44,4 @@ subtle = "2.6.1" utoipa = { version = "5.3.0", features = ["actix_extras", "chrono", "uuid"] } futures = { workspace = true } tokio-util = { workspace = true } +reqwest = { workspace = true } diff --git a/crates/shared/src/discovery/mod.rs b/crates/shared/src/discovery/mod.rs index 11f33ce1..6e254cc2 100644 --- a/crates/shared/src/discovery/mod.rs +++ b/crates/shared/src/discovery/mod.rs @@ -12,11 +12,7 @@ pub async fn fetch_nodes_from_discovery_url( route: &str, wallet: &Wallet, ) -> Result> { - let address = wallet - .wallet - .default_signer() - .address() - .to_string(); + let address = wallet.wallet.default_signer().address().to_string(); let signature = sign_request_with_nonce(route, wallet, None) .await @@ -113,7 +109,7 @@ pub async fn fetch_pool_nodes_from_discovery( compute_pool_id: u32, wallet: &Wallet, ) -> Result> { - let route = format!("/api/pool/{}", compute_pool_id); + let route = format!("/api/pool/{compute_pool_id}"); fetch_nodes_from_discovery_urls(discovery_urls, &route, wallet).await } @@ -123,5 +119,104 @@ pub async fn fetch_validator_nodes_from_discovery( wallet: &Wallet, ) -> Result> { let route = "/api/validator"; - fetch_nodes_from_discovery_urls(discovery_urls, &route, wallet).await -} \ No newline at end of file + fetch_nodes_from_discovery_urls(discovery_urls, route, wallet).await +} + +/// Upload node information to discovery services +/// +/// This function attempts to upload node information to all provided discovery URLs. +/// It returns Ok(()) if at least one upload succeeds. +pub async fn upload_node_to_discovery( + discovery_urls: &[String], + node_data: &serde_json::Value, + wallet: &Wallet, +) -> Result<()> { + let endpoint = "/api/nodes"; + let mut last_error: Option = None; + + for discovery_url in discovery_urls { + match upload_to_single_discovery(discovery_url, endpoint, node_data, wallet).await { + Ok(_) => { + debug!("Successfully uploaded node info to {discovery_url}"); + return Ok(()); + } + Err(e) => { + error!("Failed to upload to {discovery_url}: {e}"); + last_error = Some(e.to_string()); + } + } + } + + // If we reach here, all discovery services failed + if let Some(error) = last_error { + Err(anyhow::anyhow!( + "Failed to upload to all discovery services. Last error: {}", + error + )) + } else { + Err(anyhow::anyhow!( + "Failed to upload to all discovery services" + )) + } +} + +async fn upload_to_single_discovery( + base_url: &str, + endpoint: &str, + node_data: &serde_json::Value, + wallet: &Wallet, +) -> Result<()> { + + let signed_request = sign_request_with_nonce(endpoint, wallet, Some(node_data)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "x-address", + wallet + .wallet + .default_signer() + .address() + .to_string() + .parse() + .context("Failed to parse address header")?, + ); + headers.insert( + "x-signature", + signed_request + .signature + .parse() + .context("Failed to parse signature header")?, + ); + + let request_url = format!("{base_url}{endpoint}"); + let response = reqwest::Client::new() + .put(&request_url) + .headers(headers) + .json( + &signed_request + .data + .expect("Signed request data should always be present for discovery upload"), + ) + .timeout(Duration::from_secs(10)) + .send() + .await + .context("Failed to send request")?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "No error message".to_string()); + return Err(anyhow::anyhow!( + "Error: Received response with status code {} from {}: {}", + status, + base_url, + error_text + )); + } + + Ok(()) +} diff --git a/crates/shared/src/lib.rs b/crates/shared/src/lib.rs index 5ce256ed..daa51c09 100644 --- a/crates/shared/src/lib.rs +++ b/crates/shared/src/lib.rs @@ -1,3 +1,4 @@ +pub mod discovery; pub mod models; pub mod p2p; pub mod security; diff --git a/crates/worker/src/services/discovery.rs b/crates/worker/src/services/discovery.rs index 2088215c..a0e60f79 100644 --- a/crates/worker/src/services/discovery.rs +++ b/crates/worker/src/services/discovery.rs @@ -1,16 +1,14 @@ use anyhow::Result; use shared::models::node::Node; -use shared::security::request_signer::sign_request_with_nonce; use shared::web3::wallet::Wallet; pub(crate) struct DiscoveryService { wallet: Wallet, base_urls: Vec, - endpoint: String, } impl DiscoveryService { - pub(crate) fn new(wallet: Wallet, base_urls: Vec, endpoint: Option) -> Self { + pub(crate) fn new(wallet: Wallet, base_urls: Vec, _endpoint: Option) -> Self { let urls = if base_urls.is_empty() { vec!["http://localhost:8089".to_string()] } else { @@ -19,86 +17,18 @@ impl DiscoveryService { Self { wallet, base_urls: urls, - endpoint: endpoint.unwrap_or_else(|| "/api/nodes".to_string()), } } - async fn upload_to_single_discovery(&self, node_config: &Node, base_url: &str) -> Result<()> { - let request_data = serde_json::to_value(node_config)?; - - let signed_request = - sign_request_with_nonce(&self.endpoint, &self.wallet, Some(&request_data)) - .await - .map_err(|e| anyhow::anyhow!("{}", e))?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "x-address", - self.wallet - .wallet - .default_signer() - .address() - .to_string() - .parse() - .unwrap(), - ); - headers.insert("x-signature", signed_request.signature.parse().unwrap()); - let request_url = format!("{}{}", base_url, &self.endpoint); - let client = reqwest::Client::new(); - let response = client - .put(&request_url) - .headers(headers) - .json( - &signed_request - .data - .expect("Signed request data should always be present for discovery upload"), - ) - .send() - .await?; - - if !response.status().is_success() { - let status = response.status(); - let error_text = response - .text() - .await - .unwrap_or_else(|_| "No error message".to_string()); - return Err(anyhow::anyhow!( - "Error: Received response with status code {} from {}: {}", - status, - base_url, - error_text - )); - } - - Ok(()) - } - pub(crate) async fn upload_discovery_info(&self, node_config: &Node) -> Result<()> { - let mut last_error: Option = None; - - for base_url in &self.base_urls { - match self.upload_to_single_discovery(node_config, base_url).await { - Ok(_) => { - // Successfully uploaded to one discovery service, return immediately - return Ok(()); - } - Err(e) => { - last_error = Some(e.to_string()); - } - } - } - - // If we reach here, all discovery services failed - if let Some(error) = last_error { - Err(anyhow::anyhow!( - "Failed to upload to all discovery services. Last error: {}", - error - )) - } else { - Err(anyhow::anyhow!( - "Failed to upload to all discovery services" - )) - } + let node_data = serde_json::to_value(node_config)?; + + shared::discovery::upload_node_to_discovery( + &self.base_urls, + &node_data, + &self.wallet, + ) + .await } } @@ -107,7 +37,6 @@ impl Clone for DiscoveryService { Self { wallet: self.wallet.clone(), base_urls: self.base_urls.clone(), - endpoint: self.endpoint.clone(), } } } From de4a89e96e46e59c0f5d45b6684f1f68a762636f Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 14:06:04 +0200 Subject: [PATCH 09/23] implement basic validator sdk functionality to list nodes from discovery --- Cargo.lock | 142 +++++++-- crates/prime-protocol-py/Cargo.toml | 1 + crates/prime-protocol-py/README.md | 122 ++++++++ .../examples/validator_list_nodes.py | 93 ++++++ crates/prime-protocol-py/src/validator/mod.rs | 280 +++++++++++++++--- .../prime-protocol-py/tests/test_validator.py | 139 +++++++++ 6 files changed, 709 insertions(+), 68 deletions(-) create mode 100644 crates/prime-protocol-py/README.md create mode 100644 crates/prime-protocol-py/examples/validator_list_nodes.py create mode 100644 crates/prime-protocol-py/tests/test_validator.py diff --git a/Cargo.lock b/Cargo.lock index b6e846d4..339c175c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -640,7 +640,7 @@ dependencies = [ "lru 0.13.0", "parking_lot 0.12.3", "pin-project", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "thiserror 2.0.12", @@ -709,7 +709,7 @@ dependencies = [ "async-stream", "futures", "pin-project", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "tokio", @@ -980,7 +980,7 @@ checksum = "21238d7ce1425bdb42269624b59a4fcc1c744bcfcb3cc762cbb251761c45d488" dependencies = [ "alloy-json-rpc", "alloy-transport", - "reqwest", + "reqwest 0.12.15", "serde_json", "tower", "tracing", @@ -2615,7 +2615,7 @@ dependencies = [ "log", "redis", "redis-test", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "shared", @@ -3357,7 +3357,7 @@ dependencies = [ "google-cloud-token", "home", "jsonwebtoken", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "thiserror 1.0.69", @@ -3377,9 +3377,9 @@ dependencies = [ "base64 0.22.1", "derive_builder", "http 1.3.1", - "reqwest", + "reqwest 0.12.15", "rustls", - "rustls-pemfile", + "rustls-pemfile 2.2.0", "serde", "serde_json", "thiserror 2.0.12", @@ -3393,7 +3393,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d901aeb453fd80e51d64df4ee005014f6cf39f2d736dd64f7239c132d9d39a6a" dependencies = [ - "reqwest", + "reqwest 0.12.15", "thiserror 1.0.69", "tokio", ] @@ -3418,7 +3418,7 @@ dependencies = [ "percent-encoding", "pkcs8", "regex", - "reqwest", + "reqwest 0.12.15", "reqwest-middleware", "ring 0.17.14", "serde", @@ -3854,6 +3854,19 @@ dependencies = [ "webpki-roots 0.26.9", ] +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.32", + "native-tls", + "tokio", + "tokio-native-tls", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -4116,7 +4129,7 @@ dependencies = [ "netlink-proto", "netlink-sys", "rtnetlink 0.13.1", - "system-configuration", + "system-configuration 0.6.1", "tokio", "windows 0.52.0", ] @@ -4335,7 +4348,7 @@ dependencies = [ "portmapper", "rand 0.8.5", "rcgen 0.13.2", - "reqwest", + "reqwest 0.12.15", "ring 0.17.14", "rustls", "rustls-webpki 0.102.8", @@ -4468,7 +4481,7 @@ dependencies = [ "pkarr", "postcard", "rand 0.8.5", - "reqwest", + "reqwest 0.12.15", "rustls", "rustls-webpki 0.102.8", "serde", @@ -5775,7 +5788,7 @@ dependencies = [ "netlink-packet-route 0.17.1", "netlink-sys", "once_cell", - "system-configuration", + "system-configuration 0.6.1", "windows-sys 0.52.0", ] @@ -6195,7 +6208,7 @@ dependencies = [ "rand 0.9.1", "redis", "redis-test", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "shared", @@ -6736,6 +6749,7 @@ dependencies = [ "pyo3-log", "pythonize", "rand 0.8.5", + "reqwest 0.11.27", "serde", "serde_json", "shared", @@ -7395,6 +7409,46 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "reqwest" +version = "0.11.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +dependencies = [ + "base64 0.21.7", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-tls 0.5.0", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 0.1.2", + "system-configuration 0.5.1", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "reqwest" version = "0.12.15" @@ -7413,7 +7467,7 @@ dependencies = [ "http-body-util", "hyper 1.6.0", "hyper-rustls", - "hyper-tls", + "hyper-tls 0.6.0", "hyper-util", "ipnet", "js-sys", @@ -7426,13 +7480,13 @@ dependencies = [ "pin-project-lite", "quinn", "rustls", - "rustls-pemfile", + "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", - "system-configuration", + "sync_wrapper 1.0.2", + "system-configuration 0.6.1", "tokio", "tokio-native-tls", "tokio-rustls", @@ -7457,7 +7511,7 @@ dependencies = [ "anyhow", "async-trait", "http 1.3.1", - "reqwest", + "reqwest 0.12.15", "serde", "thiserror 1.0.69", "tower-service", @@ -7836,6 +7890,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -8315,7 +8378,7 @@ dependencies = [ "rand 0.9.1", "redis", "regex", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "subtle", @@ -8723,6 +8786,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.2" @@ -8758,6 +8827,17 @@ dependencies = [ "windows 0.52.0", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys 0.5.0", +] + [[package]] name = "system-configuration" version = "0.6.1" @@ -8766,7 +8846,17 @@ checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.9.0", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.6.0", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -9120,7 +9210,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -9201,7 +9291,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3beec919fbdf99d719de8eda6adae3281f8a5b71ae40431f44dc7423053d34" dependencies = [ "loki-api", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "snap", @@ -9500,7 +9590,7 @@ dependencies = [ "base64 0.22.1", "mime_guess", "regex", - "reqwest", + "reqwest 0.12.15", "rust-embed", "serde", "serde_json", @@ -9548,7 +9638,7 @@ dependencies = [ "redis", "redis-test", "regex", - "reqwest", + "reqwest 0.12.15", "serde", "serde_json", "shared", @@ -10390,7 +10480,7 @@ dependencies = [ "prime-core", "rand 0.8.5", "rand 0.9.1", - "reqwest", + "reqwest 0.12.15", "rust-ipfs", "serde", "serde_json", diff --git a/crates/prime-protocol-py/Cargo.toml b/crates/prime-protocol-py/Cargo.toml index 8f4d96be..971480e3 100644 --- a/crates/prime-protocol-py/Cargo.toml +++ b/crates/prime-protocol-py/Cargo.toml @@ -29,6 +29,7 @@ anyhow = { workspace = true } tokio-util = { workspace = true } rand = { version = "0.8", features = ["std"] } hex = "0.4" +reqwest = { version = "0.11", features = ["json"] } [dev-dependencies] test-log = "0.2" diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md new file mode 100644 index 00000000..cd458fc9 --- /dev/null +++ b/crates/prime-protocol-py/README.md @@ -0,0 +1,122 @@ +# Prime Protocol Python SDK + +Python bindings for the Prime Protocol, providing easy access to worker, orchestrator, and validator functionalities. + +## Installation + +```bash +# Clone the repository +git clone https://github.com/primeprotocol/protocol.git +cd protocol/crates/prime-protocol-py + +# Build and install the package +pip install maturin +maturin develop +``` + +## Validator Client + +The Validator Client allows validators to interact with the Prime Protocol network, particularly for listing and validating nodes. + +### Basic Usage + +```python +from primeprotocol import ValidatorClient + +# Initialize the validator client +validator = ValidatorClient( + rpc_url="http://localhost:8545", + private_key="YOUR_PRIVATE_KEY", + discovery_urls=["http://localhost:8089"] # Can specify multiple discovery services +) + +# List all non-validated nodes +non_validated_nodes = validator.list_non_validated_nodes() + +for node in non_validated_nodes: + print(f"Node ID: {node.id}") + print(f"Provider: {node.provider_address}") + print(f"Address: {node.ip_address}:{node.port}") + print(f"Active: {node.is_active}") + print(f"Whitelisted: {node.is_provider_whitelisted}") + print() + +# Get count of non-validated nodes +count = validator.get_non_validated_count() +print(f"Total non-validated nodes: {count}") + +# Get all nodes as dictionaries (includes compute specs) +all_nodes = validator.list_all_nodes_dict() +for node in all_nodes: + if not node['is_validated']: + print(f"Node {node['id']} needs validation") + + # Access compute specs if available + if 'node' in node and 'compute_specs' in node['node']: + specs = node['node']['compute_specs'] + if specs: + print(f" RAM: {specs.get('ram_mb', 'N/A')} MB") + print(f" Storage: {specs.get('storage_gb', 'N/A')} GB") +``` + +### Node Details + +Each node returned by `list_non_validated_nodes()` has the following attributes: + +- `id`: Unique identifier for the node +- `provider_address`: On-chain address of the node provider +- `ip_address`: IP address of the node +- `port`: Port number for node communication +- `compute_pool_id`: ID of the compute pool the node belongs to +- `is_validated`: Whether the node has been validated +- `is_active`: Whether the node is currently active +- `is_provider_whitelisted`: Whether the provider is whitelisted +- `is_blacklisted`: Whether the node is blacklisted +- `worker_p2p_id`: P2P identifier (optional) +- `last_updated`: Last update timestamp (optional) +- `created_at`: Creation timestamp (optional) + +### Environment Variables + +The validator client can be configured using environment variables: + +```bash +export RPC_URL="http://localhost:8545" +export VALIDATOR_PRIVATE_KEY="your_private_key_here" +export DISCOVERY_URLS="http://localhost:8089,http://backup-discovery:8089" +``` + +### Example Scripts + +See the `examples/` directory for complete examples: +- `validator_list_nodes.py`: Lists all non-validated nodes with detailed output + +## Worker Client + +[Documentation for WorkerClient...] + +## Orchestrator Client + +[Documentation for OrchestratorClient...] + +## Development + +### Running Tests + +```bash +# Install development dependencies +pip install -r requirements-dev.txt + +# Run tests +pytest tests/ +``` + +### Building from Source + +```bash +# Build in release mode +maturin build --release + +# Build and install locally +maturin develop +``` \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/validator_list_nodes.py b/crates/prime-protocol-py/examples/validator_list_nodes.py new file mode 100644 index 00000000..c8398c92 --- /dev/null +++ b/crates/prime-protocol-py/examples/validator_list_nodes.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Validator Client to list non-validated nodes.""" + +import os +import logging +from typing import List +from primeprotocol import ValidatorClient + +# Configure logging +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def print_node_summary(nodes: List) -> None: + """Print a summary of nodes""" + print(f"\nTotal nodes found: {len(nodes)}") + + if not nodes: + print("No non-validated nodes found.") + return + + print("\nNon-validated nodes:") + print("-" * 80) + + for idx, node in enumerate(nodes, 1): + print(f"\n{idx}. Node ID: {node.id}") + print(f" Provider Address: {node.provider_address}") + print(f" IP: {node.ip_address}:{node.port}") + print(f" Compute Pool ID: {node.compute_pool_id}") + print(f" Active: {node.is_active}") + print(f" Whitelisted: {node.is_provider_whitelisted}") + print(f" Blacklisted: {node.is_blacklisted}") + + if node.worker_p2p_id: + print(f" P2P ID: {node.worker_p2p_id}") + + if node.created_at: + print(f" Created At: {node.created_at}") + + if node.last_updated: + print(f" Last Updated: {node.last_updated}") + + +def main(): + # Get configuration from environment variables + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + private_key = os.getenv("VALIDATOR_PRIVATE_KEY") + discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") + discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] + + if not private_key: + print("Error: VALIDATOR_PRIVATE_KEY environment variable is required") + return + + try: + # Initialize the validator client + print(f"Initializing validator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + + validator = ValidatorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls + ) + + # List all non-validated nodes + print("\nFetching non-validated nodes from discovery service...") + non_validated_nodes = validator.list_non_validated_nodes() + + # Print summary + print_node_summary(non_validated_nodes) + + # You can also get all nodes as dictionaries for more flexibility + print("\n\nFetching all nodes as dictionaries...") + all_nodes = validator.list_all_nodes_dict() + + # Count validated vs non-validated + validated_count = sum(1 for node in all_nodes if node['is_validated']) + non_validated_count = len(all_nodes) - validated_count + + print(f"\nTotal nodes: {len(all_nodes)}") + print(f"Validated: {validated_count}") + print(f"Non-validated: {non_validated_count}") + + except Exception as e: + logging.error(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index ff649e32..cf636cba 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -1,86 +1,282 @@ use pyo3::prelude::*; +use pythonize::pythonize; +use shared::models::node::DiscoveryNode; +use shared::security::request_signer::sign_request_with_nonce; +use shared::web3::wallet::Wallet; +use std::time::Duration; +use url::Url; /// Node details for validator operations #[pyclass] #[derive(Clone)] pub(crate) struct NodeDetails { #[pyo3(get)] - pub address: String, + pub id: String, + #[pyo3(get)] + pub provider_address: String, + #[pyo3(get)] + pub ip_address: String, + #[pyo3(get)] + pub port: u16, + #[pyo3(get)] + pub compute_pool_id: u32, + #[pyo3(get)] + pub is_validated: bool, + #[pyo3(get)] + pub is_active: bool, + #[pyo3(get)] + pub is_provider_whitelisted: bool, + #[pyo3(get)] + pub is_blacklisted: bool, + #[pyo3(get)] + pub worker_p2p_id: Option, + #[pyo3(get)] + pub last_updated: Option, + #[pyo3(get)] + pub created_at: Option, +} + +impl From for NodeDetails { + fn from(node: DiscoveryNode) -> Self { + Self { + id: node.node.id, + provider_address: node.node.provider_address, + ip_address: node.node.ip_address, + port: node.node.port, + compute_pool_id: node.node.compute_pool_id, + is_validated: node.is_validated, + is_active: node.is_active, + is_provider_whitelisted: node.is_provider_whitelisted, + is_blacklisted: node.is_blacklisted, + worker_p2p_id: node.node.worker_p2p_id, + last_updated: node.last_updated.map(|dt| dt.to_rfc3339()), + created_at: node.created_at.map(|dt| dt.to_rfc3339()), + } + } } #[pymethods] impl NodeDetails { - #[new] - pub fn new(address: String) -> Self { - Self { address } + /// Get compute specs as a Python dictionary + pub fn get_compute_specs(&self, py: Python) -> PyResult { + // This would need access to the original DiscoveryNode's compute_specs + // For now returning None + Ok(py.None()) + } + + /// Get location as a Python dictionary + pub fn get_location(&self, py: Python) -> PyResult { + // This would need access to the original DiscoveryNode's location + // For now returning None + Ok(py.None()) } } -/// Prime Protocol Validator Client - for validating task results +/// Prime Protocol Validator Client - for validating nodes and tasks #[pyclass] pub(crate) struct ValidatorClient { runtime: Option, + wallet: Option, + discovery_urls: Vec, } #[pymethods] impl ValidatorClient { #[new] - #[pyo3(signature = (rpc_url, private_key=None))] - pub fn new(rpc_url: String, private_key: Option) -> PyResult { - // TODO: Implement validator initialization - let _ = rpc_url; - let _ = private_key; - - Ok(Self { runtime: None }) - } + #[pyo3(signature = (rpc_url, private_key, discovery_urls=vec!["http://localhost:8089".to_string()]))] + pub fn new( + rpc_url: String, + private_key: String, + discovery_urls: Vec, + ) -> PyResult { + let rpc_url = Url::parse(&rpc_url).map_err(|e| { + PyErr::new::(format!("Invalid RPC URL: {}", e)) + })?; + let wallet = Wallet::new(&private_key, rpc_url) + .map_err(|e| PyErr::new::(e.to_string()))?; - /// Initialize the validator client and start listening for messages - pub fn start(&mut self, _py: Python) -> PyResult<()> { - // Create a new runtime for this validator - let rt = tokio::runtime::Builder::new_multi_thread() + let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .map_err(|e| PyErr::new::(e.to_string()))?; - // Store the runtime for future use - self.runtime = Some(rt); + Ok(Self { + runtime: Some(runtime), + wallet: Some(wallet), + discovery_urls, + }) + } - Ok(()) + /// List all nodes that are not validated yet + pub fn list_non_validated_nodes(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })?; + + let discovery_urls = self.discovery_urls.clone(); + + // Release the GIL while performing async operations + py.allow_threads(|| { + rt.block_on(async { + self.fetch_non_validated_nodes(&discovery_urls, wallet) + .await + }) + }) } - pub fn list_nodes(&self) -> PyResult> { - // TODO: Implement validator node listing from chain that are not yet validated - Ok(vec![]) + /// List all nodes with their details as Python dictionaries + pub fn list_all_nodes_dict(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })?; + + let discovery_urls = self.discovery_urls.clone(); + + // Release the GIL while performing async operations + let nodes = py.allow_threads(|| { + rt.block_on(async { self.fetch_all_nodes(&discovery_urls, wallet).await }) + })?; + + // Convert to Python dictionaries + let result: Result, _> = + nodes.into_iter().map(|node| pythonize(py, &node)).collect(); + + let python_objects = + result.map_err(|e| PyErr::new::(e.to_string()))?; + + // Convert Bound to Py + let py_objects: Vec = python_objects + .into_iter() + .map(|bound| bound.into()) + .collect(); + + Ok(py_objects) } - pub fn fetch_node_details(&self, _node_id: String) -> PyResult> { - // TODO: Implement validator node details fetching - Ok(None) + /// Get the number of non-validated nodes + pub fn get_non_validated_count(&self, py: Python) -> PyResult { + let nodes = self.list_non_validated_nodes(py)?; + Ok(nodes.len()) } - pub fn mark_node_as_validated(&self, _node_id: String) -> PyResult<()> { - // TODO: Implement validator node marking as validated + /// Initialize the validator client + pub fn start(&mut self, _py: Python) -> PyResult<()> { + self.get_or_create_runtime()?; Ok(()) } +} - pub fn send_request_to_node(&self, _node_id: String, _request: String) -> PyResult<()> { - // TODO: Implement validator node request sending - Ok(()) +// Private implementation methods +impl ValidatorClient { + fn get_or_create_runtime(&self) -> PyResult<&tokio::runtime::Runtime> { + if let Some(ref rt) = self.runtime { + Ok(rt) + } else { + Err(PyErr::new::( + "Runtime not initialized. Call start() first.", + )) + } } - pub fn send_request_to_node_address( + async fn fetch_non_validated_nodes( &self, - node_address: String, - request: String, - ) -> PyResult<()> { - // TODO: Implement validator node request sending to specific address - let _ = node_address; - let _ = request; - Ok(()) + discovery_urls: &[String], + wallet: &Wallet, + ) -> PyResult> { + let nodes = self.fetch_all_nodes(discovery_urls, wallet).await?; + Ok(nodes + .into_iter() + .filter(|node| !node.is_validated) + .map(NodeDetails::from) + .collect()) + } + + async fn fetch_all_nodes( + &self, + discovery_urls: &[String], + wallet: &Wallet, + ) -> PyResult> { + let mut all_nodes = Vec::new(); + let mut any_success = false; + + for discovery_url in discovery_urls { + match self + .fetch_nodes_from_discovery_url(discovery_url, wallet) + .await + { + Ok(nodes) => { + all_nodes.extend(nodes); + any_success = true; + } + Err(e) => { + // Log error but continue with other discovery services + log::error!("Failed to fetch nodes from {}: {}", discovery_url, e); + } + } + } + + if !any_success { + return Err(PyErr::new::( + "Failed to fetch nodes from all discovery services", + )); + } + + // Remove duplicates based on node ID + let mut unique_nodes = Vec::new(); + let mut seen_ids = std::collections::HashSet::new(); + for node in all_nodes { + if seen_ids.insert(node.node.id.clone()) { + unique_nodes.push(node); + } + } + + Ok(unique_nodes) } - /// Get the number of pending validation results - pub fn get_queue_size(&self) -> usize { - 0 + async fn fetch_nodes_from_discovery_url( + &self, + discovery_url: &str, + wallet: &Wallet, + ) -> Result, anyhow::Error> { + let address = wallet.wallet.default_signer().address().to_string(); + + let discovery_route = "/api/validator"; + let signature = sign_request_with_nonce(discovery_route, wallet, None) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "x-address", + reqwest::header::HeaderValue::from_str(&address)?, + ); + headers.insert( + "x-signature", + reqwest::header::HeaderValue::from_str(&signature.signature)?, + ); + + let response = reqwest::Client::new() + .get(format!("{}{}", discovery_url, discovery_route)) + .query(&[("nonce", signature.nonce)]) + .headers(headers) + .timeout(Duration::from_secs(10)) + .send() + .await?; + + let response_text = response.text().await?; + let parsed_response: shared::models::api::ApiResponse> = + serde_json::from_str(&response_text)?; + + if !parsed_response.success { + return Err(anyhow::anyhow!( + "Failed to fetch nodes from {}: {:?}", + discovery_url, + parsed_response + )); + } + + Ok(parsed_response.data) } } diff --git a/crates/prime-protocol-py/tests/test_validator.py b/crates/prime-protocol-py/tests/test_validator.py new file mode 100644 index 00000000..e15edeab --- /dev/null +++ b/crates/prime-protocol-py/tests/test_validator.py @@ -0,0 +1,139 @@ +"""Tests for the validator client.""" + +import pytest +import os +from unittest.mock import patch, Mock + + +def test_validator_client_creation(): + """Test creating a validator client requires proper parameters.""" + # Mock the primeprotocol module since it might not be built + with patch('primeprotocol.ValidatorClient') as MockValidator: + mock_instance = Mock() + MockValidator.return_value = mock_instance + + # Test creation with required parameters + rpc_url = "http://localhost:8545" + private_key = "0x1234567890abcdef" + discovery_urls = ["http://localhost:8089"] + + from primeprotocol import ValidatorClient + validator = ValidatorClient(rpc_url, private_key, discovery_urls) + + # Verify the constructor was called with correct parameters + MockValidator.assert_called_once_with(rpc_url, private_key, discovery_urls) + + +def test_list_non_validated_nodes(): + """Test listing non-validated nodes.""" + with patch('primeprotocol.ValidatorClient') as MockValidator: + # Create mock node data + mock_node1 = Mock() + mock_node1.id = "node1" + mock_node1.provider_address = "0xabc123" + mock_node1.ip_address = "192.168.1.1" + mock_node1.port = 8080 + mock_node1.compute_pool_id = 1 + mock_node1.is_validated = False + mock_node1.is_active = True + mock_node1.is_provider_whitelisted = True + mock_node1.is_blacklisted = False + mock_node1.worker_p2p_id = "p2p_id_1" + mock_node1.created_at = "2024-01-01T00:00:00Z" + mock_node1.last_updated = "2024-01-02T00:00:00Z" + + mock_instance = Mock() + mock_instance.list_non_validated_nodes.return_value = [mock_node1] + MockValidator.return_value = mock_instance + + from primeprotocol import ValidatorClient + validator = ValidatorClient("http://localhost:8545", "0x123", ["http://localhost:8089"]) + + # Get non-validated nodes + nodes = validator.list_non_validated_nodes() + + # Verify results + assert len(nodes) == 1 + assert nodes[0].id == "node1" + assert nodes[0].is_validated == False + assert nodes[0].is_active == True + + +def test_list_all_nodes_dict(): + """Test listing all nodes as dictionaries.""" + with patch('primeprotocol.ValidatorClient') as MockValidator: + # Create mock response + mock_nodes = [ + { + 'id': 'node1', + 'provider_address': '0xabc123', + 'is_validated': False, + 'is_active': True, + 'node': { + 'compute_specs': { + 'gpu': { + 'count': 4, + 'model': 'NVIDIA A100', + 'memory_mb': 40000 + }, + 'cpu': { + 'cores': 32 + }, + 'ram_mb': 128000, + 'storage_gb': 1000 + } + } + }, + { + 'id': 'node2', + 'provider_address': '0xdef456', + 'is_validated': True, + 'is_active': True, + 'node': { + 'compute_specs': None + } + } + ] + + mock_instance = Mock() + mock_instance.list_all_nodes_dict.return_value = mock_nodes + MockValidator.return_value = mock_instance + + from primeprotocol import ValidatorClient + validator = ValidatorClient("http://localhost:8545", "0x123", ["http://localhost:8089"]) + + # Get all nodes + nodes = validator.list_all_nodes_dict() + + # Verify results + assert len(nodes) == 2 + assert nodes[0]['is_validated'] == False + assert nodes[1]['is_validated'] == True + + # Check compute specs + assert nodes[0]['node']['compute_specs']['gpu']['count'] == 4 + assert nodes[0]['node']['compute_specs']['gpu']['model'] == 'NVIDIA A100' + + +def test_get_non_validated_count(): + """Test getting count of non-validated nodes.""" + with patch('primeprotocol.ValidatorClient') as MockValidator: + mock_instance = Mock() + # Mock list_non_validated_nodes to return 3 nodes + mock_nodes = [Mock() for _ in range(3)] + mock_instance.list_non_validated_nodes.return_value = mock_nodes + mock_instance.get_non_validated_count.return_value = 3 + MockValidator.return_value = mock_instance + + from primeprotocol import ValidatorClient + validator = ValidatorClient("http://localhost:8545", "0x123", ["http://localhost:8089"]) + + # Get count + count = validator.get_non_validated_count() + + # Verify result + assert count == 3 + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file From 82ae84ef91bb73c09eda58c55324d58c5d7b1d3b Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 14:58:37 +0200 Subject: [PATCH 10/23] move p2p logic outside of py-sdk worker for sharing with validator + orchestrator --- .../src/{worker => }/constants.rs | 0 crates/prime-protocol-py/src/lib.rs | 2 + .../src/{worker => p2p_handler}/auth.rs | 2 +- .../message_processor.rs | 8 +- .../src/{worker/p2p.rs => p2p_handler/mod.rs} | 5 +- .../src/worker/blockchain.rs | 4 +- crates/prime-protocol-py/src/worker/client.rs | 10 +- crates/prime-protocol-py/src/worker/mod.rs | 20 +-- .../tests/integration/test_auth_flow.py | 147 ------------------ crates/shared/src/discovery/mod.rs | 1 - crates/worker/src/services/discovery.rs | 7 +- 11 files changed, 27 insertions(+), 179 deletions(-) rename crates/prime-protocol-py/src/{worker => }/constants.rs (100%) rename crates/prime-protocol-py/src/{worker => p2p_handler}/auth.rs (99%) rename crates/prime-protocol-py/src/{worker => p2p_handler}/message_processor.rs (97%) rename crates/prime-protocol-py/src/{worker/p2p.rs => p2p_handler/mod.rs} (99%) delete mode 100644 crates/prime-protocol-py/tests/integration/test_auth_flow.py diff --git a/crates/prime-protocol-py/src/worker/constants.rs b/crates/prime-protocol-py/src/constants.rs similarity index 100% rename from crates/prime-protocol-py/src/worker/constants.rs rename to crates/prime-protocol-py/src/constants.rs diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/prime-protocol-py/src/lib.rs index 0715c33a..b5c6bb2b 100644 --- a/crates/prime-protocol-py/src/lib.rs +++ b/crates/prime-protocol-py/src/lib.rs @@ -3,8 +3,10 @@ use crate::validator::ValidatorClient; use crate::worker::WorkerClient; use pyo3::prelude::*; +mod constants; mod error; mod orchestrator; +mod p2p_handler; mod utils; mod validator; mod worker; diff --git a/crates/prime-protocol-py/src/worker/auth.rs b/crates/prime-protocol-py/src/p2p_handler/auth.rs similarity index 99% rename from crates/prime-protocol-py/src/worker/auth.rs rename to crates/prime-protocol-py/src/p2p_handler/auth.rs index 1c31ca25..f8ecaa70 100644 --- a/crates/prime-protocol-py/src/worker/auth.rs +++ b/crates/prime-protocol-py/src/p2p_handler/auth.rs @@ -1,5 +1,5 @@ use crate::error::{PrimeProtocolError, Result}; -use crate::worker::p2p::Message; +use crate::p2p_handler::Message; use alloy::primitives::{Address, Signature}; use rand::Rng; use shared::security::request_signer::sign_message; diff --git a/crates/prime-protocol-py/src/worker/message_processor.rs b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs similarity index 97% rename from crates/prime-protocol-py/src/worker/message_processor.rs rename to crates/prime-protocol-py/src/p2p_handler/message_processor.rs index 4cf40424..543cfde3 100644 --- a/crates/prime-protocol-py/src/worker/message_processor.rs +++ b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs @@ -1,6 +1,6 @@ use crate::error::Result; -use crate::worker::auth::AuthenticationManager; -use crate::worker::p2p::{Message, MessageType}; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::{Message, MessageType}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -50,7 +50,7 @@ impl MessageProcessor { message_result = async { let mut rx = self.message_queue_rx.lock().await; tokio::time::timeout( - crate::worker::constants::MESSAGE_QUEUE_TIMEOUT, + crate::constants::MESSAGE_QUEUE_TIMEOUT, rx.recv() ).await } => { @@ -63,6 +63,8 @@ impl MessageProcessor { Err(_) => continue, // Timeout, continue loop }; + log::debug!("Received message: {:?}", message); + if let Err(e) = self.process_message(message).await { log::error!("Failed to process message: {}", e); } diff --git a/crates/prime-protocol-py/src/worker/p2p.rs b/crates/prime-protocol-py/src/p2p_handler/mod.rs similarity index 99% rename from crates/prime-protocol-py/src/worker/p2p.rs rename to crates/prime-protocol-py/src/p2p_handler/mod.rs index 1bf9d449..6952aac5 100644 --- a/crates/prime-protocol-py/src/worker/p2p.rs +++ b/crates/prime-protocol-py/src/p2p_handler/mod.rs @@ -10,7 +10,10 @@ use tokio::sync::{ }; use tokio_util::sync::CancellationToken; -use crate::worker::constants::{MESSAGE_QUEUE_CHANNEL_SIZE, P2P_CHANNEL_SIZE}; +use crate::constants::{MESSAGE_QUEUE_CHANNEL_SIZE, P2P_CHANNEL_SIZE}; + +pub(crate) mod auth; +pub(crate) mod message_processor; // Type alias for the complex return type of Service::new type ServiceNewResult = Result<( diff --git a/crates/prime-protocol-py/src/worker/blockchain.rs b/crates/prime-protocol-py/src/worker/blockchain.rs index e1d5479f..c03d4e09 100644 --- a/crates/prime-protocol-py/src/worker/blockchain.rs +++ b/crates/prime-protocol-py/src/worker/blockchain.rs @@ -8,7 +8,7 @@ use shared::web3::contracts::structs::compute_pool::PoolStatus; use shared::web3::wallet::{Wallet, WalletProvider}; use url::Url; -use crate::worker::constants::{BLOCKCHAIN_OPERATION_TIMEOUT, DEFAULT_COMPUTE_UNITS}; +use crate::constants::{BLOCKCHAIN_OPERATION_TIMEOUT, DEFAULT_COMPUTE_UNITS}; /// Configuration for blockchain operations pub struct BlockchainConfig { @@ -109,7 +109,7 @@ impl BlockchainService { self.config.compute_pool_id, pool.status ); - tokio::time::sleep(crate::worker::constants::POOL_STATUS_CHECK_INTERVAL).await; + tokio::time::sleep(crate::constants::POOL_STATUS_CHECK_INTERVAL).await; } Err(e) => { return Err(anyhow::anyhow!("Failed to get pool info: {}", e)); diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index 497cd261..73b91196 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -1,11 +1,9 @@ +use crate::constants::{DEFAULT_FUNDING_RETRY_COUNT, MESSAGE_QUEUE_TIMEOUT, P2P_SHUTDOWN_TIMEOUT}; use crate::error::{PrimeProtocolError, Result}; -use crate::worker::auth::AuthenticationManager; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::message_processor::MessageProcessor; use crate::worker::blockchain::{BlockchainConfig, BlockchainService}; -use crate::worker::constants::{ - DEFAULT_FUNDING_RETRY_COUNT, MESSAGE_QUEUE_TIMEOUT, P2P_SHUTDOWN_TIMEOUT, -}; -use crate::worker::message_processor::MessageProcessor; -use crate::worker::p2p::{Message, MessageType, Service as P2PService}; +use crate::worker::p2p_handler::{Message, MessageType, Service as P2PService}; use p2p::{Keypair, PeerId}; use std::sync::Arc; use tokio::sync::mpsc::{Receiver, Sender}; diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index 13d6150a..aaeb3cb4 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -1,19 +1,15 @@ use pyo3::prelude::*; -mod auth; mod blockchain; mod client; -mod constants; mod discovery; -mod message_processor; -mod p2p; +use crate::constants::DEFAULT_P2P_PORT; +use crate::p2p_handler; +use crate::p2p_handler::Message; pub(crate) use client::WorkerClientCore; use tokio_util::sync::CancellationToken; -use crate::worker::p2p::Message; -use constants::DEFAULT_P2P_PORT; - /// Prime Protocol Worker Client - for compute nodes that execute tasks #[pyclass] pub(crate) struct WorkerClient { @@ -87,7 +83,7 @@ impl WorkerClient { let rt = self.ensure_runtime()?; let message = Message { - message_type: p2p::MessageType::General { data }, + message_type: p2p_handler::MessageType::General { data }, peer_id, multiaddrs, sender_address: None, // Will be filled from our wallet automatically @@ -192,19 +188,19 @@ fn to_py_runtime_err(msg: &str) -> PyErr { fn message_to_pyobject(message: Message) -> PyObject { let message_data = match message.message_type { - p2p::MessageType::General { data } => { + p2p_handler::MessageType::General { data } => { serde_json::json!({ "type": "general", "data": data, }) } - p2p::MessageType::AuthenticationInitiation { challenge } => { + p2p_handler::MessageType::AuthenticationInitiation { challenge } => { serde_json::json!({ "type": "auth_initiation", "challenge": challenge, }) } - p2p::MessageType::AuthenticationResponse { + p2p_handler::MessageType::AuthenticationResponse { challenge, signature, } => { @@ -214,7 +210,7 @@ fn message_to_pyobject(message: Message) -> PyObject { "signature": signature, }) } - p2p::MessageType::AuthenticationSolution { signature } => { + p2p_handler::MessageType::AuthenticationSolution { signature } => { serde_json::json!({ "type": "auth_solution", "signature": signature, diff --git a/crates/prime-protocol-py/tests/integration/test_auth_flow.py b/crates/prime-protocol-py/tests/integration/test_auth_flow.py deleted file mode 100644 index 54d20137..00000000 --- a/crates/prime-protocol-py/tests/integration/test_auth_flow.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -""" -Integration test for the authentication flow between two nodes. - -This test demonstrates: -1. Two nodes starting up with their own wallets -2. First message triggers authentication -3. Addresses are extracted from signatures -4. Subsequent messages are sent directly -""" - -import asyncio -import pytest -import time -from primeprotocol import WorkerClient -import logging - -# Set up logging -logging.basicConfig(level=logging.DEBUG) - - -@pytest.fixture -def setup_test_environment(): - """Set up test environment with Anvil keys""" - # Anvil test keys - node_a_key = "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80" - node_b_key = "0x59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d" - provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a" - - return { - "rpc_url": "http://localhost:8545", - "pool_id": 0, - "node_a": {"node_key": node_a_key, "provider_key": provider_key, "port": 9001}, - "node_b": {"node_key": node_b_key, "provider_key": provider_key, "port": 9002}, - } - - -@pytest.mark.asyncio -async def test_authentication_flow(setup_test_environment): - """Test the full authentication flow between two nodes""" - env = setup_test_environment - - # Create two worker clients - client_a = WorkerClient( - compute_pool_id=env["pool_id"], - rpc_url=env["rpc_url"], - private_key_provider=env["node_a"]["provider_key"], - private_key_node=env["node_a"]["node_key"], - p2p_port=env["node_a"]["port"] - ) - - client_b = WorkerClient( - compute_pool_id=env["pool_id"], - rpc_url=env["rpc_url"], - private_key_provider=env["node_b"]["provider_key"], - private_key_node=env["node_b"]["node_key"], - p2p_port=env["node_b"]["port"] - ) - - try: - # Start both clients - logging.info("Starting client A...") - client_a.start() - peer_a_id = client_a.get_own_peer_id() - logging.info(f"Client A started with peer ID: {peer_a_id}") - - logging.info("Starting client B...") - client_b.start() - peer_b_id = client_b.get_own_peer_id() - logging.info(f"Client B started with peer ID: {peer_b_id}") - - # Give nodes time to start - await asyncio.sleep(2) - - # Test 1: First message from A to B (triggers authentication) - logging.info("Test 1: Sending first message from A to B...") - client_a.send_message( - peer_id=peer_b_id, - data=b"Hello from A! This is the first message.", - multiaddrs=[f"/ip4/127.0.0.1/tcp/{env['node_b']['port']}"] - ) - - # Check if B receives the message - message = None - for _ in range(50): # Poll for up to 5 seconds - message = client_b.get_next_message() - if message: - break - await asyncio.sleep(0.1) - - assert message is not None, "Client B did not receive message" - assert message["message"]["type"] == "general" - assert bytes(message["message"]["data"]) == b"Hello from A! This is the first message." - assert message["sender_address"] is not None # Should have sender's address - logging.info(f"Client B received message with sender address: {message['sender_address']}") - - # Test 2: Second message from A to B (should be direct, no auth) - logging.info("Test 2: Sending second message from A to B...") - client_a.send_message( - peer_id=peer_b_id, - data=b"Second message - should be direct", - multiaddrs=[f"/ip4/127.0.0.1/tcp/{env['node_b']['port']}"] - ) - - # Should receive quickly since already authenticated - message = None - for _ in range(20): # Should be faster - message = client_b.get_next_message() - if message: - break - await asyncio.sleep(0.1) - - assert message is not None, "Client B did not receive second message" - assert bytes(message["message"]["data"]) == b"Second message - should be direct" - - # Test 3: Message from B to A (first message in this direction) - logging.info("Test 3: Sending message from B to A...") - client_b.send_message( - peer_id=peer_a_id, - data=b"Hello from B!", - multiaddrs=[f"/ip4/127.0.0.1/tcp/{env['node_a']['port']}"] - ) - - # Check if A receives the message - message = None - for _ in range(50): - message = client_a.get_next_message() - if message: - break - await asyncio.sleep(0.1) - - assert message is not None, "Client A did not receive message" - assert bytes(message["message"]["data"]) == b"Hello from B!" - assert message["sender_address"] is not None - logging.info(f"Client A received message with sender address: {message['sender_address']}") - - logging.info("All tests passed! Authentication flow working correctly.") - - finally: - # Clean up - client_a.stop() - client_b.stop() - - -if __name__ == "__main__": - # Run the test - asyncio.run(test_authentication_flow({})) \ No newline at end of file diff --git a/crates/shared/src/discovery/mod.rs b/crates/shared/src/discovery/mod.rs index 6e254cc2..af281e6c 100644 --- a/crates/shared/src/discovery/mod.rs +++ b/crates/shared/src/discovery/mod.rs @@ -166,7 +166,6 @@ async fn upload_to_single_discovery( node_data: &serde_json::Value, wallet: &Wallet, ) -> Result<()> { - let signed_request = sign_request_with_nonce(endpoint, wallet, Some(node_data)) .await .map_err(|e| anyhow::anyhow!("{}", e))?; diff --git a/crates/worker/src/services/discovery.rs b/crates/worker/src/services/discovery.rs index a0e60f79..e4440a0a 100644 --- a/crates/worker/src/services/discovery.rs +++ b/crates/worker/src/services/discovery.rs @@ -23,12 +23,7 @@ impl DiscoveryService { pub(crate) async fn upload_discovery_info(&self, node_config: &Node) -> Result<()> { let node_data = serde_json::to_value(node_config)?; - shared::discovery::upload_node_to_discovery( - &self.base_urls, - &node_data, - &self.wallet, - ) - .await + shared::discovery::upload_node_to_discovery(&self.base_urls, &node_data, &self.wallet).await } } From 86bdde1052079cbd784e52797bb5a9172c877560 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 15:25:35 +0200 Subject: [PATCH 11/23] share p2p functionality between all three components to easily send messages --- crates/prime-protocol-py/README.md | 122 -------- .../prime-protocol-py/src/orchestrator/mod.rs | 284 ++++++++++++++++-- .../src/p2p_handler/common.rs | 81 +++++ .../src/p2p_handler/message_processor.rs | 29 ++ .../prime-protocol-py/src/p2p_handler/mod.rs | 3 + crates/prime-protocol-py/src/validator/mod.rs | 205 ++++++++++++- crates/prime-protocol-py/src/worker/client.rs | 89 ++---- 7 files changed, 590 insertions(+), 223 deletions(-) create mode 100644 crates/prime-protocol-py/src/p2p_handler/common.rs diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md index cd458fc9..e69de29b 100644 --- a/crates/prime-protocol-py/README.md +++ b/crates/prime-protocol-py/README.md @@ -1,122 +0,0 @@ -# Prime Protocol Python SDK - -Python bindings for the Prime Protocol, providing easy access to worker, orchestrator, and validator functionalities. - -## Installation - -```bash -# Clone the repository -git clone https://github.com/primeprotocol/protocol.git -cd protocol/crates/prime-protocol-py - -# Build and install the package -pip install maturin -maturin develop -``` - -## Validator Client - -The Validator Client allows validators to interact with the Prime Protocol network, particularly for listing and validating nodes. - -### Basic Usage - -```python -from primeprotocol import ValidatorClient - -# Initialize the validator client -validator = ValidatorClient( - rpc_url="http://localhost:8545", - private_key="YOUR_PRIVATE_KEY", - discovery_urls=["http://localhost:8089"] # Can specify multiple discovery services -) - -# List all non-validated nodes -non_validated_nodes = validator.list_non_validated_nodes() - -for node in non_validated_nodes: - print(f"Node ID: {node.id}") - print(f"Provider: {node.provider_address}") - print(f"Address: {node.ip_address}:{node.port}") - print(f"Active: {node.is_active}") - print(f"Whitelisted: {node.is_provider_whitelisted}") - print() - -# Get count of non-validated nodes -count = validator.get_non_validated_count() -print(f"Total non-validated nodes: {count}") - -# Get all nodes as dictionaries (includes compute specs) -all_nodes = validator.list_all_nodes_dict() -for node in all_nodes: - if not node['is_validated']: - print(f"Node {node['id']} needs validation") - - # Access compute specs if available - if 'node' in node and 'compute_specs' in node['node']: - specs = node['node']['compute_specs'] - if specs: - print(f" RAM: {specs.get('ram_mb', 'N/A')} MB") - print(f" Storage: {specs.get('storage_gb', 'N/A')} GB") -``` - -### Node Details - -Each node returned by `list_non_validated_nodes()` has the following attributes: - -- `id`: Unique identifier for the node -- `provider_address`: On-chain address of the node provider -- `ip_address`: IP address of the node -- `port`: Port number for node communication -- `compute_pool_id`: ID of the compute pool the node belongs to -- `is_validated`: Whether the node has been validated -- `is_active`: Whether the node is currently active -- `is_provider_whitelisted`: Whether the provider is whitelisted -- `is_blacklisted`: Whether the node is blacklisted -- `worker_p2p_id`: P2P identifier (optional) -- `last_updated`: Last update timestamp (optional) -- `created_at`: Creation timestamp (optional) - -### Environment Variables - -The validator client can be configured using environment variables: - -```bash -export RPC_URL="http://localhost:8545" -export VALIDATOR_PRIVATE_KEY="your_private_key_here" -export DISCOVERY_URLS="http://localhost:8089,http://backup-discovery:8089" -``` - -### Example Scripts - -See the `examples/` directory for complete examples: -- `validator_list_nodes.py`: Lists all non-validated nodes with detailed output - -## Worker Client - -[Documentation for WorkerClient...] - -## Orchestrator Client - -[Documentation for OrchestratorClient...] - -## Development - -### Running Tests - -```bash -# Install development dependencies -pip install -r requirements-dev.txt - -# Run tests -pytest tests/ -``` - -### Building from Source - -```bash -# Build in release mode -maturin build --release - -# Build and install locally -maturin develop -``` \ No newline at end of file diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index c610ea6f..4c8e5be3 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -1,9 +1,29 @@ +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; +use crate::p2p_handler::{Message, MessageType, Service as P2PService}; +use p2p::{Keypair, PeerId}; use pyo3::prelude::*; +use pythonize::pythonize; +use shared::web3::wallet::Wallet; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use url::Url; /// Prime Protocol Orchestrator Client - for managing and distributing tasks #[pyclass] pub struct OrchestratorClient { - // TODO: Implement orchestrator-specific functionality + runtime: Option, + wallet: Option, + cancellation_token: CancellationToken, + // P2P fields + auth_manager: Option>, + outbound_tx: Option>>>, + user_message_rx: Option>>>, + message_processor_handle: Option>, + peer_id: Option, } #[pymethods] @@ -11,10 +31,174 @@ impl OrchestratorClient { #[new] #[pyo3(signature = (rpc_url, private_key=None))] pub fn new(rpc_url: String, private_key: Option) -> PyResult { - // TODO: Implement orchestrator initialization - let _ = rpc_url; - let _ = private_key; - Ok(Self {}) + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + let cancellation_token = CancellationToken::new(); + + // Create wallet if private key is provided + let wallet = if let Some(key) = private_key { + let rpc_url_parsed = Url::parse(&rpc_url).map_err(|e| { + PyErr::new::(format!("Invalid RPC URL: {}", e)) + })?; + Some( + Wallet::new(&key, rpc_url_parsed) + .map_err(|e| PyErr::new::(e.to_string()))?, + ) + } else { + None + }; + + Ok(Self { + runtime: Some(runtime), + wallet, + cancellation_token, + auth_manager: None, + outbound_tx: None, + user_message_rx: None, + message_processor_handle: None, + peer_id: None, + }) + } + + /// Initialize the orchestrator client with optional P2P support + #[pyo3(signature = (p2p_port=None))] + pub fn start(&mut self, py: Python, p2p_port: Option) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + if let Some(port) = p2p_port { + // Initialize P2P if port is provided + let wallet = self + .wallet + .as_ref() + .ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })? + .clone(); + + let cancellation_token = self.cancellation_token.clone(); + + // Create the P2P components + let (auth_manager, peer_id, outbound_tx, user_message_rx, message_processor_handle) = + py.allow_threads(|| { + rt.block_on(async { + Self::create_p2p_components(wallet, port, cancellation_token) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + }) + }) + })?; + + // Update self with the created components + self.auth_manager = Some(auth_manager); + self.peer_id = Some(peer_id); + self.outbound_tx = Some(outbound_tx); + self.user_message_rx = Some(user_message_rx); + self.message_processor_handle = Some(message_processor_handle); + } + + Ok(()) + } + + /// Send a message to a peer + pub fn send_message( + &self, + py: Python, + peer_id: String, + multiaddrs: Vec, + data: Vec, + ) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let outbound_tx = self.outbound_tx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = Message { + message_type: MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: None, + response_tx: None, + }; + + py.allow_threads(|| { + rt.block_on(async { + crate::p2p_handler::send_message_with_auth(message, auth_manager, outbound_tx) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + }) + } + + /// Get the next message from the P2P network + pub fn get_next_message(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let user_message_rx = self.user_message_rx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = py.allow_threads(|| { + rt.block_on(async { + tokio::time::timeout( + crate::constants::MESSAGE_QUEUE_TIMEOUT, + user_message_rx.lock().await.recv(), + ) + .await + .ok() + .flatten() + }) + }); + + match message { + Some(msg) => { + let py_msg = pythonize(py, &msg) + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(Some(py_msg.into())) + } + None => Ok(None), + } + } + + /// Get the orchestrator's peer ID + pub fn get_peer_id(&self) -> PyResult> { + Ok(self.peer_id.map(|id| id.to_string())) + } + + /// Get the orchestrator's wallet address + pub fn get_wallet_address(&self) -> PyResult> { + Ok(self + .wallet + .as_ref() + .map(|w| w.wallet.default_signer().address().to_string())) + } + + /// Stop the orchestrator client + pub fn stop(&mut self, _py: Python) -> PyResult<()> { + self.cancellation_token.cancel(); + + // Stop message processor + if let Some(handle) = self.message_processor_handle.take() { + handle.abort(); + } + + Ok(()) } pub fn list_validated_nodes(&self) -> PyResult> { @@ -26,30 +210,78 @@ impl OrchestratorClient { // TODO: Implement orchestrator node listing from chain Ok(vec![]) } +} - // pub fn get_node_details(&self, node_id: String) -> PyResult> { - // // TODO: Implement orchestrator node details fetching - // Ok(None) - // } +// Private implementation methods +impl OrchestratorClient { + fn get_or_create_runtime(&self) -> PyResult<&tokio::runtime::Runtime> { + if let Some(ref rt) = self.runtime { + Ok(rt) + } else { + Err(PyErr::new::( + "Runtime not initialized. Call start() first.", + )) + } + } + + async fn create_p2p_components( + wallet: Wallet, + port: u16, + cancellation_token: CancellationToken, + ) -> Result< + ( + Arc, + PeerId, + Arc>>, + Arc>>, + JoinHandle<()>, + ), + anyhow::Error, + > { + // Initialize authentication manager + let auth_manager = Arc::new(AuthenticationManager::new(Arc::new(wallet.clone()))); + + // Create P2P service + let keypair = Keypair::generate_ed25519(); + let wallet_address = Some(wallet.wallet.default_signer().address().to_string()); + + let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); - // pub fn get_node_details_from_chain(&self, node_id: String) -> PyResult> { - // // TODO: Implement orchestrator node details fetching from chain - // Ok(None) - // } + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = + P2PService::new(keypair, port, cancellation_token.clone(), wallet_address)?; - // pub fn send_invite_to_node(&self, node_id: String) -> PyResult<()> { - // // TODO: Implement orchestrator node invite sending - // Ok(()) - // } + let peer_id = p2p_service.node.peer_id(); + let outbound_tx = Arc::new(Mutex::new(outbound_tx)); + let user_message_rx = Arc::new(Mutex::new(user_message_rx)); - // pub fn send_request_to_node(&self, node_id: String, request: String) -> PyResult<()> { - // // TODO: Implement orchestrator node request sending - // Ok(()) - // } + // Start P2P service + tokio::task::spawn(p2p_service.run()); - // // TODO: Sender of this message? - // pub fn read_message(&self) -> PyResult> { - // // TODO: Implement orchestrator message reading - // Ok(None) - // } + // Start message processor + let config = MessageProcessorConfig { + auth_manager: auth_manager.clone(), + message_queue_rx: Arc::new(Mutex::new(message_queue_rx)), + user_message_tx, + outbound_tx: outbound_tx.clone(), + authenticated_peers, + cancellation_token, + }; + + let message_processor = MessageProcessor::from_config(config); + let message_processor_handle = message_processor.spawn(); + + log::info!( + "P2P service started on port {} with peer ID: {:?}", + port, + peer_id + ); + + Ok(( + auth_manager, + peer_id, + outbound_tx, + user_message_rx, + message_processor_handle, + )) + } } diff --git a/crates/prime-protocol-py/src/p2p_handler/common.rs b/crates/prime-protocol-py/src/p2p_handler/common.rs new file mode 100644 index 00000000..7fd930ec --- /dev/null +++ b/crates/prime-protocol-py/src/p2p_handler/common.rs @@ -0,0 +1,81 @@ +use crate::error::{PrimeProtocolError, Result}; +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::{Message, MessageType}; +use std::sync::Arc; +use tokio::sync::mpsc::Sender; +use tokio::sync::Mutex; + +/// Shared send_message function that handles authentication and message sending +pub async fn send_message_with_auth( + message: Message, + auth_manager: &Arc, + outbound_tx: &Arc>>, +) -> Result<()> { + log::debug!("Sending message to peer: {}", message.peer_id); + + // Check if we're already authenticated with this peer + if auth_manager.is_authenticated(&message.peer_id).await { + log::debug!( + "Already authenticated with peer {}, sending message directly", + message.peer_id + ); + return outbound_tx.lock().await.send(message).await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to send message: {}", e)) + }); + } + + // Not authenticated yet, check if we have ongoing authentication + log::debug!("Not authenticated with peer {}", message.peer_id); + + // Check if there's already an ongoing auth request + if let Some(role) = auth_manager.get_auth_role(&message.peer_id).await { + match role.as_str() { + "initiator" => { + return Err(PrimeProtocolError::InvalidConfig(format!( + "Already initiated authentication with peer {}", + message.peer_id + ))); + } + "responder" => { + // We're responding to their auth, queue the message + log::debug!( + "Queuing message for peer {} (we're responding to their auth)", + message.peer_id + ); + return auth_manager + .queue_message_as_responder(message.peer_id.clone(), message) + .await; + } + _ => {} + } + } + + // Extract fields we need before moving the message + let peer_id = message.peer_id.clone(); + let multiaddrs = message.multiaddrs.clone(); + + // Start authentication (takes ownership of message) + let auth_challenge = auth_manager + .start_authentication(peer_id.clone(), message) + .await?; + + // Send authentication initiation + let auth_message = Message { + message_type: MessageType::AuthenticationInitiation { + challenge: auth_challenge, + }, + peer_id, + multiaddrs, + sender_address: Some(auth_manager.wallet_address()), + response_tx: None, + }; + + outbound_tx + .lock() + .await + .send(auth_message) + .await + .map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to send auth message: {}", e)) + }) +} diff --git a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs index 543cfde3..e062bfc6 100644 --- a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs +++ b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs @@ -8,8 +8,19 @@ use tokio::sync::{ mpsc::{Receiver, Sender}, Mutex, RwLock, }; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +/// Configuration for creating a MessageProcessor +pub struct MessageProcessorConfig { + pub auth_manager: Arc, + pub message_queue_rx: Arc>>, + pub user_message_tx: Sender, + pub outbound_tx: Arc>>, + pub authenticated_peers: Arc>>, + pub cancellation_token: CancellationToken, +} + /// Handles processing of incoming P2P messages pub struct MessageProcessor { auth_manager: Arc, @@ -39,6 +50,24 @@ impl MessageProcessor { } } + /// Create a MessageProcessor from a config struct + pub fn from_config(config: MessageProcessorConfig) -> Self { + Self::new( + config.auth_manager, + config.message_queue_rx, + config.user_message_tx, + config.outbound_tx, + config.authenticated_peers, + config.cancellation_token, + ) + } + + /// Start the message processor as a background task + /// Returns a JoinHandle that can be used to await or abort the task + pub fn spawn(self) -> JoinHandle<()> { + tokio::task::spawn(self.run()) + } + /// Run the message processing loop pub async fn run(self) { loop { diff --git a/crates/prime-protocol-py/src/p2p_handler/mod.rs b/crates/prime-protocol-py/src/p2p_handler/mod.rs index 6952aac5..2cf4b730 100644 --- a/crates/prime-protocol-py/src/p2p_handler/mod.rs +++ b/crates/prime-protocol-py/src/p2p_handler/mod.rs @@ -13,8 +13,11 @@ use tokio_util::sync::CancellationToken; use crate::constants::{MESSAGE_QUEUE_CHANNEL_SIZE, P2P_CHANNEL_SIZE}; pub(crate) mod auth; +pub(crate) mod common; pub(crate) mod message_processor; +pub use common::send_message_with_auth; + // Type alias for the complex return type of Service::new type ServiceNewResult = Result<( Service, diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index cf636cba..855b4e6e 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -1,9 +1,18 @@ +use crate::p2p_handler::auth::AuthenticationManager; +use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; +use crate::p2p_handler::{Message, MessageType, Service as P2PService}; +use p2p::{Keypair, PeerId}; use pyo3::prelude::*; use pythonize::pythonize; use shared::models::node::DiscoveryNode; use shared::security::request_signer::sign_request_with_nonce; use shared::web3::wallet::Wallet; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use url::Url; /// Node details for validator operations @@ -78,6 +87,13 @@ pub(crate) struct ValidatorClient { runtime: Option, wallet: Option, discovery_urls: Vec, + cancellation_token: CancellationToken, + // P2P fields + auth_manager: Option>, + outbound_tx: Option>>>, + user_message_rx: Option>>>, + message_processor_handle: Option>, + peer_id: Option, } #[pymethods] @@ -89,10 +105,10 @@ impl ValidatorClient { private_key: String, discovery_urls: Vec, ) -> PyResult { - let rpc_url = Url::parse(&rpc_url).map_err(|e| { + let rpc_url_parsed = Url::parse(&rpc_url).map_err(|e| { PyErr::new::(format!("Invalid RPC URL: {}", e)) })?; - let wallet = Wallet::new(&private_key, rpc_url) + let wallet = Wallet::new(&private_key, rpc_url_parsed) .map_err(|e| PyErr::new::(e.to_string()))?; let runtime = tokio::runtime::Builder::new_multi_thread() @@ -100,10 +116,18 @@ impl ValidatorClient { .build() .map_err(|e| PyErr::new::(e.to_string()))?; + let cancellation_token = CancellationToken::new(); + Ok(Self { runtime: Some(runtime), wallet: Some(wallet), discovery_urls, + cancellation_token, + auth_manager: None, + outbound_tx: None, + user_message_rx: None, + message_processor_handle: None, + peer_id: None, }) } @@ -161,11 +185,121 @@ impl ValidatorClient { Ok(nodes.len()) } - /// Initialize the validator client - pub fn start(&mut self, _py: Python) -> PyResult<()> { - self.get_or_create_runtime()?; + /// Initialize the validator client with optional P2P support + #[pyo3(signature = (p2p_port=None))] + pub fn start(&mut self, py: Python, p2p_port: Option) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + if let Some(port) = p2p_port { + // Initialize P2P if port is provided + let wallet = self + .wallet + .as_ref() + .ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })? + .clone(); + + let cancellation_token = self.cancellation_token.clone(); + + // Create the P2P components + let (auth_manager, peer_id, outbound_tx, user_message_rx, message_processor_handle) = + py.allow_threads(|| { + rt.block_on(async { + Self::create_p2p_components(wallet, port, cancellation_token) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + }) + }) + })?; + + // Update self with the created components + self.auth_manager = Some(auth_manager); + self.peer_id = Some(peer_id); + self.outbound_tx = Some(outbound_tx); + self.user_message_rx = Some(user_message_rx); + self.message_processor_handle = Some(message_processor_handle); + } + Ok(()) } + + /// Send a message to a peer + pub fn send_message( + &self, + py: Python, + peer_id: String, + multiaddrs: Vec, + data: Vec, + ) -> PyResult<()> { + let rt = self.get_or_create_runtime()?; + + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let outbound_tx = self.outbound_tx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = Message { + message_type: MessageType::General { data }, + peer_id, + multiaddrs, + sender_address: None, + response_tx: None, + }; + + py.allow_threads(|| { + rt.block_on(async { + crate::p2p_handler::send_message_with_auth(message, auth_manager, outbound_tx) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + }) + } + + /// Get the next message from the P2P network + pub fn get_next_message(&self, py: Python) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let user_message_rx = self.user_message_rx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let message = py.allow_threads(|| { + rt.block_on(async { + tokio::time::timeout( + crate::constants::MESSAGE_QUEUE_TIMEOUT, + user_message_rx.lock().await.recv(), + ) + .await + .ok() + .flatten() + }) + }); + + match message { + Some(msg) => { + let py_msg = pythonize(py, &msg) + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(Some(py_msg.into())) + } + None => Ok(None), + } + } + + /// Get the validator's peer ID + pub fn get_peer_id(&self) -> PyResult> { + Ok(self.peer_id.map(|id| id.to_string())) + } } // Private implementation methods @@ -180,6 +314,67 @@ impl ValidatorClient { } } + async fn create_p2p_components( + wallet: Wallet, + port: u16, + cancellation_token: CancellationToken, + ) -> Result< + ( + Arc, + PeerId, + Arc>>, + Arc>>, + JoinHandle<()>, + ), + anyhow::Error, + > { + // Initialize authentication manager + let auth_manager = Arc::new(AuthenticationManager::new(Arc::new(wallet.clone()))); + + // Create P2P service + let keypair = Keypair::generate_ed25519(); + let wallet_address = Some(wallet.wallet.default_signer().address().to_string()); + + let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); + + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = + P2PService::new(keypair, port, cancellation_token.clone(), wallet_address)?; + + let peer_id = p2p_service.node.peer_id(); + let outbound_tx = Arc::new(Mutex::new(outbound_tx)); + let user_message_rx = Arc::new(Mutex::new(user_message_rx)); + + // Start P2P service + tokio::task::spawn(p2p_service.run()); + + // Start message processor + let config = MessageProcessorConfig { + auth_manager: auth_manager.clone(), + message_queue_rx: Arc::new(Mutex::new(message_queue_rx)), + user_message_tx, + outbound_tx: outbound_tx.clone(), + authenticated_peers, + cancellation_token, + }; + + let message_processor = MessageProcessor::from_config(config); + let message_processor_handle = message_processor.spawn(); + + log::info!( + "P2P service started on port {} with peer ID: {:?}", + port, + peer_id + ); + + Ok(( + auth_manager, + peer_id, + outbound_tx, + user_message_rx, + message_processor_handle, + )) + } + async fn fetch_non_validated_nodes( &self, discovery_urls: &[String], diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index 73b91196..62f87b94 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -1,9 +1,9 @@ use crate::constants::{DEFAULT_FUNDING_RETRY_COUNT, MESSAGE_QUEUE_TIMEOUT, P2P_SHUTDOWN_TIMEOUT}; use crate::error::{PrimeProtocolError, Result}; use crate::p2p_handler::auth::AuthenticationManager; -use crate::p2p_handler::message_processor::MessageProcessor; +use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; use crate::worker::blockchain::{BlockchainConfig, BlockchainService}; -use crate::worker::p2p_handler::{Message, MessageType, Service as P2PService}; +use crate::worker::p2p_handler::{Message, Service as P2PService}; use p2p::{Keypair, PeerId}; use std::sync::Arc; use tokio::sync::mpsc::{Receiver, Sender}; @@ -166,8 +166,6 @@ impl WorkerClientCore { /// Send a message to a peer pub async fn send_message(&self, message: Message) -> Result<()> { - log::debug!("Sending message to peer: {}", message.peer_id); - let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { PrimeProtocolError::InvalidConfig("Authentication manager not initialized".to_string()) })?; @@ -176,66 +174,7 @@ impl WorkerClientCore { PrimeProtocolError::InvalidConfig("P2P service not initialized".to_string()) })?; - // Check if we're already authenticated with this peer - if auth_manager.is_authenticated(&message.peer_id).await { - log::debug!( - "Already authenticated with peer {}, sending message directly", - message.peer_id - ); - return tx.lock().await.send(message).await.map_err(|e| { - PrimeProtocolError::InvalidConfig(format!("Failed to send message: {}", e)) - }); - } - - // Not authenticated yet, check if we have ongoing authentication - log::debug!("Not authenticated with peer {}", message.peer_id); - - // Check if there's already an ongoing auth request - if let Some(role) = auth_manager.get_auth_role(&message.peer_id).await { - match role.as_str() { - "initiator" => { - return Err(PrimeProtocolError::InvalidConfig(format!( - "Already initiated authentication with peer {}", - message.peer_id - ))); - } - "responder" => { - // We're responding to their auth, queue the message - log::debug!( - "Queuing message for peer {} (we're responding to their auth)", - message.peer_id - ); - return auth_manager - .queue_message_as_responder(message.peer_id.clone(), message) - .await; - } - _ => {} - } - } - - // Extract fields we need before moving the message - let peer_id = message.peer_id.clone(); - let multiaddrs = message.multiaddrs.clone(); - - // Start authentication (takes ownership of message) - let auth_challenge = auth_manager - .start_authentication(peer_id.clone(), message) - .await?; - - // Send authentication initiation - let auth_message = Message { - message_type: MessageType::AuthenticationInitiation { - challenge: auth_challenge, - }, - peer_id, - multiaddrs, - sender_address: Some(auth_manager.wallet_address()), - response_tx: None, - }; - - tx.lock().await.send(auth_message).await.map_err(|e| { - PrimeProtocolError::InvalidConfig(format!("Failed to send auth message: {}", e)) - }) + crate::p2p_handler::send_message_with_auth(message, auth_manager, tx).await } // Private helper methods @@ -310,6 +249,19 @@ impl WorkerClientCore { } async fn start_message_processor(&mut self) -> Result<()> { + // Build the message processor configuration + let config = self.build_message_processor_config()?; + + // Create and spawn the message processor + let message_processor = MessageProcessor::from_config(config); + self.message_processor_handle = Some(message_processor.spawn()); + + Ok(()) + } + + /// Build configuration for the message processor + /// This method is public to allow reuse in other crates + pub fn build_message_processor_config(&self) -> Result { let message_queue_rx = self .p2p_state .message_queue_rx @@ -359,17 +311,14 @@ impl WorkerClientCore { })? .clone(); - let message_processor = MessageProcessor::new( + Ok(MessageProcessorConfig { auth_manager, message_queue_rx, user_message_tx, outbound_tx, authenticated_peers, - self.cancellation_token.clone(), - ); - - self.message_processor_handle = Some(tokio::task::spawn(message_processor.run())); - Ok(()) + cancellation_token: self.cancellation_token.clone(), + }) } /// Get the provider's Ethereum address From 31870ad340afa21fc4ac2f5e080ca3d10dbf9301 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 16:03:21 +0200 Subject: [PATCH 12/23] allow validator to validate nodes using sdk --- .../examples/validator_list_nodes.py | 48 ++--- crates/prime-protocol-py/src/validator/mod.rs | 164 +++++++++++++++++- 2 files changed, 175 insertions(+), 37 deletions(-) diff --git a/crates/prime-protocol-py/examples/validator_list_nodes.py b/crates/prime-protocol-py/examples/validator_list_nodes.py index c8398c92..d20e1660 100644 --- a/crates/prime-protocol-py/examples/validator_list_nodes.py +++ b/crates/prime-protocol-py/examples/validator_list_nodes.py @@ -11,37 +11,6 @@ logging.basicConfig(format=FORMAT) logging.getLogger().setLevel(logging.INFO) - -def print_node_summary(nodes: List) -> None: - """Print a summary of nodes""" - print(f"\nTotal nodes found: {len(nodes)}") - - if not nodes: - print("No non-validated nodes found.") - return - - print("\nNon-validated nodes:") - print("-" * 80) - - for idx, node in enumerate(nodes, 1): - print(f"\n{idx}. Node ID: {node.id}") - print(f" Provider Address: {node.provider_address}") - print(f" IP: {node.ip_address}:{node.port}") - print(f" Compute Pool ID: {node.compute_pool_id}") - print(f" Active: {node.is_active}") - print(f" Whitelisted: {node.is_provider_whitelisted}") - print(f" Blacklisted: {node.is_blacklisted}") - - if node.worker_p2p_id: - print(f" P2P ID: {node.worker_p2p_id}") - - if node.created_at: - print(f" Created At: {node.created_at}") - - if node.last_updated: - print(f" Last Updated: {node.last_updated}") - - def main(): # Get configuration from environment variables rpc_url = os.getenv("RPC_URL", "http://localhost:8545") @@ -62,19 +31,28 @@ def main(): validator = ValidatorClient( rpc_url=rpc_url, private_key=private_key, - discovery_urls=discovery_urls + discovery_urls=discovery_urls, ) + print("Starting validator client...") + validator.start() + print("Validator client started") # List all non-validated nodes print("\nFetching non-validated nodes from discovery service...") non_validated_nodes = validator.list_non_validated_nodes() - - # Print summary - print_node_summary(non_validated_nodes) + for node in non_validated_nodes: + print(node.id) + if node.is_validated is False: + print(f"Validating node {node.id}...") + validator.validate_node(node.id, node.provider_address) + print(f"Node {node.id} validated") + else: + print(f"Node {node.id} is already validated") # You can also get all nodes as dictionaries for more flexibility print("\n\nFetching all nodes as dictionaries...") all_nodes = validator.list_all_nodes_dict() + print(all_nodes) # Count validated vs non-validated validated_count = sum(1 for node in all_nodes if node['is_validated']) diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index 855b4e6e..c716448b 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -1,12 +1,14 @@ use crate::p2p_handler::auth::AuthenticationManager; use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; use crate::p2p_handler::{Message, MessageType, Service as P2PService}; +use alloy::primitives::Address; use p2p::{Keypair, PeerId}; use pyo3::prelude::*; use pythonize::pythonize; use shared::models::node::DiscoveryNode; use shared::security::request_signer::sign_request_with_nonce; -use shared::web3::wallet::Wallet; +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::wallet::{Wallet, WalletProvider}; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::{Receiver, Sender}; @@ -94,6 +96,7 @@ pub(crate) struct ValidatorClient { user_message_rx: Option>>>, message_processor_handle: Option>, peer_id: Option, + contracts: Option>>, } #[pymethods] @@ -128,6 +131,7 @@ impl ValidatorClient { user_message_rx: None, message_processor_handle: None, peer_id: None, + contracts: None, }) } @@ -188,7 +192,40 @@ impl ValidatorClient { /// Initialize the validator client with optional P2P support #[pyo3(signature = (p2p_port=None))] pub fn start(&mut self, py: Python, p2p_port: Option) -> PyResult<()> { - let rt = self.get_or_create_runtime()?; + // Initialize contracts if not already done + if self.contracts.is_none() { + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::("Wallet not initialized") + })?; + + let wallet_provider = wallet.provider(); + + // Get runtime reference and use it before mutating self + let rt = self.runtime.as_ref().ok_or_else(|| { + PyErr::new::("Runtime not initialized") + })?; + + let contracts = py.allow_threads(|| { + rt.block_on(async { + // Build all contracts (required due to known bug) + ContractBuilder::new(wallet_provider) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build contracts: {}", + e + )) + }) + }) + })?; + + self.contracts = Some(Arc::new(contracts)); + } if let Some(port) = p2p_port { // Initialize P2P if port is provided @@ -202,6 +239,11 @@ impl ValidatorClient { let cancellation_token = self.cancellation_token.clone(); + // Get runtime reference for P2P initialization + let rt = self.runtime.as_ref().ok_or_else(|| { + PyErr::new::("Runtime not initialized") + })?; + // Create the P2P components let (auth_manager, peer_id, outbound_tx, user_message_rx, message_processor_handle) = py.allow_threads(|| { @@ -296,6 +338,124 @@ impl ValidatorClient { } } + /// Validate a node on the Prime Network contract + /// + /// Args: + /// node_address: The node address to validate + /// provider_address: The provider's address + /// + /// Returns: + /// Transaction hash as a string if successful + pub fn validate_node( + &self, + py: Python, + node_address: String, + provider_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + let contracts_clone = contracts.clone(); + + // Release the GIL while performing async operations + py.allow_threads(|| { + rt.block_on(async { + // Parse addresses + let provider_addr = + Address::parse_checksummed(&provider_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid provider address: {}", + e + )) + })?; + + let node_addr = Address::parse_checksummed(&node_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid node address: {}", + e + )) + })?; + + let tx_hash = contracts_clone + .prime_network + .validate_node(provider_addr, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to validate node: {}", + e + )) + })?; + + Ok(format!("0x{}", hex::encode(tx_hash))) + }) + }) + } + + /// Validate a node on the Prime Network contract with explicit addresses + /// + /// Args: + /// provider_address: The provider's address + /// node_address: The node's address + /// + /// Returns: + /// Transaction hash as a string if successful + pub fn validate_node_with_addresses( + &self, + py: Python, + provider_address: String, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + let contracts_clone = contracts.clone(); + + // Release the GIL while performing async operations + py.allow_threads(|| { + rt.block_on(async { + // Parse addresses + let provider_addr = + Address::parse_checksummed(&provider_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid provider address: {}", + e + )) + })?; + + let node_addr = Address::parse_checksummed(&node_address, None).map_err(|e| { + PyErr::new::(format!( + "Invalid node address: {}", + e + )) + })?; + + let tx_hash = contracts_clone + .prime_network + .validate_node(provider_addr, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to validate node: {}", + e + )) + })?; + + Ok(format!("0x{}", hex::encode(tx_hash))) + }) + }) + } + /// Get the validator's peer ID pub fn get_peer_id(&self) -> PyResult> { Ok(self.peer_id.map(|id| id.to_string())) From 5e327ac5e3f3b4f9a36a55cf18a9abd948cfbe11 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 17:16:02 +0200 Subject: [PATCH 13/23] ability to list all nodes for orchestrator compute pool --- .../examples/orchestrator_list_nodes.py | 47 +++++++++++++ crates/prime-protocol-py/src/common/mod.rs | 68 ++++++++++++++++++ crates/prime-protocol-py/src/lib.rs | 3 + .../prime-protocol-py/src/orchestrator/mod.rs | 43 +++++++++--- crates/prime-protocol-py/src/validator/mod.rs | 69 +------------------ 5 files changed, 154 insertions(+), 76 deletions(-) create mode 100644 crates/prime-protocol-py/examples/orchestrator_list_nodes.py create mode 100644 crates/prime-protocol-py/src/common/mod.rs diff --git a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py new file mode 100644 index 00000000..e2f94106 --- /dev/null +++ b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +""" +Example demonstrating how to list nodes for a specific pool using the OrchestratorClient. +""" + +import os +from primeprotocol import OrchestratorClient + +def main(): + # Replace with your actual RPC URL and private key + RPC_URL = "http://localhost:8545" + PRIVATE_KEY = os.getenv("ORCHESTRATOR_PRIVATE_KEY") + DISCOVERY_URLS = ["http://localhost:8089"] # Discovery service URLs + + # Create orchestrator client + orchestrator = OrchestratorClient( + rpc_url=RPC_URL, + private_key=PRIVATE_KEY, + discovery_urls=DISCOVERY_URLS + ) + + # Initialize the orchestrator (without P2P for this example) + orchestrator.start() + + # List nodes for a specific pool (example pool ID: 0) + pool_id = 0 + pool_nodes = orchestrator.list_nodes_for_pool(pool_id) + print(f"Nodes in pool {pool_id}: {len(pool_nodes)}") + + # Print details of all nodes in the pool + for i, node in enumerate(pool_nodes): + print(f"\nNode {i+1}:") + print(f" ID: {node.id}") + print(f" Provider Address: {node.provider_address}") + print(f" IP Address: {node.ip_address}") + print(f" Port: {node.port}") + print(f" Pool ID: {node.compute_pool_id}") + print(f" Validated: {node.is_validated}") + print(f" Active: {node.is_active}") + if node.worker_p2p_id: + print(f" Worker P2P ID: {node.worker_p2p_id}") + + # Stop the orchestrator + orchestrator.stop() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/prime-protocol-py/src/common/mod.rs b/crates/prime-protocol-py/src/common/mod.rs new file mode 100644 index 00000000..eb529d13 --- /dev/null +++ b/crates/prime-protocol-py/src/common/mod.rs @@ -0,0 +1,68 @@ +use pyo3::prelude::*; +use shared::models::node::DiscoveryNode; + +/// Node details structure shared between validator and orchestrator +#[pyclass] +#[derive(Clone)] +pub struct NodeDetails { + #[pyo3(get)] + pub id: String, + #[pyo3(get)] + pub provider_address: String, + #[pyo3(get)] + pub ip_address: String, + #[pyo3(get)] + pub port: u16, + #[pyo3(get)] + pub compute_pool_id: u32, + #[pyo3(get)] + pub is_validated: bool, + #[pyo3(get)] + pub is_active: bool, + #[pyo3(get)] + pub is_provider_whitelisted: bool, + #[pyo3(get)] + pub is_blacklisted: bool, + #[pyo3(get)] + pub worker_p2p_id: Option, + #[pyo3(get)] + pub last_updated: Option, + #[pyo3(get)] + pub created_at: Option, +} + +impl From for NodeDetails { + fn from(node: DiscoveryNode) -> Self { + Self { + id: node.node.id, + provider_address: node.node.provider_address, + ip_address: node.node.ip_address, + port: node.node.port, + compute_pool_id: node.node.compute_pool_id, + is_validated: node.is_validated, + is_active: node.is_active, + is_provider_whitelisted: node.is_provider_whitelisted, + is_blacklisted: node.is_blacklisted, + worker_p2p_id: node.node.worker_p2p_id, + last_updated: node.last_updated.map(|dt| dt.to_rfc3339()), + created_at: node.created_at.map(|dt| dt.to_rfc3339()), + } + } +} + +#[pymethods] +impl NodeDetails { + /// Get compute specifications as a Python dictionary + pub fn get_compute_specs(&self, py: Python) -> PyResult { + // This method would need access to the original node data + // For now, return None since we don't store compute specs in NodeDetails + Ok(py.None()) + } + + /// Get location information as a Python dictionary + pub fn get_location(&self, py: Python) -> PyResult { + // This method would need access to the original node data + // For now, return None since we don't store location in NodeDetails + Ok(py.None()) + } +} diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/prime-protocol-py/src/lib.rs index b5c6bb2b..a2e721b9 100644 --- a/crates/prime-protocol-py/src/lib.rs +++ b/crates/prime-protocol-py/src/lib.rs @@ -1,8 +1,10 @@ +use crate::common::NodeDetails; use crate::orchestrator::OrchestratorClient; use crate::validator::ValidatorClient; use crate::worker::WorkerClient; use pyo3::prelude::*; +mod common; mod constants; mod error; mod orchestrator; @@ -17,5 +19,6 @@ fn primeprotocol(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index 4c8e5be3..aab134b9 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -1,3 +1,4 @@ +use crate::common::NodeDetails; use crate::p2p_handler::auth::AuthenticationManager; use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; use crate::p2p_handler::{Message, MessageType, Service as P2PService}; @@ -12,6 +13,9 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use url::Url; +// Add new imports for discovery functionality +use shared::discovery::fetch_nodes_from_discovery_urls; + /// Prime Protocol Orchestrator Client - for managing and distributing tasks #[pyclass] pub struct OrchestratorClient { @@ -24,13 +28,19 @@ pub struct OrchestratorClient { user_message_rx: Option>>>, message_processor_handle: Option>, peer_id: Option, + // Discovery service URLs + discovery_urls: Vec, } #[pymethods] impl OrchestratorClient { #[new] - #[pyo3(signature = (rpc_url, private_key=None))] - pub fn new(rpc_url: String, private_key: Option) -> PyResult { + #[pyo3(signature = (rpc_url, private_key=None, discovery_urls=vec!["http://localhost:8089".to_string()]))] + pub fn new( + rpc_url: String, + private_key: Option, + discovery_urls: Vec, + ) -> PyResult { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -60,6 +70,7 @@ impl OrchestratorClient { user_message_rx: None, message_processor_handle: None, peer_id: None, + discovery_urls, }) } @@ -201,14 +212,28 @@ impl OrchestratorClient { Ok(()) } - pub fn list_validated_nodes(&self) -> PyResult> { - // TODO: Implement orchestrator node listing - Ok(vec![]) - } + /// List nodes for a specific compute pool + pub fn list_nodes_for_pool(&self, py: Python, pool_id: u32) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })?; + + let discovery_urls = self.discovery_urls.clone(); + let route = format!("/api/pool/{}", pool_id); + + let nodes = py.allow_threads(|| { + rt.block_on(async { + fetch_nodes_from_discovery_urls(&discovery_urls, &route, wallet) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + })?; - pub fn list_nodes_from_chain(&self) -> PyResult> { - // TODO: Implement orchestrator node listing from chain - Ok(vec![]) + Ok(nodes.into_iter().map(NodeDetails::from).collect()) } } diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index c716448b..bfeec6a2 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -1,3 +1,4 @@ +use crate::common::NodeDetails; use crate::p2p_handler::auth::AuthenticationManager; use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; use crate::p2p_handler::{Message, MessageType, Service as P2PService}; @@ -17,75 +18,9 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use url::Url; -/// Node details for validator operations -#[pyclass] -#[derive(Clone)] -pub(crate) struct NodeDetails { - #[pyo3(get)] - pub id: String, - #[pyo3(get)] - pub provider_address: String, - #[pyo3(get)] - pub ip_address: String, - #[pyo3(get)] - pub port: u16, - #[pyo3(get)] - pub compute_pool_id: u32, - #[pyo3(get)] - pub is_validated: bool, - #[pyo3(get)] - pub is_active: bool, - #[pyo3(get)] - pub is_provider_whitelisted: bool, - #[pyo3(get)] - pub is_blacklisted: bool, - #[pyo3(get)] - pub worker_p2p_id: Option, - #[pyo3(get)] - pub last_updated: Option, - #[pyo3(get)] - pub created_at: Option, -} - -impl From for NodeDetails { - fn from(node: DiscoveryNode) -> Self { - Self { - id: node.node.id, - provider_address: node.node.provider_address, - ip_address: node.node.ip_address, - port: node.node.port, - compute_pool_id: node.node.compute_pool_id, - is_validated: node.is_validated, - is_active: node.is_active, - is_provider_whitelisted: node.is_provider_whitelisted, - is_blacklisted: node.is_blacklisted, - worker_p2p_id: node.node.worker_p2p_id, - last_updated: node.last_updated.map(|dt| dt.to_rfc3339()), - created_at: node.created_at.map(|dt| dt.to_rfc3339()), - } - } -} - -#[pymethods] -impl NodeDetails { - /// Get compute specs as a Python dictionary - pub fn get_compute_specs(&self, py: Python) -> PyResult { - // This would need access to the original DiscoveryNode's compute_specs - // For now returning None - Ok(py.None()) - } - - /// Get location as a Python dictionary - pub fn get_location(&self, py: Python) -> PyResult { - // This would need access to the original DiscoveryNode's location - // For now returning None - Ok(py.None()) - } -} - /// Prime Protocol Validator Client - for validating nodes and tasks #[pyclass] -pub(crate) struct ValidatorClient { +pub struct ValidatorClient { runtime: Option, wallet: Option, discovery_urls: Vec, From c505468d2f0bc43a6f8061cd4a22c73b8a094797 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 17:27:54 +0200 Subject: [PATCH 14/23] move invite logic to prime-core crate --- Cargo.lock | 2 + crates/orchestrator/Cargo.toml | 1 + crates/orchestrator/src/node/invite.rs | 72 ++++-------- crates/prime-core/Cargo.toml | 1 + crates/prime-core/src/invite/admin.rs | 81 +++++++++++++ crates/prime-core/src/invite/common.rs | 108 +++++++++++++++++ crates/prime-core/src/invite/mod.rs | 7 ++ crates/prime-core/src/invite/worker.rs | 153 +++++++++++++++++++++++++ crates/prime-core/src/lib.rs | 1 + crates/worker/src/p2p/mod.rs | 14 +-- 10 files changed, 385 insertions(+), 55 deletions(-) create mode 100644 crates/prime-core/src/invite/admin.rs create mode 100644 crates/prime-core/src/invite/common.rs create mode 100644 crates/prime-core/src/invite/mod.rs create mode 100644 crates/prime-core/src/invite/worker.rs diff --git a/Cargo.lock b/Cargo.lock index 339c175c..48a90c27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6204,6 +6204,7 @@ dependencies = [ "log", "mockito", "p2p", + "prime-core", "prometheus 0.14.0", "rand 0.9.1", "redis", @@ -6721,6 +6722,7 @@ dependencies = [ "futures-util", "hex", "log", + "p2p", "rand 0.8.5", "redis", "serde", diff --git a/crates/orchestrator/Cargo.toml b/crates/orchestrator/Cargo.toml index ce733ee6..2703facf 100644 --- a/crates/orchestrator/Cargo.toml +++ b/crates/orchestrator/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] p2p = { workspace = true} shared = { workspace = true } +prime-core = { workspace = true } actix-web = { workspace = true } alloy = { workspace = true } diff --git a/crates/orchestrator/src/node/invite.rs b/crates/orchestrator/src/node/invite.rs index 8391d047..4e3cc874 100644 --- a/crates/orchestrator/src/node/invite.rs +++ b/crates/orchestrator/src/node/invite.rs @@ -3,19 +3,17 @@ use crate::models::node::OrchestratorNode; use crate::p2p::InviteRequest as InviteRequestWithMetadata; use crate::store::core::StoreContext; use crate::utils::loop_heartbeats::LoopHeartbeats; -use alloy::primitives::utils::keccak256 as keccak; -use alloy::primitives::U256; -use alloy::signers::Signer; use anyhow::{bail, Result}; use futures::stream; use futures::StreamExt; use log::{debug, error, info, warn}; -use p2p::InviteRequest; use p2p::InviteRequestUrl; +use prime_core::invite::{ + admin::{generate_invite_expiration, generate_invite_nonce, generate_invite_signature}, + common::InviteBuilder, +}; use shared::web3::wallet::Wallet; use std::sync::Arc; -use std::time::SystemTime; -use std::time::UNIX_EPOCH; use tokio::sync::mpsc::Sender; use tokio::time::{interval, Duration}; @@ -89,29 +87,15 @@ impl NodeInviter { nonce: [u8; 32], expiration: [u8; 32], ) -> Result<[u8; 65]> { - let domain_id: [u8; 32] = U256::from(self.domain_id).to_be_bytes(); - let pool_id: [u8; 32] = U256::from(self.pool_id).to_be_bytes(); - - let digest = keccak( - [ - &domain_id, - &pool_id, - node.address.as_slice(), - &nonce, - &expiration, - ] - .concat(), - ); - - let signature = self - .wallet - .signer - .sign_message(digest.as_slice()) - .await? - .as_bytes() - .to_owned(); - - Ok(signature) + generate_invite_signature( + &self.wallet, + self.domain_id, + self.pool_id, + node.address, + nonce, + expiration, + ) + .await } async fn send_invite(&self, node: &OrchestratorNode) -> Result<(), anyhow::Error> { @@ -122,29 +106,21 @@ impl NodeInviter { let p2p_addresses = node.worker_p2p_addresses.as_ref().unwrap(); // Generate random nonce and expiration - let nonce: [u8; 32] = rand::random(); - let expiration: [u8; 32] = U256::from( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| anyhow::anyhow!("System time error: {}", e))? - .as_secs() - + 1000, - ) - .to_be_bytes(); + let nonce = generate_invite_nonce(); + let expiration = generate_invite_expiration(Some(1000))?; let invite_signature = self.generate_invite(node, nonce, expiration).await?; - let payload = InviteRequest { - invite: hex::encode(invite_signature), - pool_id: self.pool_id, - url: self.url.clone(), - timestamp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|e| anyhow::anyhow!("System time error: {}", e))? - .as_secs(), - expiration, - nonce, + + // Build the invite request using the builder + let builder = match &self.url { + InviteRequestUrl::MasterUrl(url) => InviteBuilder::with_url(self.pool_id, url.clone()), + InviteRequestUrl::MasterIpPort(ip, port) => { + InviteBuilder::with_ip_port(self.pool_id, ip.clone(), *port) + } }; + let payload = builder.build(invite_signature, nonce, expiration)?; + info!("Sending invite to node: {p2p_id}"); let (response_tx, response_rx) = tokio::sync::oneshot::channel(); diff --git a/crates/prime-core/Cargo.toml b/crates/prime-core/Cargo.toml index bfcef45e..4b6ec28c 100644 --- a/crates/prime-core/Cargo.toml +++ b/crates/prime-core/Cargo.toml @@ -12,6 +12,7 @@ path = "src/lib.rs" [dependencies] shared = { workspace = true } +p2p = { workspace = true } alloy = { workspace = true } alloy-provider = { workspace = true } serde = { workspace = true } diff --git a/crates/prime-core/src/invite/admin.rs b/crates/prime-core/src/invite/admin.rs new file mode 100644 index 00000000..02de9a1e --- /dev/null +++ b/crates/prime-core/src/invite/admin.rs @@ -0,0 +1,81 @@ +use alloy::primitives::utils::keccak256 as keccak; +use alloy::primitives::{Address, U256}; +use alloy::signers::Signer; +use anyhow::Result; +use rand_v8::prelude::*; +use shared::web3::wallet::Wallet; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Generates an invite signature for a node +/// +/// This function is used by pool owners/admins to create signed invites +/// that authorize nodes to join their pool. +pub async fn generate_invite_signature( + wallet: &Wallet, + domain_id: u32, + pool_id: u32, + node_address: Address, + nonce: [u8; 32], + expiration: [u8; 32], +) -> Result<[u8; 65]> { + let domain_id_bytes: [u8; 32] = U256::from(domain_id).to_be_bytes(); + let pool_id_bytes: [u8; 32] = U256::from(pool_id).to_be_bytes(); + + let digest = keccak( + [ + &domain_id_bytes, + &pool_id_bytes, + node_address.as_slice(), + &nonce, + &expiration, + ] + .concat(), + ); + + let signature = wallet + .signer + .sign_message(digest.as_slice()) + .await? + .as_bytes() + .to_owned(); + + Ok(signature) +} + +/// Generates an invite expiration timestamp +/// +/// Creates an expiration timestamp for an invite, defaulting to 1000 seconds from now +pub fn generate_invite_expiration(seconds_from_now: Option) -> Result<[u8; 32]> { + let duration = seconds_from_now.unwrap_or(1000); + let expiration = U256::from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs() + + duration, + ); + Ok(expiration.to_be_bytes()) +} + +/// Generates a random nonce for invite +pub fn generate_invite_nonce() -> [u8; 32] { + rand_v8::rngs::OsRng.gen() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_invite_nonce() { + let nonce1 = generate_invite_nonce(); + let nonce2 = generate_invite_nonce(); + assert_ne!(nonce1, nonce2, "Nonces should be unique"); + } + + #[test] + fn test_generate_invite_expiration() { + let expiration = generate_invite_expiration(Some(3600)).unwrap(); + assert_eq!(expiration.len(), 32); + } +} diff --git a/crates/prime-core/src/invite/common.rs b/crates/prime-core/src/invite/common.rs new file mode 100644 index 00000000..dd153980 --- /dev/null +++ b/crates/prime-core/src/invite/common.rs @@ -0,0 +1,108 @@ +use alloy::primitives::Address; +use anyhow::Result; +use p2p::{InviteRequest, InviteRequestUrl}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Builder for creating invite requests +pub struct InviteBuilder { + pool_id: u32, + url: InviteRequestUrl, +} + +impl InviteBuilder { + /// Creates a new InviteBuilder with a master URL + pub fn with_url(pool_id: u32, url: String) -> Self { + Self { + pool_id, + url: InviteRequestUrl::MasterUrl(url), + } + } + + /// Creates a new InviteBuilder with IP and port + pub fn with_ip_port(pool_id: u32, ip: String, port: u16) -> Self { + Self { + pool_id, + url: InviteRequestUrl::MasterIpPort(ip, port), + } + } + + /// Builds an InviteRequest with the given parameters + pub fn build( + self, + invite_signature: [u8; 65], + nonce: [u8; 32], + expiration: [u8; 32], + ) -> Result { + Ok(InviteRequest { + invite: hex::encode(invite_signature), + pool_id: self.pool_id, + url: self.url, + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs(), + expiration, + nonce, + }) + } +} + +/// Metadata for an invite request that includes worker information +#[derive(Debug, Clone)] +pub struct InviteMetadata { + pub worker_wallet_address: Address, + pub worker_p2p_id: String, + pub worker_addresses: Vec, +} + +/// Full invite request with metadata +#[derive(Debug)] +pub struct InviteWithMetadata { + pub metadata: InviteMetadata, + pub request: InviteRequest, +} + +/// Helper to parse InviteRequestUrl into a usable endpoint +pub fn get_endpoint_from_url(url: &InviteRequestUrl, path: &str) -> String { + match url { + InviteRequestUrl::MasterIpPort(ip, port) => { + format!("http://{ip}:{port}/{path}") + } + InviteRequestUrl::MasterUrl(url) => { + let url = url.trim_end_matches('/'); + format!("{url}/{path}") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_invite_builder_with_url() { + let builder = InviteBuilder::with_url(1, "https://example.com".to_string()); + let signature = [0u8; 65]; + let nonce = [1u8; 32]; + let expiration = [2u8; 32]; + + let invite = builder.build(signature, nonce, expiration).unwrap(); + assert_eq!(invite.pool_id, 1); + assert!(matches!(invite.url, InviteRequestUrl::MasterUrl(_))); + } + + #[test] + fn test_get_endpoint_from_url() { + let url = InviteRequestUrl::MasterUrl("https://example.com".to_string()); + assert_eq!( + get_endpoint_from_url(&url, "heartbeat"), + "https://example.com/heartbeat" + ); + + let ip_port = InviteRequestUrl::MasterIpPort("192.168.1.1".to_string(), 8080); + assert_eq!( + get_endpoint_from_url(&ip_port, "heartbeat"), + "http://192.168.1.1:8080/heartbeat" + ); + } +} diff --git a/crates/prime-core/src/invite/mod.rs b/crates/prime-core/src/invite/mod.rs new file mode 100644 index 00000000..5d7d044d --- /dev/null +++ b/crates/prime-core/src/invite/mod.rs @@ -0,0 +1,7 @@ +pub mod admin; +pub mod common; +pub mod worker; + +pub use admin::*; +pub use common::*; +pub use worker::*; diff --git a/crates/prime-core/src/invite/worker.rs b/crates/prime-core/src/invite/worker.rs new file mode 100644 index 00000000..682855a5 --- /dev/null +++ b/crates/prime-core/src/invite/worker.rs @@ -0,0 +1,153 @@ +use alloy::primitives::utils::keccak256 as keccak; +use alloy::primitives::{Address, Signature, U256}; +use anyhow::{bail, Result}; +use p2p::InviteRequest; +use std::str::FromStr; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Verifies an invite signature +/// +/// This function is used by workers to verify that an invite is valid +/// and was signed by the correct pool owner/admin. +pub fn verify_invite_signature( + invite: &InviteRequest, + domain_id: u32, + worker_address: Address, + signer_address: Address, +) -> Result { + // Parse the signature from hex string + let signature = Signature::from_str(&invite.invite) + .map_err(|e| anyhow::anyhow!("Failed to parse invite signature: {}", e))?; + + // Recreate the message that was signed + let domain_id_bytes: [u8; 32] = U256::from(domain_id).to_be_bytes(); + let pool_id_bytes: [u8; 32] = U256::from(invite.pool_id).to_be_bytes(); + + let message = keccak( + [ + &domain_id_bytes, + &pool_id_bytes, + worker_address.as_slice(), + &invite.nonce, + &invite.expiration, + ] + .concat(), + ); + + // Verify the signature + let recovered = signature + .recover_address_from_msg(message.as_slice()) + .map_err(|e| anyhow::anyhow!("Failed to recover address: {}", e))?; + + Ok(recovered == signer_address) +} + +/// Checks if an invite has expired +pub fn is_invite_expired(invite: &InviteRequest) -> Result { + let expiration_u256 = U256::from_be_bytes(invite.expiration); + let expiration_secs = expiration_u256.to::(); + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs(); + + Ok(current_time > expiration_secs) +} + +/// Validates an invite for a worker +/// +/// Performs full validation including signature verification and expiration check +pub fn validate_invite( + invite: &InviteRequest, + domain_id: u32, + worker_address: Address, + pool_owner_address: Address, +) -> Result<()> { + // Check expiration + if is_invite_expired(invite)? { + bail!("Invite has expired"); + } + + // Verify signature + if !verify_invite_signature(invite, domain_id, worker_address, pool_owner_address)? { + bail!("Invalid invite signature"); + } + + Ok(()) +} + +/// Extract pool information from an invite +#[derive(Debug, Clone)] +pub struct PoolInfo { + pub pool_id: u32, + pub endpoint_url: String, +} + +impl PoolInfo { + pub fn from_invite(invite: &InviteRequest) -> Self { + use crate::invite::common::get_endpoint_from_url; + + Self { + pool_id: invite.pool_id, + endpoint_url: get_endpoint_from_url(&invite.url, ""), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use p2p::InviteRequestUrl; + + #[test] + fn test_is_invite_expired() { + // Create an expired invite + let past_time = U256::from(1000u64); + let invite = InviteRequest { + invite: "test".to_string(), + pool_id: 1, + url: InviteRequestUrl::MasterUrl("https://example.com".to_string()), + timestamp: 0, + expiration: past_time.to_be_bytes(), + nonce: [0u8; 32], + }; + + assert!(is_invite_expired(&invite).unwrap()); + + // Create a future invite + let future_time = U256::from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600, + ); + let future_invite = InviteRequest { + invite: "test".to_string(), + pool_id: 1, + url: InviteRequestUrl::MasterUrl("https://example.com".to_string()), + timestamp: 0, + expiration: future_time.to_be_bytes(), + nonce: [0u8; 32], + }; + + assert!(!is_invite_expired(&future_invite).unwrap()); + } + + #[test] + fn test_pool_info_from_invite() { + let invite = InviteRequest { + invite: "test".to_string(), + pool_id: 42, + url: InviteRequestUrl::MasterUrl("https://example.com/api".to_string()), + timestamp: 0, + expiration: [0u8; 32], + nonce: [0u8; 32], + }; + + let pool_info = PoolInfo::from_invite(&invite); + assert_eq!(pool_info.pool_id, 42); + assert_eq!(pool_info.endpoint_url, "https://example.com/api/"); + } +} diff --git a/crates/prime-core/src/lib.rs b/crates/prime-core/src/lib.rs index 1bf04f8a..9734c863 100644 --- a/crates/prime-core/src/lib.rs +++ b/crates/prime-core/src/lib.rs @@ -1 +1,2 @@ +pub mod invite; pub mod operations; diff --git a/crates/worker/src/p2p/mod.rs b/crates/worker/src/p2p/mod.rs index 94fe10a3..2635f3f9 100644 --- a/crates/worker/src/p2p/mod.rs +++ b/crates/worker/src/p2p/mod.rs @@ -1,12 +1,12 @@ use anyhow::Context as _; use anyhow::Result; use futures::stream::FuturesUnordered; -use p2p::InviteRequestUrl; use p2p::Node; use p2p::NodeBuilder; use p2p::PeerId; use p2p::Response; use p2p::{IncomingMessage, Libp2pIncomingMessage, OutgoingMessage}; +use prime_core::invite::{common::get_endpoint_from_url, worker::is_invite_expired}; use shared::web3::contracts::core::builder::Contracts; use shared::web3::wallet::Wallet; use std::collections::HashMap; @@ -421,6 +421,11 @@ async fn handle_invite_request( ); } + // Check if invite has expired + if is_invite_expired(&req)? { + anyhow::bail!("invite has expired"); + } + let invite_bytes = hex::decode(&req.invite).context("failed to decode invite hex")?; if invite_bytes.len() < 65 { @@ -481,12 +486,7 @@ async fn handle_invite_request( } } - let heartbeat_endpoint = match req.url { - InviteRequestUrl::MasterIpPort(ip, port) => { - format!("http://{ip}:{port}/heartbeat") - } - InviteRequestUrl::MasterUrl(url) => format!("{url}/heartbeat"), - }; + let heartbeat_endpoint = get_endpoint_from_url(&req.url, "heartbeat"); context .heartbeat_service From 26a79bb696e96bdfe6a99862cabfa3831ba1cfb6 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 19:14:08 +0200 Subject: [PATCH 15/23] implement orchestrator automatic invite flow, fix auth race condition --- .../examples/orchestrator_list_nodes.py | 84 ++++++++---- crates/prime-protocol-py/src/common/mod.rs | 5 +- .../prime-protocol-py/src/orchestrator/mod.rs | 127 ++++++++++++++++++ .../prime-protocol-py/src/p2p_handler/auth.rs | 42 ++++-- .../src/p2p_handler/message_processor.rs | 41 ++++-- .../prime-protocol-py/src/p2p_handler/mod.rs | 38 +++++- .../src/worker/blockchain.rs | 60 +++++++++ crates/prime-protocol-py/src/worker/client.rs | 85 ++++++++---- crates/prime-protocol-py/src/worker/mod.rs | 5 + 9 files changed, 413 insertions(+), 74 deletions(-) diff --git a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py index e2f94106..cd09e65f 100644 --- a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py +++ b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py @@ -4,9 +4,18 @@ """ import os +import signal +import sys from primeprotocol import OrchestratorClient +def signal_handler(sig, frame): + print('\nShutting down gracefully...') + sys.exit(0) + def main(): + # Set up signal handler for Ctrl+C + signal.signal(signal.SIGINT, signal_handler) + # Replace with your actual RPC URL and private key RPC_URL = "http://localhost:8545" PRIVATE_KEY = os.getenv("ORCHESTRATOR_PRIVATE_KEY") @@ -19,29 +28,58 @@ def main(): discovery_urls=DISCOVERY_URLS ) - # Initialize the orchestrator (without P2P for this example) - orchestrator.start() - - # List nodes for a specific pool (example pool ID: 0) - pool_id = 0 - pool_nodes = orchestrator.list_nodes_for_pool(pool_id) - print(f"Nodes in pool {pool_id}: {len(pool_nodes)}") - - # Print details of all nodes in the pool - for i, node in enumerate(pool_nodes): - print(f"\nNode {i+1}:") - print(f" ID: {node.id}") - print(f" Provider Address: {node.provider_address}") - print(f" IP Address: {node.ip_address}") - print(f" Port: {node.port}") - print(f" Pool ID: {node.compute_pool_id}") - print(f" Validated: {node.is_validated}") - print(f" Active: {node.is_active}") - if node.worker_p2p_id: - print(f" Worker P2P ID: {node.worker_p2p_id}") - - # Stop the orchestrator - orchestrator.stop() + try: + # Initialize the orchestrator (without P2P for this example) + orchestrator.start(p2p_port=8180) + + # List nodes for a specific pool (example pool ID: 0) + pool_id = 0 + pool_nodes = orchestrator.list_nodes_for_pool(pool_id) + print(f"Nodes in pool {pool_id}: {len(pool_nodes)}") + + # Print details of all nodes in the pool + for i, node in enumerate(pool_nodes): + print(f"\nNode {i+1}:") + print(f" ID: {node.id}") + print(f" Provider Address: {node.provider_address}") + print(f" IP Address: {node.ip_address}") + print(f" Port: {node.port}") + print(f" Pool ID: {node.compute_pool_id}") + print(f" Validated: {node.is_validated}") + print(f" Worker P2P Addresses: {node.worker_p2p_addresses}") + print(f" Active: {node.is_active}") + if node.worker_p2p_id: + print(f" Worker P2P ID: {node.worker_p2p_id}") + if node.is_active is False: + # Invite node with required parameters + orchestrator.invite_node( + peer_id=node.worker_p2p_id, + worker_address=node.id, + pool_id=pool_id, + multiaddrs=node.worker_p2p_addresses, + # todo: automatically fetch from contract + domain_id=0, + # tood: deprecate + orchestrator_url=None, + expiration_seconds=1000 + ) + + print("\nPress Ctrl+C to exit...") + + # Keep the program running until Ctrl+C + while True: + try: + signal.pause() + except AttributeError: + # signal.pause() is not available on Windows + import time + time.sleep(1) + + except KeyboardInterrupt: + print('\nShutting down gracefully...') + finally: + # Stop the orchestrator + orchestrator.stop() if __name__ == "__main__": main() \ No newline at end of file diff --git a/crates/prime-protocol-py/src/common/mod.rs b/crates/prime-protocol-py/src/common/mod.rs index eb529d13..89aae39a 100644 --- a/crates/prime-protocol-py/src/common/mod.rs +++ b/crates/prime-protocol-py/src/common/mod.rs @@ -26,6 +26,8 @@ pub struct NodeDetails { #[pyo3(get)] pub worker_p2p_id: Option, #[pyo3(get)] + pub worker_p2p_addresses: Option>, + #[pyo3(get)] pub last_updated: Option, #[pyo3(get)] pub created_at: Option, @@ -43,9 +45,10 @@ impl From for NodeDetails { is_active: node.is_active, is_provider_whitelisted: node.is_provider_whitelisted, is_blacklisted: node.is_blacklisted, - worker_p2p_id: node.node.worker_p2p_id, last_updated: node.last_updated.map(|dt| dt.to_rfc3339()), created_at: node.created_at.map(|dt| dt.to_rfc3339()), + worker_p2p_id: node.node.worker_p2p_id, + worker_p2p_addresses: node.node.worker_p2p_addresses, } } } diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index aab134b9..8d3a014e 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -16,6 +16,14 @@ use url::Url; // Add new imports for discovery functionality use shared::discovery::fetch_nodes_from_discovery_urls; +// Add imports for invite functionality +use alloy::primitives::Address; +use prime_core::invite::{ + admin::{generate_invite_expiration, generate_invite_nonce, generate_invite_signature}, + common::InviteBuilder, +}; +use std::str::FromStr; + /// Prime Protocol Orchestrator Client - for managing and distributing tasks #[pyclass] pub struct OrchestratorClient { @@ -235,6 +243,125 @@ impl OrchestratorClient { Ok(nodes.into_iter().map(NodeDetails::from).collect()) } + + /// Invite a node to join a compute pool + /// + /// This method creates a signed invite and sends it to the specified worker node. + /// The invite contains pool information and authentication details that the worker + /// can validate before joining the pool. + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (peer_id, worker_address, pool_id, multiaddrs, domain_id=1, orchestrator_url=None, expiration_seconds=1000))] + pub fn invite_node( + &self, + py: Python, + peer_id: String, + worker_address: String, + pool_id: u32, + multiaddrs: Vec, + domain_id: u32, + orchestrator_url: Option, + expiration_seconds: u64, + ) -> PyResult<()> { + println!("invite_node"); + let rt = self.get_or_create_runtime()?; + + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })?; + + let outbound_tx = self.outbound_tx.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + let auth_manager = self.auth_manager.as_ref().ok_or_else(|| { + PyErr::new::( + "P2P not initialized. Call start() with p2p_port parameter.", + ) + })?; + + // Parse worker address + let worker_addr = Address::from_str(&worker_address).map_err(|e| { + PyErr::new::(format!( + "Invalid worker address: {}", + e + )) + })?; + + println!("worker_addr: {:?}", worker_addr); + + let wallet = wallet.clone(); + let outbound_tx = outbound_tx.clone(); + let auth_manager = auth_manager.clone(); + + println!("invite_node 2"); + + py.allow_threads(|| { + rt.block_on(async { + // Generate invite parameters + let nonce = generate_invite_nonce(); + let expiration = + generate_invite_expiration(Some(expiration_seconds)).map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + // Generate the invite signature + let invite_signature = generate_invite_signature( + &wallet, + domain_id, + pool_id, + worker_addr, + nonce, + expiration, + ) + .await + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Build the invite request + let invite_builder = if let Some(url) = orchestrator_url { + InviteBuilder::with_url(pool_id, url) + } else { + // Default to localhost if no URL provided + InviteBuilder::with_url(pool_id, "http://localhost:8080".to_string()) + }; + + let invite_request = invite_builder + .build(invite_signature, nonce, expiration) + .map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + // Serialize the invite request + let invite_data = serde_json::to_vec(&invite_request).map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + // Create a general message with the invite data + let message = Message { + message_type: MessageType::General { data: invite_data }, + peer_id, + multiaddrs, + sender_address: Some(wallet.wallet.default_signer().address().to_string()), + response_tx: None, + }; + + // Send the invite + println!("sending invite"); + crate::p2p_handler::send_message_with_auth(message, &auth_manager, &outbound_tx) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + })?; + + println!("invite sent"); + + Ok(()) + }) + }) + } } // Private implementation methods diff --git a/crates/prime-protocol-py/src/p2p_handler/auth.rs b/crates/prime-protocol-py/src/p2p_handler/auth.rs index f8ecaa70..a6aa1c99 100644 --- a/crates/prime-protocol-py/src/p2p_handler/auth.rs +++ b/crates/prime-protocol-py/src/p2p_handler/auth.rs @@ -64,6 +64,8 @@ pub struct AuthenticationManager { responder_message_queue: Arc>>>, /// Track authenticated peers authenticated_peers: Arc>>, + /// Track peers we're waiting for authentication acknowledgment from + pending_auth_acknowledgment: Arc>>, /// Our wallet for signing node_wallet: Arc, } @@ -76,6 +78,7 @@ impl AuthenticationManager { responding_to_peers: Arc::new(RwLock::new(HashSet::new())), responder_message_queue: Arc::new(RwLock::new(HashMap::new())), authenticated_peers: Arc::new(RwLock::new(HashSet::new())), + pending_auth_acknowledgment: Arc::new(RwLock::new(HashMap::new())), node_wallet, } } @@ -174,14 +177,14 @@ impl AuthenticationManager { Ok((our_challenge, signature)) } - /// Handle authentication response from peer + /// Handle authentication response from peer (when we initiated) pub async fn handle_auth_response( &self, peer_id: &str, their_challenge: &str, their_signature: &str, - ) -> Result<(String, Option)> { - // Verify we have an ongoing auth request for this peer + ) -> Result { + // Get our ongoing auth challenge let mut ongoing_auth = self.ongoing_auth_requests.write().await; let auth_challenge = ongoing_auth.get_mut(peer_id).ok_or_else(|| { PrimeProtocolError::InvalidConfig(format!( @@ -190,12 +193,11 @@ impl AuthenticationManager { )) })?; - // Verify their signature to get their address + // Verify the signature matches the challenge we sent let parsed_signature = Signature::from_str(their_signature).map_err(|e| { PrimeProtocolError::InvalidConfig(format!("Invalid signature format: {}", e)) })?; - // Recover the peer's address from their signature let recovered_address = parsed_signature .recover_address_from_msg(&auth_challenge.auth_challenge_request_message) .map_err(|e| { @@ -214,15 +216,13 @@ impl AuthenticationManager { PrimeProtocolError::BlockchainError(format!("Failed to sign challenge: {}", e)) })?; - // Mark peer as authenticated - self.mark_authenticated(peer_id.to_string()).await; - - // Get the queued message to send after auth - let queued_message = ongoing_auth - .remove(peer_id) - .map(|auth| auth.outgoing_message); + // Store the queued message for sending after acknowledgment + if let Some(auth) = ongoing_auth.remove(peer_id) { + self.queue_for_acknowledgment(peer_id.to_string(), auth.outgoing_message) + .await; + } - Ok((our_signature, queued_message)) + Ok(our_signature) } /// Handle authentication solution from peer @@ -281,4 +281,20 @@ impl AuthenticationManager { .address() .to_string() } + + /// Store a message to send after receiving authentication acknowledgment + pub async fn queue_for_acknowledgment(&self, peer_id: String, message: Message) { + self.pending_auth_acknowledgment + .write() + .await + .insert(peer_id, message); + } + + /// Handle authentication acknowledgment and return queued message if any + pub async fn handle_auth_acknowledgment(&self, peer_id: &str) -> Option { + self.pending_auth_acknowledgment + .write() + .await + .remove(peer_id) + } } diff --git a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs index e062bfc6..8a9c2b68 100644 --- a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs +++ b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs @@ -159,6 +159,10 @@ impl MessageProcessor { Ok(()) } } + MessageType::AuthenticationComplete => { + // Authentication has been acknowledged by the peer + self.handle_auth_complete(peer_id).await + } MessageType::General { data } => { // Forward general messages to user let msg = Message { @@ -210,7 +214,7 @@ impl MessageProcessor { challenge: String, signature: String, ) -> Result<()> { - let (our_signature, queued_message) = self + let our_signature = self .auth_manager .handle_auth_response(&message.peer_id, &challenge, &signature) .await?; @@ -248,16 +252,6 @@ impl MessageProcessor { )) })?; - // Send the queued message if any - if let Some(msg) = queued_message { - self.outbound_tx.lock().await.send(msg).await.map_err(|e| { - crate::error::PrimeProtocolError::InvalidConfig(format!( - "Failed to send queued message: {}", - e - )) - })?; - } - Ok(()) } @@ -298,4 +292,29 @@ impl MessageProcessor { ) }) } + + async fn handle_auth_complete(&self, peer_id: String) -> Result<()> { + log::info!("Authentication complete for peer: {}", peer_id); + + // Mark peer as authenticated + self.auth_manager.mark_authenticated(peer_id.clone()).await; + + // Get and send any queued message + if let Some(queued_message) = self.auth_manager.handle_auth_acknowledgment(&peer_id).await { + log::debug!("Sending queued message to peer {}", peer_id); + self.outbound_tx + .lock() + .await + .send(queued_message) + .await + .map_err(|e| { + crate::error::PrimeProtocolError::InvalidConfig(format!( + "Failed to send queued message: {}", + e + )) + })?; + } + + Ok(()) + } } diff --git a/crates/prime-protocol-py/src/p2p_handler/mod.rs b/crates/prime-protocol-py/src/p2p_handler/mod.rs index 2cf4b730..4f7c2d6a 100644 --- a/crates/prime-protocol-py/src/p2p_handler/mod.rs +++ b/crates/prime-protocol-py/src/p2p_handler/mod.rs @@ -41,6 +41,7 @@ pub enum MessageType { AuthenticationSolution { signature: String, }, + AuthenticationComplete, } #[derive(Debug, Serialize, Deserialize)] @@ -77,7 +78,10 @@ impl Service { let (message_queue_tx, message_queue_rx) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_CHANNEL_SIZE); - let protocols = Protocols::new().with_general().with_authentication(); + let protocols = Protocols::new() + .with_general() + .with_authentication() + .with_invite(); let listen_addr = format!("/ip4/0.0.0.0/tcp/{}", port) .parse() @@ -146,6 +150,14 @@ impl Service { signature: signature.clone(), }), ), + MessageType::AuthenticationComplete => { + // This message type should not be sent as a request + // It should be handled via response channels + log::error!( + "AuthenticationComplete should be sent via response channel, not as a request" + ); + return Ok(()); + } }; let peer_id = PeerId::from_str(&message.peer_id).context("Failed to parse peer ID")?; @@ -453,7 +465,29 @@ impl Service { } p2p::AuthenticationResponse::Solution(sol_resp) => { log::debug!("Auth solution response: {:?}", sol_resp); - // This is handled by the authentication flow + // Forward the auth solution response to the message processor + match sol_resp { + p2p::AuthenticationSolutionResponse::Granted => { + // Create a special message to indicate auth is complete + let message = Message { + message_type: MessageType::AuthenticationComplete, + peer_id: peer_id.to_string(), + multiaddrs: vec![], + sender_address: None, + response_tx: None, + }; + + if let Err(e) = message_queue_tx.send(message).await { + log::error!( + "Failed to forward auth complete to application: {}", + e + ); + } + } + p2p::AuthenticationSolutionResponse::Rejected => { + log::warn!("Authentication rejected by peer: {}", peer_id); + } + } } } } diff --git a/crates/prime-protocol-py/src/worker/blockchain.rs b/crates/prime-protocol-py/src/worker/blockchain.rs index c03d4e09..af5b97f3 100644 --- a/crates/prime-protocol-py/src/worker/blockchain.rs +++ b/crates/prime-protocol-py/src/worker/blockchain.rs @@ -271,4 +271,64 @@ impl BlockchainService { log::info!("Compute node registered successfully"); Ok(()) } + + /// Join compute pool using an invite + pub async fn join_compute_pool_with_invite(&self, invite: &p2p::InviteRequest) -> Result<()> { + use shared::web3::contracts::core::builder::ContractBuilder; + use shared::web3::contracts::helpers::utils::retry_call; + + let invite_bytes = hex::decode(&invite.invite).context("Failed to decode invite")?; + + if invite_bytes.len() < 65 { + anyhow::bail!("Invite data is too short"); + } + + let bytes_array: [u8; 65] = invite_bytes[..65] + .try_into() + .map_err(|_| anyhow::anyhow!("Failed to convert invite to array"))?; + + // Get wallets + let provider_wallet = self + .provider_wallet + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Provider wallet not initialized"))?; + let node_wallet = self + .node_wallet + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Node wallet not initialized"))?; + + // Create contracts + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .context("Failed to build contracts")?; + + let pool_id = alloy::primitives::U256::from(invite.pool_id); + let provider_address = provider_wallet.wallet.default_signer().address(); + let node_address = vec![node_wallet.wallet.default_signer().address()]; + let signatures = vec![alloy::primitives::FixedBytes::from(&bytes_array)]; + + let call = contracts + .compute_pool + .build_join_compute_pool_call( + pool_id, + provider_address, + node_address, + vec![invite.nonce], + vec![invite.expiration], + signatures, + ) + .map_err(|e| anyhow::anyhow!("Failed to build join compute pool call: {}", e))?; + + let result = retry_call(call, 3, provider_wallet.provider.clone(), None) + .await + .context("Failed to join compute pool")?; + + log::info!("Successfully joined compute pool with tx: {}", result); + Ok(()) + } } diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index 62f87b94..e87907c8 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -3,7 +3,7 @@ use crate::error::{PrimeProtocolError, Result}; use crate::p2p_handler::auth::AuthenticationManager; use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; use crate::worker::blockchain::{BlockchainConfig, BlockchainService}; -use crate::worker::p2p_handler::{Message, Service as P2PService}; +use crate::worker::p2p_handler::{Message, MessageType, Service as P2PService}; use p2p::{Keypair, PeerId}; use std::sync::Arc; use tokio::sync::mpsc::{Receiver, Sender}; @@ -106,29 +106,20 @@ impl WorkerClientCore { /// Start the worker client asynchronously pub async fn start_async(&mut self) -> Result<()> { - log::info!("Starting worker client..."); + log::info!("Starting WorkerClient"); - // Initialize blockchain components self.initialize_blockchain().await?; - - // Initialize authentication manager self.initialize_auth_manager()?; - - // Start P2P networking self.start_p2p_service().await?; - - // Start message processor self.start_message_processor().await?; - log::info!("Worker client started successfully"); + log::info!("WorkerClient started successfully"); Ok(()) } - /// Stop the worker client and clean up resources + /// Stop the worker client asynchronously pub async fn stop_async(&mut self) -> Result<()> { log::info!("Stopping worker client..."); - - // Cancel all background tasks self.cancellation_token.cancel(); // Stop message processor @@ -136,13 +127,12 @@ impl WorkerClientCore { handle.abort(); } - // Wait for P2P service to shut down + // Give P2P service time to shutdown gracefully + tokio::time::sleep(P2P_SHUTDOWN_TIMEOUT).await; + + // Stop P2P service if let Some(handle) = self.p2p_state.handle.take() { - match tokio::time::timeout(P2P_SHUTDOWN_TIMEOUT, handle).await { - Ok(Ok(_)) => log::info!("P2P service shut down gracefully"), - Ok(Err(e)) => log::error!("P2P service error during shutdown: {:?}", e), - Err(_) => log::warn!("P2P service shutdown timed out"), - } + let _ = tokio::time::timeout(P2P_SHUTDOWN_TIMEOUT, handle).await; } log::info!("Worker client stopped"); @@ -154,14 +144,60 @@ impl WorkerClientCore { self.p2p_state.peer_id } - /// Get the next message from the P2P network (only returns general messages) + /// Get the next message from the P2P network pub async fn get_next_message(&self) -> Option { let rx = self.user_message_rx.as_ref()?; - tokio::time::timeout(MESSAGE_QUEUE_TIMEOUT, rx.lock().await.recv()) - .await - .ok() - .flatten() + match tokio::time::timeout(MESSAGE_QUEUE_TIMEOUT, rx.lock().await.recv()).await { + Ok(Some(message)) => { + // Check if it's an invite and process it automatically + if let MessageType::General { ref data } = message.message_type { + if let Ok(invite) = serde_json::from_slice::(data) { + println!("Received invite from peer: {}", message.peer_id); + log::info!("Received invite from peer: {}", message.peer_id); + + // Check if invite has expired + if let Ok(true) = prime_core::invite::worker::is_invite_expired(&invite) { + log::warn!("Received expired invite from peer: {}", message.peer_id); + return Some(message); // Return it so user can see the expired invite + } + + // Verify pool ID matches + if invite.pool_id != self.config.compute_pool_id as u32 { + log::warn!( + "Received invite for wrong pool: expected {}, got {}", + self.config.compute_pool_id, + invite.pool_id + ); + return Some(message); // Return it so user can see the wrong pool invite + } + + // Process the invite automatically + if let Some(blockchain_service) = &self.blockchain_service { + match blockchain_service + .join_compute_pool_with_invite(&invite) + .await + { + Ok(()) => { + log::info!( + "Successfully processed invite and joined compute pool" + ); + // Don't return the invite message since we handled it + return None; + } + Err(e) => { + log::error!("Failed to process invite: {}", e); + // Return the message so user knows about the failed invite + } + } + } + } + } + Some(message) + } + Ok(None) => None, + Err(_) => None, // Timeout + } } /// Send a message to a peer @@ -256,6 +292,7 @@ impl WorkerClientCore { let message_processor = MessageProcessor::from_config(config); self.message_processor_handle = Some(message_processor.spawn()); + log::info!("Message processor started"); Ok(()) } diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index aaeb3cb4..13fbe027 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -216,6 +216,11 @@ fn message_to_pyobject(message: Message) -> PyObject { "signature": signature, }) } + p2p_handler::MessageType::AuthenticationComplete => { + serde_json::json!({ + "type": "auth_complete", + }) + } }; let json_value = serde_json::json!({ From dd641981b9185515d0e87bb8278bc77731573d90 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Sun, 13 Jul 2025 19:56:05 +0200 Subject: [PATCH 16/23] ability to eject nodes on orchestrator --- .../examples/orchestrator_list_nodes.py | 14 + .../prime-protocol-py/src/orchestrator/mod.rs | 273 ++++++++++++++++++ 2 files changed, 287 insertions(+) diff --git a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py index cd09e65f..874f9409 100644 --- a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py +++ b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py @@ -6,6 +6,7 @@ import os import signal import sys +from time import sleep from primeprotocol import OrchestratorClient def signal_handler(sig, frame): @@ -31,6 +32,8 @@ def main(): try: # Initialize the orchestrator (without P2P for this example) orchestrator.start(p2p_port=8180) + # todo: temp fix for establishing p2p connections + sleep(5) # List nodes for a specific pool (example pool ID: 0) pool_id = 0 @@ -63,6 +66,17 @@ def main(): orchestrator_url=None, expiration_seconds=1000 ) + else: + try: + # todo: we need an actual ack + orchestrator.send_message( + peer_id=node.worker_p2p_id, + multiaddrs=node.worker_p2p_addresses, + data=b"Hello, world!", + ) + print(f"Message sent to node {node.id}") + except Exception as e: + print(f"Error sending message to node {node.id}: {e}") print("\nPress Ctrl+C to exit...") diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index 8d3a014e..7af6894f 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -24,6 +24,10 @@ use prime_core::invite::{ }; use std::str::FromStr; +// Add imports for compute pool contract functionality +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::wallet::WalletProvider; + /// Prime Protocol Orchestrator Client - for managing and distributing tasks #[pyclass] pub struct OrchestratorClient { @@ -38,6 +42,8 @@ pub struct OrchestratorClient { peer_id: Option, // Discovery service URLs discovery_urls: Vec, + // Contracts + contracts: Option>>, } #[pymethods] @@ -79,12 +85,50 @@ impl OrchestratorClient { message_processor_handle: None, peer_id: None, discovery_urls, + contracts: None, }) } /// Initialize the orchestrator client with optional P2P support #[pyo3(signature = (p2p_port=None))] pub fn start(&mut self, py: Python, p2p_port: Option) -> PyResult<()> { + // Initialize contracts if not already done + if self.contracts.is_none() { + let wallet = self.wallet.as_ref().ok_or_else(|| { + PyErr::new::( + "Wallet not initialized. Provide private_key when creating client.", + ) + })?; + + let wallet_provider = wallet.provider(); + + // Get runtime reference and use it before mutating self + let rt = self.runtime.as_ref().ok_or_else(|| { + PyErr::new::("Runtime not initialized") + })?; + + let contracts = py.allow_threads(|| { + rt.block_on(async { + // Build all contracts (required due to known bug) + ContractBuilder::new(wallet_provider) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build contracts: {}", + e + )) + }) + }) + })?; + + self.contracts = Some(Arc::new(contracts)); + } + let rt = self.get_or_create_runtime()?; if let Some(port) = p2p_port { @@ -362,6 +406,235 @@ impl OrchestratorClient { }) }) } + + /// Eject a node from a compute pool + /// + /// This method removes a node from the specified compute pool. Only the pool's + /// compute manager can eject nodes from their pool. + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to eject + /// + /// Returns: + /// The transaction hash as a hex string + pub fn eject_node(&self, py: Python, pool_id: u32, node_address: String) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Call eject_node on the contract + let tx_hash = contracts_clone + .compute_pool + .eject_node(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to eject node: {}", + e + )) + })?; + + // Convert transaction hash to hex string + Ok(format!("0x{}", tx_hash)) + }) + }) + } + + /// Blacklist a node from a compute pool + /// + /// This method blacklists a node from the specified compute pool, preventing it from + /// rejoining. Only the pool's compute manager can blacklist nodes in their pool. + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to blacklist + /// + /// Returns: + /// The transaction hash as a hex string + pub fn blacklist_node( + &self, + py: Python, + pool_id: u32, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Call blacklist_node on the contract + let tx_hash = contracts_clone + .compute_pool + .blacklist_node(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to blacklist node: {}", + e + )) + })?; + + // Convert transaction hash to hex string + Ok(format!("0x{}", tx_hash)) + }) + }) + } + + /// Check if a node is blacklisted from a compute pool + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to check + /// + /// Returns: + /// True if the node is blacklisted, False otherwise + pub fn is_node_blacklisted( + &self, + py: Python, + pool_id: u32, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Check if node is blacklisted + contracts_clone + .compute_pool + .is_node_blacklisted(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to check blacklist status: {}", + e + )) + }) + }) + }) + } + + /// Get all blacklisted nodes for a compute pool + /// + /// Args: + /// pool_id: The ID of the compute pool + /// + /// Returns: + /// List of blacklisted node addresses + pub fn get_blacklisted_nodes(&self, py: Python, pool_id: u32) -> PyResult> { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Get blacklisted nodes + let nodes = contracts_clone + .compute_pool + .get_blacklisted_nodes(pool_id) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to get blacklisted nodes: {}", + e + )) + })?; + + // Convert addresses to strings + Ok(nodes.iter().map(|addr| addr.to_string()).collect()) + }) + }) + } + + /// Check if a node is in a compute pool + /// + /// Args: + /// pool_id: The ID of the compute pool + /// node_address: The address of the node to check + /// + /// Returns: + /// True if the node is in the pool, False otherwise + pub fn is_node_in_pool( + &self, + py: Python, + pool_id: u32, + node_address: String, + ) -> PyResult { + let rt = self.get_or_create_runtime()?; + + let contracts = self.contracts.as_ref().ok_or_else(|| { + PyErr::new::( + "Contracts not initialized. Call start() first.", + ) + })?; + + // Parse node address + let node_addr = Address::from_str(&node_address).map_err(|e| { + PyErr::new::(format!("Invalid node address: {}", e)) + })?; + + let contracts_clone = contracts.clone(); + + py.allow_threads(|| { + rt.block_on(async { + // Check if node is in pool + contracts_clone + .compute_pool + .is_node_in_pool(pool_id, node_addr) + .await + .map_err(|e| { + PyErr::new::(format!( + "Failed to check if node is in pool: {}", + e + )) + }) + }) + }) + } } // Private implementation methods From efef3215b374761cc35d794fc46960fb78dc0e1b Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 11:10:34 +0200 Subject: [PATCH 17/23] update readme and improve examples --- crates/prime-protocol-py/README.md | 78 +++++++++ .../prime-protocol-py/examples/basic_usage.py | 153 ----------------- .../examples/orchestrator.py | 159 ++++++++++++++++++ .../examples/orchestrator_list_nodes.py | 99 ----------- .../prime-protocol-py/examples/validator.py | 103 ++++++++++++ .../examples/validator_list_nodes.py | 71 -------- crates/prime-protocol-py/examples/worker.py | 93 ++++++++++ 7 files changed, 433 insertions(+), 323 deletions(-) delete mode 100644 crates/prime-protocol-py/examples/basic_usage.py create mode 100644 crates/prime-protocol-py/examples/orchestrator.py delete mode 100644 crates/prime-protocol-py/examples/orchestrator_list_nodes.py create mode 100644 crates/prime-protocol-py/examples/validator.py delete mode 100644 crates/prime-protocol-py/examples/validator_list_nodes.py create mode 100644 crates/prime-protocol-py/examples/worker.py diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md index e69de29b..3a496628 100644 --- a/crates/prime-protocol-py/README.md +++ b/crates/prime-protocol-py/README.md @@ -0,0 +1,78 @@ +# Prime Protocol Python SDK + +## Local Development Setup +Startup the local dev chain with contracts using the following Make cmd from the base folder: +```bash +make bootstrap +``` + +Within this folder run the following to install the setup: +```bash +make install +``` + +Whenever you change something in your code, simply run `make build`. + +## Example flow: +The following example flow implements the typical prime protocol toplogy with an orchestrator, validator and worker node. +It still uses a centralized discovery service which will be replaced shortly. + +```bash +export PRIVATE_KEY_NODE=XYZ +export PRIVATE_KEY_PROVIDER=XYZ +``` +(for the private keys you can just use the keys from the `.env.example`) + +```bash +uv run examples/worker.py +``` + +You'll likely need to whitelist the provider from the base folder: +```bash +make whitelist-provider +``` + +You should be able to see the node now on the discovery service: +```bash +curl http://localhost:8089/api/platform -H "Authorization: Bearer prime" | jq +``` + +### Validator: +Next run the validator to ensure nodes are validated on chain. This is very basic logic. The pool owner could come up with something more sophisticated in the future. +```bash +uv run examples/validator.py +``` + +On underlying python side all the logic here is within: +```python +from primeprotocol import ValidatorClient +``` + +You should actually see that the node is validated now on the discovery service. + +### Orchestrator / pool owner +Once we actually have validated nodes, we still want to invite them to the compute pool and have them contribute work. + +Use the pool owner private key from the .env.example and export: +```bash +export PRIVATE_KEY_ORCHESTRATOR=xyz +``` +To run the orchestrator simply execute: +```bash +uv run examples/orchestrator.py +``` +You should see that the orchestrator invites the worker node that you started in the earlier step and keeps sending messages via p2p to this node. + +## TODOs: +- [ ] restarting node - no longer getting messages (since p2p id changed)? +- [ ] can validator send message? (special validation message) +- [ ] can orchestrator send message? +- [ ] whitelist using python api? +- [ ] borrow bug? +- [ ] restart keeps increasing provider stake? +- [ ] p2p cleanup + +## Known Limitations: +1. Orchestrator can no longer send messages to a node when the worker node restarts. +This is because the discovery info is not refreshed when the node was already active in the pool. This will change when introducing a proper DHT layer. +Simply run `make down` and restart the setup from scatch. \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py deleted file mode 100644 index 365f883d..00000000 --- a/crates/prime-protocol-py/examples/basic_usage.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -"""Example usage of the Prime Protocol Python client.""" - -import asyncio -import logging -import os -import signal -import sys -import time -from typing import Dict, Any, Optional -from primeprotocol import WorkerClient - -FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' -logging.basicConfig(format=FORMAT) -logging.getLogger().setLevel(logging.INFO) - - -def handle_pool_owner_message(message: Dict[str, Any]) -> None: - """Handle messages from pool owner""" - logging.info(f"Received message from pool owner: {message}") - - if message.get("type") == "inference_request": - prompt = message.get("prompt", "") - # Simulate processing the inference request - response = f"Processed: {prompt}" - - logging.info(f"Processing inference request: {prompt}") - logging.info(f"Generated response: {response}") - - # In a real implementation, you would send the response back - # client.send_response({"type": "inference_response", "result": response}) - else: - logging.info("Sending PONG response") - # client.send_response("PONG") - - -def handle_validator_message(message: Dict[str, Any]) -> None: - """Handle messages from validator""" - logging.info(f"Received message from validator: {message}") - - if message.get("type") == "inference_request": - prompt = message.get("prompt", "") - # Simulate processing the inference request - response = f"Validated: {prompt}" - - logging.info(f"Processing validation request: {prompt}") - logging.info(f"Generated response: {response}") - - # In a real implementation, you would send the response back - # client.send_response({"type": "inference_response", "result": response}) - - -def check_for_messages(client: WorkerClient) -> None: - """Check for new messages from pool owner and validator""" - try: - # Check for pool owner messages - pool_owner_message = client.get_pool_owner_message() - if pool_owner_message: - handle_pool_owner_message(pool_owner_message) - - # Check for validator messages - validator_message = client.get_validator_message() - if validator_message: - handle_validator_message(validator_message) - - except Exception as e: - logging.error(f"Error checking for messages: {e}") - - -def main(): - rpc_url = os.getenv("RPC_URL", "http://localhost:8545") - pool_id = os.getenv("POOL_ID", 0) - private_key_provider = os.getenv("PRIVATE_KEY_PROVIDER", None) - private_key_node = os.getenv("PRIVATE_KEY_NODE", None) - - logging.info(f"Connecting to: {rpc_url}") - - peer_id = os.getenv("PEER_ID", "12D3KooWELi4p1oR3QBSYiq1rvPpyjbkiQVhQJqCobBBUS7C6JrX") - port = int(os.getenv("PORT", 8003)) - peer_port = int(os.getenv("PEER_PORT", port-1)) - send_message_to_peer = os.getenv("SEND_MESSAGE_TO_PEER", "True").lower() == "true" - client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node, port) - - def signal_handler(sig, frame): - logging.info("Received interrupt signal, shutting down gracefully...") - try: - client.stop() - logging.info("Client stopped successfully") - except Exception as e: - logging.error(f"Error during shutdown: {e}") - sys.exit(0) - - # Register signal handler for Ctrl+C before starting client - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - logging.info("Starting client... (Press Ctrl+C to interrupt)") - client.start() - - client.upload_to_discovery("127.0.0.1", None) - - my_peer_id = client.get_own_peer_id() - logging.info(f"My Peer ID: {my_peer_id}") - - time.sleep(5) - peer_multi_addr = f"/ip4/127.0.0.1/tcp/{peer_port}" - - if send_message_to_peer: - print(f"Sending message to peer: {peer_id} on {peer_multi_addr}") - client.send_message(peer_id, b"Hello, world!", [peer_multi_addr]) - - logging.info("Setup completed. Starting message polling loop...") - print("Worker client started. Polling for messages. Press Ctrl+C to stop.") - - # Message polling loop - while True: - try: - message = client.get_next_message() - if message: - logging.info(f"Received full message: {message}") - logging.info(f"Received message from peer {message['peer_id']}") - if message.get('sender_address'): - logging.info(f"Sender Ethereum address: {message['sender_address']}") - - msg_data = message.get('message', {}) - if msg_data.get('type') == 'general': - data = bytes(msg_data.get('data', [])) - logging.info(f"Message data: {data}") - else: - logging.info(f"Message type: {msg_data.get('type')}") - - time.sleep(0.1) # Small delay to prevent busy waiting - except KeyboardInterrupt: - # Handle Ctrl+C during message polling - logging.info("Keyboard interrupt received during polling") - signal_handler(signal.SIGINT, None) - break - - except KeyboardInterrupt: - # Handle Ctrl+C during client startup - logging.info("Keyboard interrupt received during startup") - signal_handler(signal.SIGINT, None) - except Exception as e: - logging.error(f"Unexpected error: {e}") - try: - client.stop() - except: - pass - sys.exit(1) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/orchestrator.py b/crates/prime-protocol-py/examples/orchestrator.py new file mode 100644 index 00000000..a4d770fb --- /dev/null +++ b/crates/prime-protocol-py/examples/orchestrator.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +Example demonstrating orchestrator client with continuous node invitation and messaging loop. +""" + +import os +import logging +import time +from primeprotocol import OrchestratorClient + +# Configure logging +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + +def main(): + # Get configuration from environment variables + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + private_key = os.getenv("PRIVATE_KEY_ORCHESTRATOR") + discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") + discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] + + # Orchestrator loop configuration - process every 10 seconds + orchestrator_interval = 10 + pool_id = 0 # Example pool ID + + if not private_key: + print("Error: PRIVATE_KEY_ORCHESTRATOR environment variable is required") + return + + try: + # Initialize the orchestrator client + print(f"Initializing orchestrator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + print(f"Pool ID: {pool_id}") + print(f"Orchestrator interval: {orchestrator_interval} seconds") + + orchestrator = OrchestratorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls + ) + + print("Starting orchestrator client...") + orchestrator.start(p2p_port=8180) + print("Orchestrator client started") + + # Wait for P2P connections to establish + print("Waiting for P2P connections to establish...") + time.sleep(5) + + # Continuous orchestrator loop + print("\nStarting continuous orchestrator loop...") + print("Press Ctrl+C to stop the orchestrator") + + while True: + try: + print(f"\n{'='*50}") + print(f"Starting orchestrator cycle at {time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"{'='*50}") + + # List nodes for the specific pool + print(f"Fetching nodes for pool {pool_id}...") + pool_nodes = orchestrator.list_nodes_for_pool(pool_id) + + if not pool_nodes: + print(f"No nodes found in pool {pool_id}") + else: + print(f"Found {len(pool_nodes)} nodes in pool {pool_id}") + + invited_count = 0 + messaged_count = 0 + error_count = 0 + + for i, node in enumerate(pool_nodes): + print(f"\nProcessing node {i+1}/{len(pool_nodes)}:") + print(f" ID: {node.id}") + print(f" Provider Address: {node.provider_address}") + print(f" Validated: {node.is_validated}") + print(f" Active: {node.is_active}") + + if not node.is_validated: + print(f" ⚠ Node {node.id} is not validated, skipping") + continue + + if node.is_active is False: + # Invite inactive but validated nodes + try: + print(f" 📨 Inviting node {node.id}...") + orchestrator.invite_node( + peer_id=node.worker_p2p_id, + worker_address=node.id, + pool_id=pool_id, + multiaddrs=node.worker_p2p_addresses, + domain_id=0, # todo: automatically fetch from contract + orchestrator_url=None, # todo: deprecate + expiration_seconds=1000 + ) + print(f" ✓ Node {node.id} invited successfully") + invited_count += 1 + except Exception as e: + print(f" ✗ Error inviting node {node.id}: {e}") + error_count += 1 + else: + # Send message to active nodes + try: + print(f" 💬 Sending message to active node {node.id}...") + orchestrator.send_message( + peer_id=node.worker_p2p_id, + multiaddrs=node.worker_p2p_addresses, + data=b"Hello, world!", + ) + print(f" ✓ Message sent to node {node.id}") + messaged_count += 1 + except Exception as e: + print(f" ✗ Error sending message to node {node.id}: {e}") + error_count += 1 + + # Get summary statistics + active_count = sum(1 for node in pool_nodes if node.is_active) + inactive_count = sum(1 for node in pool_nodes if not node.is_active and node.is_validated) + unvalidated_count = sum(1 for node in pool_nodes if not node.is_validated) + + print(f"\nOrchestrator cycle summary:") + print(f" Total nodes in pool: {len(pool_nodes)}") + print(f" Active nodes: {active_count}") + print(f" Inactive (validated) nodes: {inactive_count}") + print(f" Unvalidated nodes: {unvalidated_count}") + print(f" Nodes invited: {invited_count}") + print(f" Messages sent: {messaged_count}") + print(f" Errors: {error_count}") + + # Wait before next orchestrator cycle + print(f"\nWaiting {orchestrator_interval} seconds before next orchestrator cycle...") + time.sleep(orchestrator_interval) + + except KeyboardInterrupt: + print("\n\nReceived interrupt signal. Shutting down orchestrator...") + break + except Exception as e: + logging.error(f"Error during orchestrator cycle: {e}") + print(f"Waiting {orchestrator_interval} seconds before retrying...") + time.sleep(orchestrator_interval) + + print("Orchestrator stopped") + + except Exception as e: + logging.error(f"Fatal error: {e}") + raise + finally: + # Stop the orchestrator + try: + orchestrator.stop() + except: + pass + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py b/crates/prime-protocol-py/examples/orchestrator_list_nodes.py deleted file mode 100644 index 874f9409..00000000 --- a/crates/prime-protocol-py/examples/orchestrator_list_nodes.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python3 -""" -Example demonstrating how to list nodes for a specific pool using the OrchestratorClient. -""" - -import os -import signal -import sys -from time import sleep -from primeprotocol import OrchestratorClient - -def signal_handler(sig, frame): - print('\nShutting down gracefully...') - sys.exit(0) - -def main(): - # Set up signal handler for Ctrl+C - signal.signal(signal.SIGINT, signal_handler) - - # Replace with your actual RPC URL and private key - RPC_URL = "http://localhost:8545" - PRIVATE_KEY = os.getenv("ORCHESTRATOR_PRIVATE_KEY") - DISCOVERY_URLS = ["http://localhost:8089"] # Discovery service URLs - - # Create orchestrator client - orchestrator = OrchestratorClient( - rpc_url=RPC_URL, - private_key=PRIVATE_KEY, - discovery_urls=DISCOVERY_URLS - ) - - try: - # Initialize the orchestrator (without P2P for this example) - orchestrator.start(p2p_port=8180) - # todo: temp fix for establishing p2p connections - sleep(5) - - # List nodes for a specific pool (example pool ID: 0) - pool_id = 0 - pool_nodes = orchestrator.list_nodes_for_pool(pool_id) - print(f"Nodes in pool {pool_id}: {len(pool_nodes)}") - - # Print details of all nodes in the pool - for i, node in enumerate(pool_nodes): - print(f"\nNode {i+1}:") - print(f" ID: {node.id}") - print(f" Provider Address: {node.provider_address}") - print(f" IP Address: {node.ip_address}") - print(f" Port: {node.port}") - print(f" Pool ID: {node.compute_pool_id}") - print(f" Validated: {node.is_validated}") - print(f" Worker P2P Addresses: {node.worker_p2p_addresses}") - print(f" Active: {node.is_active}") - if node.worker_p2p_id: - print(f" Worker P2P ID: {node.worker_p2p_id}") - if node.is_active is False: - # Invite node with required parameters - orchestrator.invite_node( - peer_id=node.worker_p2p_id, - worker_address=node.id, - pool_id=pool_id, - multiaddrs=node.worker_p2p_addresses, - # todo: automatically fetch from contract - domain_id=0, - # tood: deprecate - orchestrator_url=None, - expiration_seconds=1000 - ) - else: - try: - # todo: we need an actual ack - orchestrator.send_message( - peer_id=node.worker_p2p_id, - multiaddrs=node.worker_p2p_addresses, - data=b"Hello, world!", - ) - print(f"Message sent to node {node.id}") - except Exception as e: - print(f"Error sending message to node {node.id}: {e}") - - print("\nPress Ctrl+C to exit...") - - # Keep the program running until Ctrl+C - while True: - try: - signal.pause() - except AttributeError: - # signal.pause() is not available on Windows - import time - time.sleep(1) - - except KeyboardInterrupt: - print('\nShutting down gracefully...') - finally: - # Stop the orchestrator - orchestrator.stop() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/validator.py b/crates/prime-protocol-py/examples/validator.py new file mode 100644 index 00000000..688572ba --- /dev/null +++ b/crates/prime-protocol-py/examples/validator.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Validator Client to continuously validate nodes.""" + +import os +import logging +import time +from typing import List +from primeprotocol import ValidatorClient + +# Configure logging +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + +def main(): + # Get configuration from environment variables + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + private_key = os.getenv("PRIVATE_KEY_VALIDATOR") + discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") + discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] + + # Validation loop configuration - validate every 10 seconds + validation_interval = 10 + + if not private_key: + print("Error: VALIDATOR_PRIVATE_KEY environment variable is required") + return + + try: + # Initialize the validator client + print(f"Initializing validator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + print(f"Validation interval: {validation_interval} seconds") + + validator = ValidatorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls, + ) + print("Starting validator client...") + validator.start() + print("Validator client started") + + # Continuous validation loop + print("\nStarting continuous validation loop...") + print("Press Ctrl+C to stop the validator") + + while True: + try: + print(f"\n{'='*50}") + print(f"Starting validation cycle at {time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"{'='*50}") + + # List all non-validated nodes + print("Fetching non-validated nodes from discovery service...") + non_validated_nodes = validator.list_non_validated_nodes() + + if not non_validated_nodes: + print("No non-validated nodes found") + else: + print(f"Found {len(non_validated_nodes)} non-validated nodes") + + for node in non_validated_nodes: + print(f"Processing node {node.id}...") + if node.is_validated is False: + print(f" Validating node {node.id}...") + validator.validate_node(node.id, node.provider_address) + print(f" ✓ Node {node.id} validated successfully") + else: + print(f" ℹ Node {node.id} is already validated") + + # Get summary statistics + all_nodes = validator.list_all_nodes_dict() + validated_count = sum(1 for node in all_nodes if node['is_validated']) + non_validated_count = len(all_nodes) - validated_count + + print(f"\nValidation cycle summary:") + print(f" Total nodes: {len(all_nodes)}") + print(f" Validated: {validated_count}") + print(f" Non-validated: {non_validated_count}") + + # Wait before next validation cycle + print(f"\nWaiting {validation_interval} seconds before next validation cycle...") + time.sleep(validation_interval) + + except KeyboardInterrupt: + print("\n\nReceived interrupt signal. Shutting down validator...") + break + except Exception as e: + logging.error(f"Error during validation cycle: {e}") + print(f"Waiting {validation_interval} seconds before retrying...") + time.sleep(validation_interval) + + print("Validator stopped") + + except Exception as e: + logging.error(f"Fatal error: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/validator_list_nodes.py b/crates/prime-protocol-py/examples/validator_list_nodes.py deleted file mode 100644 index d20e1660..00000000 --- a/crates/prime-protocol-py/examples/validator_list_nodes.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -"""Example usage of the Prime Protocol Validator Client to list non-validated nodes.""" - -import os -import logging -from typing import List -from primeprotocol import ValidatorClient - -# Configure logging -FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' -logging.basicConfig(format=FORMAT) -logging.getLogger().setLevel(logging.INFO) - -def main(): - # Get configuration from environment variables - rpc_url = os.getenv("RPC_URL", "http://localhost:8545") - private_key = os.getenv("VALIDATOR_PRIVATE_KEY") - discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") - discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] - - if not private_key: - print("Error: VALIDATOR_PRIVATE_KEY environment variable is required") - return - - try: - # Initialize the validator client - print(f"Initializing validator client...") - print(f"RPC URL: {rpc_url}") - print(f"Discovery URLs: {discovery_urls}") - - validator = ValidatorClient( - rpc_url=rpc_url, - private_key=private_key, - discovery_urls=discovery_urls, - ) - print("Starting validator client...") - validator.start() - print("Validator client started") - - # List all non-validated nodes - print("\nFetching non-validated nodes from discovery service...") - non_validated_nodes = validator.list_non_validated_nodes() - for node in non_validated_nodes: - print(node.id) - if node.is_validated is False: - print(f"Validating node {node.id}...") - validator.validate_node(node.id, node.provider_address) - print(f"Node {node.id} validated") - else: - print(f"Node {node.id} is already validated") - - # You can also get all nodes as dictionaries for more flexibility - print("\n\nFetching all nodes as dictionaries...") - all_nodes = validator.list_all_nodes_dict() - print(all_nodes) - - # Count validated vs non-validated - validated_count = sum(1 for node in all_nodes if node['is_validated']) - non_validated_count = len(all_nodes) - validated_count - - print(f"\nTotal nodes: {len(all_nodes)}") - print(f"Validated: {validated_count}") - print(f"Non-validated: {non_validated_count}") - - except Exception as e: - logging.error(f"Error: {e}") - raise - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/worker.py b/crates/prime-protocol-py/examples/worker.py new file mode 100644 index 00000000..267d52fe --- /dev/null +++ b/crates/prime-protocol-py/examples/worker.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Python client.""" + +import asyncio +import logging +import os +import signal +import sys +import time +from typing import Dict, Any, Optional +from primeprotocol import WorkerClient + +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def main(): + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + pool_id = os.getenv("POOL_ID", 0) + private_key_provider = os.getenv("PRIVATE_KEY_PROVIDER", None) + private_key_node = os.getenv("PRIVATE_KEY_NODE", None) + + logging.info(f"Connecting to: {rpc_url}") + + port = int(os.getenv("PORT", 8003)) + client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node, port) + + def signal_handler(sig, frame): + logging.info("Received interrupt signal, shutting down gracefully...") + try: + client.stop() + logging.info("Client stopped successfully") + except Exception as e: + logging.error(f"Error during shutdown: {e}") + sys.exit(0) + + # Register signal handler for Ctrl+C before starting client + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + logging.info("Starting client... (Press Ctrl+C to interrupt)") + client.start() + + client.upload_to_discovery("127.0.0.1", None) + + my_peer_id = client.get_own_peer_id() + logging.info(f"My Peer ID: {my_peer_id}") + + time.sleep(5) + + # Note: To send messages to other peers manually, you can use: + # peer_id = "12D3KooWELi4p1oR3QBSYiq1rvPpyjbkiQVhQJqCobBBUS7C6JrX" + # peer_multi_addr = "/ip4/127.0.0.1/tcp/8002" + # client.send_message(peer_id, b"Hello, world!", [peer_multi_addr]) + + logging.info("Setup completed. Starting message polling loop...") + print("Worker client started. Polling for orchestrator/validator messages. Press Ctrl+C to stop.") + + # Message polling loop - listening for orchestrator and validator messages + while True: + try: + message = client.get_next_message() + if message: + msg_data = message.get('message', {}) + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + print(f"Message from {message['peer_id']}: {data}") + else: + print(f"Message from {message['peer_id']}: type={msg_data.get('type')}") + + time.sleep(0.1) # Small delay to prevent busy waiting + except KeyboardInterrupt: + # Handle Ctrl+C during message polling + logging.info("Keyboard interrupt received during polling") + signal_handler(signal.SIGINT, None) + break + + except KeyboardInterrupt: + # Handle Ctrl+C during client startup + logging.info("Keyboard interrupt received during startup") + signal_handler(signal.SIGINT, None) + except Exception as e: + logging.error(f"Unexpected error: {e}") + try: + client.stop() + except: + pass + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file From 4722bbb5eed54e8dede03ba27ed2fcccb11fd64a Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 13:12:33 +0200 Subject: [PATCH 18/23] basic detection of validator + orchestrator role, basic message sending between nodes --- crates/prime-protocol-py/README.md | 16 +- .../examples/orchestrator.py | 238 +++++++++--------- .../prime-protocol-py/examples/validator.py | 173 +++++++------ crates/prime-protocol-py/examples/worker.py | 44 +++- .../prime-protocol-py/src/orchestrator/mod.rs | 35 ++- .../src/p2p_handler/common.rs | 2 + .../src/p2p_handler/message_processor.rs | 85 ++++++- .../prime-protocol-py/src/p2p_handler/mod.rs | 16 +- crates/prime-protocol-py/src/validator/mod.rs | 5 + .../src/worker/blockchain.rs | 121 ++++++++- crates/prime-protocol-py/src/worker/client.rs | 54 +++- crates/prime-protocol-py/src/worker/mod.rs | 4 + 12 files changed, 580 insertions(+), 213 deletions(-) diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md index 3a496628..0ba36ef8 100644 --- a/crates/prime-protocol-py/README.md +++ b/crates/prime-protocol-py/README.md @@ -1,5 +1,11 @@ # Prime Protocol Python SDK +## Features + +- Worker Client +- Validator Client +- Orchestrator Client + ## Local Development Setup Startup the local dev chain with contracts using the following Make cmd from the base folder: ```bash @@ -64,15 +70,21 @@ uv run examples/orchestrator.py You should see that the orchestrator invites the worker node that you started in the earlier step and keeps sending messages via p2p to this node. ## TODOs: -- [ ] restarting node - no longer getting messages (since p2p id changed)? - [ ] can validator send message? (special validation message) +- [ ] can the worker send messages as response to the validator / orchestrator? - [ ] can orchestrator send message? - [ ] whitelist using python api? - [ ] borrow bug? - [ ] restart keeps increasing provider stake? - [ ] p2p cleanup +- [ ] I keep forgetting to run make build +- [ ] what about formatting? +- 09, in main + validator.stop() + ^^^^^^^^^^^^^^ +AttributeError: 'builtins.ValidatorClient' object has no attribute 'stop' ## Known Limitations: 1. Orchestrator can no longer send messages to a node when the worker node restarts. This is because the discovery info is not refreshed when the node was already active in the pool. This will change when introducing a proper DHT layer. -Simply run `make down` and restart the setup from scatch. \ No newline at end of file +Shortcut here is to simply reset the discovery service using: `redis-cli -p 6380 "FLUSHALL"` \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/orchestrator.py b/crates/prime-protocol-py/examples/orchestrator.py index a4d770fb..f7f05773 100644 --- a/crates/prime-protocol-py/examples/orchestrator.py +++ b/crates/prime-protocol-py/examples/orchestrator.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -""" -Example demonstrating orchestrator client with continuous node invitation and messaging loop. -""" +"""Example usage of the Prime Protocol Orchestrator Client.""" import os import logging @@ -19,141 +17,133 @@ def main(): private_key = os.getenv("PRIVATE_KEY_ORCHESTRATOR") discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] - - # Orchestrator loop configuration - process every 10 seconds - orchestrator_interval = 10 - pool_id = 0 # Example pool ID + pool_id = int(os.getenv("POOL_ID", "0")) if not private_key: print("Error: PRIVATE_KEY_ORCHESTRATOR environment variable is required") return + print(f"Initializing orchestrator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + print(f"Pool ID: {pool_id}") + + # Initialize and start the orchestrator + orchestrator = OrchestratorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls, + ) + + print("Starting orchestrator client...") + orchestrator.start(p2p_port=8180) + print("Orchestrator client started") + + # Wait for P2P to initialize + time.sleep(5) + + print("\nStarting orchestrator loop...") + print("Press Ctrl+C to stop\n") + + # Track nodes in our pool + pool_node_ids = set() + try: - # Initialize the orchestrator client - print(f"Initializing orchestrator client...") - print(f"RPC URL: {rpc_url}") - print(f"Discovery URLs: {discovery_urls}") - print(f"Pool ID: {pool_id}") - print(f"Orchestrator interval: {orchestrator_interval} seconds") - - orchestrator = OrchestratorClient( - rpc_url=rpc_url, - private_key=private_key, - discovery_urls=discovery_urls - ) - - print("Starting orchestrator client...") - orchestrator.start(p2p_port=8180) - print("Orchestrator client started") - - # Wait for P2P connections to establish - print("Waiting for P2P connections to establish...") - time.sleep(5) - - # Continuous orchestrator loop - print("\nStarting continuous orchestrator loop...") - print("Press Ctrl+C to stop the orchestrator") - while True: - try: - print(f"\n{'='*50}") - print(f"Starting orchestrator cycle at {time.strftime('%Y-%m-%d %H:%M:%S')}") - print(f"{'='*50}") - - # List nodes for the specific pool - print(f"Fetching nodes for pool {pool_id}...") - pool_nodes = orchestrator.list_nodes_for_pool(pool_id) + print(f"{'='*50}") + print(f"Cycle at {time.strftime('%H:%M:%S')}") + + # Check for any pending messages first + print("Checking for pending messages...") + message = orchestrator.get_next_message() + while message: + peer_id = message['peer_id'] + msg_data = message.get('message', {}) - if not pool_nodes: - print(f"No nodes found in pool {pool_id}") + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ + "POOL_OWNER" if message.get('is_sender_pool_owner') else \ + "WORKER" + print(f" 📨 From {peer_id[:16]}... ({sender_type}): {data}") + elif msg_data.get('type') == 'authentication_complete': + print(f" ✓ Auth complete with {peer_id[:16]}...") else: - print(f"Found {len(pool_nodes)} nodes in pool {pool_id}") + print(f" 📋 {msg_data.get('type')} from {peer_id[:16]}...") + + # Check for more messages + message = orchestrator.get_next_message() + + # Get nodes in the pool + pool_nodes = orchestrator.list_nodes_for_pool(pool_id) + print(f"\nFound {len(pool_nodes)} nodes in pool {pool_id}") + + # Update our tracking of pool nodes + pool_node_ids = {node.worker_p2p_id for node in pool_nodes if node.worker_p2p_id} + + # Process each node + for node in pool_nodes: + if not node.worker_p2p_id: + continue - invited_count = 0 - messaged_count = 0 - error_count = 0 + try: + if not node.is_active and node.is_validated: + # Invite inactive but validated nodes + print(f" 📨 Inviting {node.id[:8]}...") + orchestrator.invite_node( + peer_id=node.worker_p2p_id, + worker_address=node.id, + pool_id=pool_id, + multiaddrs=node.worker_p2p_addresses, + domain_id=0, + orchestrator_url=None, + expiration_seconds=1000 + ) + elif node.is_active: + # Send message to active nodes + print(f" 💬 Messaging {node.id[:8]}...") + orchestrator.send_message( + peer_id=node.worker_p2p_id, + multiaddrs=node.worker_p2p_addresses, + data=b"Hello from orchestrator!", + ) + except Exception as e: + print(f" ❌ Error with {node.id[:8]}: {e}") + + # Check for messages throughout the wait period + print(f"\nWaiting 30 seconds (checking for messages)...") + end_time = time.time() + 30 + messages_during_wait = 0 + + while time.time() < end_time: + message = orchestrator.get_next_message() + if message: + peer_id = message['peer_id'] + msg_data = message.get('message', {}) - for i, node in enumerate(pool_nodes): - print(f"\nProcessing node {i+1}/{len(pool_nodes)}:") - print(f" ID: {node.id}") - print(f" Provider Address: {node.provider_address}") - print(f" Validated: {node.is_validated}") - print(f" Active: {node.is_active}") + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ + "POOL_OWNER" if message.get('is_sender_pool_owner') else \ + "WORKER" - if not node.is_validated: - print(f" ⚠ Node {node.id} is not validated, skipping") - continue - - if node.is_active is False: - # Invite inactive but validated nodes - try: - print(f" 📨 Inviting node {node.id}...") - orchestrator.invite_node( - peer_id=node.worker_p2p_id, - worker_address=node.id, - pool_id=pool_id, - multiaddrs=node.worker_p2p_addresses, - domain_id=0, # todo: automatically fetch from contract - orchestrator_url=None, # todo: deprecate - expiration_seconds=1000 - ) - print(f" ✓ Node {node.id} invited successfully") - invited_count += 1 - except Exception as e: - print(f" ✗ Error inviting node {node.id}: {e}") - error_count += 1 + # Check if message is from a pool node + if peer_id in pool_node_ids: + print(f" 📨 From pool node {peer_id[:16]}... ({sender_type}): {data}") else: - # Send message to active nodes - try: - print(f" 💬 Sending message to active node {node.id}...") - orchestrator.send_message( - peer_id=node.worker_p2p_id, - multiaddrs=node.worker_p2p_addresses, - data=b"Hello, world!", - ) - print(f" ✓ Message sent to node {node.id}") - messaged_count += 1 - except Exception as e: - print(f" ✗ Error sending message to node {node.id}: {e}") - error_count += 1 - - # Get summary statistics - active_count = sum(1 for node in pool_nodes if node.is_active) - inactive_count = sum(1 for node in pool_nodes if not node.is_active and node.is_validated) - unvalidated_count = sum(1 for node in pool_nodes if not node.is_validated) - - print(f"\nOrchestrator cycle summary:") - print(f" Total nodes in pool: {len(pool_nodes)}") - print(f" Active nodes: {active_count}") - print(f" Inactive (validated) nodes: {inactive_count}") - print(f" Unvalidated nodes: {unvalidated_count}") - print(f" Nodes invited: {invited_count}") - print(f" Messages sent: {messaged_count}") - print(f" Errors: {error_count}") - - # Wait before next orchestrator cycle - print(f"\nWaiting {orchestrator_interval} seconds before next orchestrator cycle...") - time.sleep(orchestrator_interval) - - except KeyboardInterrupt: - print("\n\nReceived interrupt signal. Shutting down orchestrator...") - break - except Exception as e: - logging.error(f"Error during orchestrator cycle: {e}") - print(f"Waiting {orchestrator_interval} seconds before retrying...") - time.sleep(orchestrator_interval) - + print(f" 📨 From external {peer_id[:16]}... ({sender_type}): {data}") + messages_during_wait += 1 + else: + time.sleep(0.1) + + if messages_during_wait > 0: + print(f"Received {messages_during_wait} messages during wait") + print() + + except KeyboardInterrupt: + print("\n\nShutting down orchestrator...") + orchestrator.stop() print("Orchestrator stopped") - - except Exception as e: - logging.error(f"Fatal error: {e}") - raise - finally: - # Stop the orchestrator - try: - orchestrator.stop() - except: - pass if __name__ == "__main__": main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/validator.py b/crates/prime-protocol-py/examples/validator.py index 688572ba..2b4af7cc 100644 --- a/crates/prime-protocol-py/examples/validator.py +++ b/crates/prime-protocol-py/examples/validator.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 -"""Example usage of the Prime Protocol Validator Client to continuously validate nodes.""" +"""Example usage of the Prime Protocol Validator Client.""" import os import logging import time -from typing import List from primeprotocol import ValidatorClient # Configure logging @@ -18,86 +17,116 @@ def main(): private_key = os.getenv("PRIVATE_KEY_VALIDATOR") discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] - - # Validation loop configuration - validate every 10 seconds - validation_interval = 10 + p2p_port = int(os.getenv("VALIDATOR_P2P_PORT", "8665")) if not private_key: - print("Error: VALIDATOR_PRIVATE_KEY environment variable is required") + print("Error: PRIVATE_KEY_VALIDATOR environment variable is required") return + print(f"Initializing validator client...") + print(f"RPC URL: {rpc_url}") + print(f"Discovery URLs: {discovery_urls}") + print(f"P2P Port: {p2p_port}") + + # Initialize and start the validator + validator = ValidatorClient( + rpc_url=rpc_url, + private_key=private_key, + discovery_urls=discovery_urls, + ) + + print("Starting validator client...") + validator.start(p2p_port=p2p_port) + print(f"Validator started with peer ID: {validator.get_peer_id()}") + + print("\nStarting validator loop...") + print("Press Ctrl+C to stop\n") + try: - # Initialize the validator client - print(f"Initializing validator client...") - print(f"RPC URL: {rpc_url}") - print(f"Discovery URLs: {discovery_urls}") - print(f"Validation interval: {validation_interval} seconds") - - validator = ValidatorClient( - rpc_url=rpc_url, - private_key=private_key, - discovery_urls=discovery_urls, - ) - print("Starting validator client...") - validator.start() - print("Validator client started") - - # Continuous validation loop - print("\nStarting continuous validation loop...") - print("Press Ctrl+C to stop the validator") - while True: - try: - print(f"\n{'='*50}") - print(f"Starting validation cycle at {time.strftime('%Y-%m-%d %H:%M:%S')}") - print(f"{'='*50}") - - # List all non-validated nodes - print("Fetching non-validated nodes from discovery service...") - non_validated_nodes = validator.list_non_validated_nodes() + print(f"{'='*50}") + print(f"Cycle at {time.strftime('%H:%M:%S')}") + + # Check for messages first + print("Checking for any pending messages...") + message = validator.get_next_message() + while message: + peer_id = message['peer_id'] + msg_data = message.get('message', {}) - if not non_validated_nodes: - print("No non-validated nodes found") + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ + "POOL_OWNER" if message.get('is_sender_pool_owner') else \ + "WORKER" + print(f" 📨 From {peer_id[:16]}... ({sender_type}): {data}") + elif msg_data.get('type') == 'authentication_complete': + print(f" ✓ Auth complete with {peer_id[:16]}...") else: - print(f"Found {len(non_validated_nodes)} non-validated nodes") - - for node in non_validated_nodes: - print(f"Processing node {node.id}...") - if node.is_validated is False: - print(f" Validating node {node.id}...") - validator.validate_node(node.id, node.provider_address) - print(f" ✓ Node {node.id} validated successfully") - else: - print(f" ℹ Node {node.id} is already validated") + print(f" 📋 {msg_data.get('type')} from {peer_id[:16]}...") - # Get summary statistics - all_nodes = validator.list_all_nodes_dict() - validated_count = sum(1 for node in all_nodes if node['is_validated']) - non_validated_count = len(all_nodes) - validated_count - - print(f"\nValidation cycle summary:") - print(f" Total nodes: {len(all_nodes)}") - print(f" Validated: {validated_count}") - print(f" Non-validated: {non_validated_count}") - - # Wait before next validation cycle - print(f"\nWaiting {validation_interval} seconds before next validation cycle...") - time.sleep(validation_interval) - - except KeyboardInterrupt: - print("\n\nReceived interrupt signal. Shutting down validator...") - break - except Exception as e: - logging.error(f"Error during validation cycle: {e}") - print(f"Waiting {validation_interval} seconds before retrying...") - time.sleep(validation_interval) - + # Check for more messages + message = validator.get_next_message() + + # 1. Validate any non-validated nodes + non_validated = validator.list_non_validated_nodes() + if non_validated: + print(f"\nValidating {len(non_validated)} nodes...") + for node in non_validated: + try: + print(f" ✅ Validating {node.id[:8]}...") + validator.validate_node(node.id, node.provider_address) + except Exception as e: + print(f" ❌ Error validating {node.id[:8]}: {e}") + else: + print("\nNo nodes to validate") + + # 2. Send messages to validated nodes + all_nodes = validator.list_all_nodes_dict() + validated = [n for n in all_nodes if n.get('is_validated') and n.get('worker_p2p_id')] + + if validated: + print(f"\nMessaging {len(validated)} validated nodes...") + for node in validated: + try: + validator.send_message( + peer_id=node['worker_p2p_id'], + multiaddrs=node.get('worker_p2p_addresses', []), + data=b"Hello from validator!", + ) + print(f" 💬 Sent to {node['id'][:8]}...") + except Exception as e: + print(f" ❌ Error messaging {node['id'][:8]}: {e}") + + # 3. Check for more messages throughout the wait period + print(f"\nWaiting 10 seconds (checking for messages)...") + end_time = time.time() + 10 + messages_during_wait = 0 + + while time.time() < end_time: + message = validator.get_next_message() + if message: + peer_id = message['peer_id'] + msg_data = message.get('message', {}) + + if msg_data.get('type') == 'general': + data = bytes(msg_data.get('data', [])) + sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ + "POOL_OWNER" if message.get('is_sender_pool_owner') else \ + "WORKER" + print(f" 📨 From {peer_id[:16]}... ({sender_type}): {data}") + messages_during_wait += 1 + else: + time.sleep(0.1) + + if messages_during_wait > 0: + print(f"Received {messages_during_wait} messages during wait") + print() + + except KeyboardInterrupt: + print("\n\nShutting down validator...") + validator.stop() print("Validator stopped") - - except Exception as e: - logging.error(f"Fatal error: {e}") - raise - if __name__ == "__main__": main() \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/worker.py b/crates/prime-protocol-py/examples/worker.py index 267d52fe..cd708040 100644 --- a/crates/prime-protocol-py/examples/worker.py +++ b/crates/prime-protocol-py/examples/worker.py @@ -26,6 +26,9 @@ def main(): port = int(os.getenv("PORT", 8003)) client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node, port) + # Track known peer addresses + known_peers = {} + def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down gracefully...") try: @@ -64,11 +67,48 @@ def signal_handler(sig, frame): message = client.get_next_message() if message: msg_data = message.get('message', {}) + peer_id = message['peer_id'] + multiaddrs = message.get('multiaddrs', []) + + # Determine sender type + sender_type = "worker" + if message.get('is_sender_validator'): + sender_type = "VALIDATOR" + elif message.get('is_sender_pool_owner'): + sender_type = "POOL_OWNER" + if msg_data.get('type') == 'general': data = bytes(msg_data.get('data', [])) - print(f"Message from {message['peer_id']}: {data}") + print(f"\n[{time.strftime('%H:%M:%S')}] Message from {peer_id} ({sender_type}): {data}") + print(f" Multiaddrs received: {multiaddrs}") + + # Check if it's an error message + if data.startswith(b"ERROR:"): + print(f" ⚠️ Received error: {data}") + continue + + # Respond to validators and pool owners + if sender_type in ["VALIDATOR", "POOL_OWNER"]: + try: + response_msg = f"Hello {sender_type}! Worker received: {data.decode('utf-8', errors='ignore')}" + + print(f" Attempting to respond to {peer_id}...") + print(f" Response message: {response_msg}") + + # Send response using empty multiaddrs (peer should already be connected) + client.send_message( + peer_id=peer_id, + multiaddrs=[], # Empty - peer already connected + data=response_msg.encode() + ) + print(f" ✓ Response sent to {sender_type}") + except Exception as e: + print(f" ✗ Error sending response: {e}") + print(f" Error type: {type(e).__name__}") + elif msg_data.get('type') == 'authentication_complete': + print(f"\n[{time.strftime('%H:%M:%S')}] ✓ Authentication complete with {peer_id} ({sender_type})") else: - print(f"Message from {message['peer_id']}: type={msg_data.get('type')}") + print(f"\n[{time.strftime('%H:%M:%S')}] Message from {peer_id} ({sender_type}): type={msg_data.get('type')}") time.sleep(0.1) # Small delay to prevent busy waiting except KeyboardInterrupt: diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index 7af6894f..f6adff74 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -67,10 +67,13 @@ impl OrchestratorClient { let rpc_url_parsed = Url::parse(&rpc_url).map_err(|e| { PyErr::new::(format!("Invalid RPC URL: {}", e)) })?; - Some( - Wallet::new(&key, rpc_url_parsed) - .map_err(|e| PyErr::new::(e.to_string()))?, - ) + let w = Wallet::new(&key, rpc_url_parsed) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let wallet_address = w.wallet.default_signer().address().to_string(); + log::info!("Orchestrator wallet address: {}", wallet_address); + + Some(w) } else { None }; @@ -194,7 +197,12 @@ impl OrchestratorClient { message_type: MessageType::General { data }, peer_id, multiaddrs, - sender_address: None, + sender_address: self + .wallet + .as_ref() + .map(|w| w.wallet.default_signer().address().to_string()), + is_sender_validator: false, + is_sender_pool_owner: false, // This will be determined by the receiver response_tx: None, }; @@ -389,6 +397,8 @@ impl OrchestratorClient { peer_id, multiaddrs, sender_address: Some(wallet.wallet.default_signer().address().to_string()), + is_sender_validator: false, + is_sender_pool_owner: false, // This will be determined by the receiver response_tx: None, }; @@ -672,8 +682,12 @@ impl OrchestratorClient { let (user_message_tx, user_message_rx) = tokio::sync::mpsc::channel::(1000); - let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = - P2PService::new(keypair, port, cancellation_token.clone(), wallet_address)?; + let (p2p_service, outbound_tx, message_queue_rx, authenticated_peers) = P2PService::new( + keypair, + port, + cancellation_token.clone(), + wallet_address.clone(), + )?; let peer_id = p2p_service.node.peer_id(); let outbound_tx = Arc::new(Mutex::new(outbound_tx)); @@ -689,7 +703,12 @@ impl OrchestratorClient { user_message_tx, outbound_tx: outbound_tx.clone(), authenticated_peers, - cancellation_token, + cancellation_token: cancellation_token.clone(), + validator_addresses: Arc::new(std::collections::HashSet::new()), // Orchestrator doesn't need validator info + pool_owner_address: wallet_address + .as_ref() + .and_then(|addr| alloy::primitives::Address::from_str(addr).ok()), // Parse orchestrator's address + compute_manager_address: None, // Orchestrator doesn't need this }; let message_processor = MessageProcessor::from_config(config); diff --git a/crates/prime-protocol-py/src/p2p_handler/common.rs b/crates/prime-protocol-py/src/p2p_handler/common.rs index 7fd930ec..22703dcd 100644 --- a/crates/prime-protocol-py/src/p2p_handler/common.rs +++ b/crates/prime-protocol-py/src/p2p_handler/common.rs @@ -67,6 +67,8 @@ pub async fn send_message_with_auth( peer_id, multiaddrs, sender_address: Some(auth_manager.wallet_address()), + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: None, }; diff --git a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs index 8a9c2b68..8efa0eb3 100644 --- a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs +++ b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs @@ -1,6 +1,7 @@ use crate::error::Result; use crate::p2p_handler::auth::AuthenticationManager; use crate::p2p_handler::{Message, MessageType}; +use alloy::primitives::Address; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -19,6 +20,9 @@ pub struct MessageProcessorConfig { pub outbound_tx: Arc>>, pub authenticated_peers: Arc>>, pub cancellation_token: CancellationToken, + pub validator_addresses: Arc>, + pub pool_owner_address: Option

, + pub compute_manager_address: Option
, } /// Handles processing of incoming P2P messages @@ -29,6 +33,9 @@ pub struct MessageProcessor { outbound_tx: Arc>>, authenticated_peers: Arc>>, cancellation_token: CancellationToken, + validator_addresses: Arc>, + pool_owner_address: Option
, + compute_manager_address: Option
, } impl MessageProcessor { @@ -39,6 +46,9 @@ impl MessageProcessor { outbound_tx: Arc>>, authenticated_peers: Arc>>, cancellation_token: CancellationToken, + validator_addresses: Arc>, + pool_owner_address: Option
, + compute_manager_address: Option
, ) -> Self { Self { auth_manager, @@ -47,6 +57,9 @@ impl MessageProcessor { outbound_tx, authenticated_peers, cancellation_token, + validator_addresses, + pool_owner_address, + compute_manager_address, } } @@ -59,6 +72,9 @@ impl MessageProcessor { config.outbound_tx, config.authenticated_peers, config.cancellation_token, + config.validator_addresses, + config.pool_owner_address, + config.compute_manager_address, ) } @@ -109,6 +125,8 @@ impl MessageProcessor { peer_id, multiaddrs, sender_address, + is_sender_validator, + is_sender_pool_owner, response_tx, } = message; @@ -139,7 +157,9 @@ impl MessageProcessor { }, peer_id, multiaddrs, - sender_address, + sender_address: sender_address.clone(), + is_sender_validator: self.is_address_validator(&sender_address), + is_sender_pool_owner: self.is_address_pool_owner(&sender_address), response_tx: None, }; self.handle_auth_response(msg, challenge, signature).await @@ -169,7 +189,9 @@ impl MessageProcessor { message_type: MessageType::General { data }, peer_id, multiaddrs, - sender_address, + sender_address: sender_address.clone(), + is_sender_validator: self.is_address_validator(&sender_address), + is_sender_pool_owner: self.is_address_pool_owner(&sender_address), response_tx: None, }; self.user_message_tx.send(msg).await.map_err(|e| { @@ -237,6 +259,8 @@ impl MessageProcessor { peer_id: message.peer_id.clone(), multiaddrs: message.multiaddrs, sender_address: Some(self.auth_manager.wallet_address()), + is_sender_validator: false, // We're sending this message, so we check our own address + is_sender_pool_owner: false, // We're sending this message, so we check our own address response_tx: None, }; @@ -317,4 +341,61 @@ impl MessageProcessor { Ok(()) } + + /// Check if an address is a validator + fn is_address_validator(&self, address: &Option) -> bool { + if let Some(addr_str) = address { + // Parse the string address to an Address type + match Address::from_str(addr_str) { + Ok(addr) => { + let is_validator = self.validator_addresses.contains(&addr); + log::debug!( + "Checking if {} is a validator: {} (validator set has {} addresses)", + addr_str, + is_validator, + self.validator_addresses.len() + ); + is_validator + } + Err(e) => { + log::warn!("Failed to parse address {}: {}", addr_str, e); + false + } + } + } else { + log::debug!("No sender address provided, cannot check validator status"); + false + } + } + + /// Check if an address is the pool owner + fn is_address_pool_owner(&self, address: &Option) -> bool { + if let Some(addr_str) = address { + // Parse the string address to an Address type + match Address::from_str(addr_str) { + Ok(addr) => { + // Check if it's the pool creator OR the compute manager + let is_creator = self.pool_owner_address == Some(addr); + let is_compute_manager = self.compute_manager_address == Some(addr); + let is_owner = is_creator || is_compute_manager; + + log::debug!( + "Checking if {} is pool owner/manager: {} (pool owner: {:?}, compute manager: {:?})", + addr_str, + is_owner, + self.pool_owner_address, + self.compute_manager_address + ); + is_owner + } + Err(e) => { + log::warn!("Failed to parse address {}: {}", addr_str, e); + false + } + } + } else { + log::debug!("No sender address provided, cannot check pool owner status"); + false + } + } } diff --git a/crates/prime-protocol-py/src/p2p_handler/mod.rs b/crates/prime-protocol-py/src/p2p_handler/mod.rs index 4f7c2d6a..89b98cf6 100644 --- a/crates/prime-protocol-py/src/p2p_handler/mod.rs +++ b/crates/prime-protocol-py/src/p2p_handler/mod.rs @@ -50,6 +50,8 @@ pub struct Message { pub peer_id: String, pub multiaddrs: Vec, pub sender_address: Option, // Ethereum address of the sender + pub is_sender_validator: bool, // Whether the sender is a validator + pub is_sender_pool_owner: bool, // Whether the sender is the pool owner #[serde(skip)] pub response_tx: Option>, // For sending responses to auth requests } @@ -326,7 +328,9 @@ impl Service { peer_id: peer_id.to_string(), multiaddrs: vec![], // TODO: Extract multiaddrs from peer info sender_address: Some(sender_address), - response_tx: None, // General messages don't need response channels + is_sender_validator: false, // Will be populated by message processor + is_sender_pool_owner: false, // Will be populated by message processor + response_tx: None, // General messages don't need response channels }; if let Err(e) = message_queue_tx.send(message).await { @@ -359,6 +363,8 @@ impl Service { peer_id: peer_id.to_string(), multiaddrs: vec![], sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: Some(response_tx), // Pass the sender for response }; @@ -392,6 +398,8 @@ impl Service { peer_id: peer_id.to_string(), multiaddrs: vec![], sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: Some(response_tx), // Pass the sender for response }; @@ -456,6 +464,8 @@ impl Service { peer_id: peer_id.to_string(), multiaddrs: vec![], sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: None, }; @@ -474,6 +484,8 @@ impl Service { peer_id: peer_id.to_string(), multiaddrs: vec![], sender_address: None, + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: None, }; @@ -513,6 +525,8 @@ mod tests { peer_id: "12D3KooWExample".to_string(), multiaddrs: vec!["/ip4/127.0.0.1/tcp/4001".to_string()], sender_address: Some("0x1234567890123456789012345678901234567890".to_string()), + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: None, }; diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs index bfeec6a2..514ed6f5 100644 --- a/crates/prime-protocol-py/src/validator/mod.rs +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -229,6 +229,8 @@ impl ValidatorClient { peer_id, multiaddrs, sender_address: None, + is_sender_validator: true, // Validator is sending this message + is_sender_pool_owner: false, response_tx: None, }; @@ -450,6 +452,9 @@ impl ValidatorClient { outbound_tx: outbound_tx.clone(), authenticated_peers, cancellation_token, + validator_addresses: Arc::new(std::collections::HashSet::new()), // Validator doesn't need to check other validators + pool_owner_address: None, // Validator doesn't need pool owner info + compute_manager_address: None, // Validator doesn't need this }; let message_processor = MessageProcessor::from_config(config); diff --git a/crates/prime-protocol-py/src/worker/blockchain.rs b/crates/prime-protocol-py/src/worker/blockchain.rs index af5b97f3..793e3b05 100644 --- a/crates/prime-protocol-py/src/worker/blockchain.rs +++ b/crates/prime-protocol-py/src/worker/blockchain.rs @@ -1,11 +1,12 @@ use alloy::primitives::utils::format_ether; -use alloy::primitives::U256; +use alloy::primitives::{Address, U256}; use anyhow::{Context, Result}; use prime_core::operations::compute_node::ComputeNodeOperations; use prime_core::operations::provider::ProviderOperations; use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; use shared::web3::contracts::structs::compute_pool::PoolStatus; use shared::web3::wallet::{Wallet, WalletProvider}; +use std::sync::Arc; use url::Url; use crate::constants::{BLOCKCHAIN_OPERATION_TIMEOUT, DEFAULT_COMPUTE_UNITS}; @@ -25,6 +26,7 @@ pub struct BlockchainService { config: BlockchainConfig, provider_wallet: Option, node_wallet: Option, + contracts: Option>, } impl BlockchainService { @@ -36,6 +38,7 @@ impl BlockchainService { config, provider_wallet: None, node_wallet: None, + contracts: None, }) } @@ -56,6 +59,7 @@ impl BlockchainService { // Store the wallets self.provider_wallet = Some(provider_wallet.clone()); self.node_wallet = Some(node_wallet.clone()); + self.contracts = Some(contracts.clone()); self.wait_for_active_pool(&contracts).await?; self.ensure_provider_registered(&provider_wallet, &contracts) @@ -331,4 +335,119 @@ impl BlockchainService { log::info!("Successfully joined compute pool with tx: {}", result); Ok(()) } + + /// Get all validator addresses from the PrimeNetwork contract + pub async fn get_validator_addresses(&self) -> Arc> { + let contracts = match self.contracts.as_ref() { + Some(contracts) => contracts, + None => { + log::error!("Contracts not initialized"); + return Arc::new(std::collections::HashSet::new()); + } + }; + + match contracts.prime_network.get_validator_role().await { + Ok(validators) => { + log::info!( + "Fetched {} validator addresses from chain", + validators.len() + ); + let validator_set: std::collections::HashSet
= validators + .into_iter() + .inspect(|&addr| { + log::debug!("Validator address: {:?}", addr); + }) + .collect(); + log::info!("Validator addresses: {:?}", validator_set); + Arc::new(validator_set) + } + Err(e) => { + log::error!("Failed to get validator addresses: {}", e); + Arc::new(std::collections::HashSet::new()) + } + } + } + + /// Get the pool owner address for the current compute pool + pub async fn get_pool_owner_address(&self) -> Option
{ + let contracts = match self.contracts.as_ref() { + Some(contracts) => contracts, + None => { + log::error!("Contracts not initialized"); + return None; + } + }; + + log::info!( + "Fetching pool owner for pool ID: {}", + self.config.compute_pool_id + ); + + match contracts + .compute_pool + .get_pool_info(U256::from(self.config.compute_pool_id)) + .await + { + Ok(pool_info) => { + log::info!( + "Pool {} owner address: {:?}", + self.config.compute_pool_id, + pool_info.creator + ); + log::info!( + "Pool {} owner address (checksummed): {}", + self.config.compute_pool_id, + pool_info.creator + ); + Some(pool_info.creator) + } + Err(e) => { + log::error!( + "Failed to get pool info for pool {}: {}", + self.config.compute_pool_id, + e + ); + None + } + } + } + + /// Get the compute manager address for the current compute pool + pub async fn get_compute_manager_address(&self) -> Option
{ + let contracts = match self.contracts.as_ref() { + Some(contracts) => contracts, + None => { + log::error!("Contracts not initialized"); + return None; + } + }; + + match contracts + .compute_pool + .get_pool_info(U256::from(self.config.compute_pool_id)) + .await + { + Ok(pool_info) => { + log::info!( + "Pool {} compute manager address: {:?}", + self.config.compute_pool_id, + pool_info.compute_manager_key + ); + log::info!( + "Pool {} compute manager address (checksummed): {}", + self.config.compute_pool_id, + pool_info.compute_manager_key + ); + Some(pool_info.compute_manager_key) + } + Err(e) => { + log::error!( + "Failed to get pool info for pool {}: {}", + self.config.compute_pool_id, + e + ); + None + } + } + } } diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs index e87907c8..8eb2ca4d 100644 --- a/crates/prime-protocol-py/src/worker/client.rs +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -4,6 +4,7 @@ use crate::p2p_handler::auth::AuthenticationManager; use crate::p2p_handler::message_processor::{MessageProcessor, MessageProcessorConfig}; use crate::worker::blockchain::{BlockchainConfig, BlockchainService}; use crate::worker::p2p_handler::{Message, MessageType, Service as P2PService}; +use alloy::primitives::Address; use p2p::{Keypair, PeerId}; use std::sync::Arc; use tokio::sync::mpsc::{Receiver, Sender}; @@ -23,6 +24,10 @@ pub struct WorkerClientCore { user_message_tx: Option>, user_message_rx: Option>>>, message_processor_handle: Option>, + // Validator and pool owner info + validator_addresses: Option>>, + pool_owner_address: Option
, + compute_manager_address: Option
, } /// Configuration for the worker client @@ -101,6 +106,9 @@ impl WorkerClientCore { user_message_tx: Some(user_message_tx), user_message_rx: Some(Arc::new(Mutex::new(user_message_rx))), message_processor_handle: None, + validator_addresses: None, + pool_owner_address: None, + compute_manager_address: None, }) } @@ -109,6 +117,7 @@ impl WorkerClientCore { log::info!("Starting WorkerClient"); self.initialize_blockchain().await?; + self.initialize_validator_and_pool_info().await?; self.initialize_auth_manager()?; self.start_p2p_service().await?; self.start_message_processor().await?; @@ -153,7 +162,6 @@ impl WorkerClientCore { // Check if it's an invite and process it automatically if let MessageType::General { ref data } = message.message_type { if let Ok(invite) = serde_json::from_slice::(data) { - println!("Received invite from peer: {}", message.peer_id); log::info!("Received invite from peer: {}", message.peer_id); // Check if invite has expired @@ -241,6 +249,33 @@ impl WorkerClientCore { Ok(()) } + async fn initialize_validator_and_pool_info(&mut self) -> Result<()> { + let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) + })?; + + let validator_addresses = blockchain_service.get_validator_addresses().await; + self.validator_addresses = Some(validator_addresses); + + let pool_owner_address = blockchain_service.get_pool_owner_address().await; + self.pool_owner_address = pool_owner_address; + + let compute_manager_address = blockchain_service.get_compute_manager_address().await; + self.compute_manager_address = compute_manager_address; + + log::info!( + "Validator addresses: {:?}", + self.validator_addresses.as_ref().map(|s| s.len()) + ); + log::info!("Pool owner address: {:?}", self.pool_owner_address); + log::info!( + "Compute manager address: {:?}", + self.compute_manager_address + ); + + Ok(()) + } + fn initialize_auth_manager(&mut self) -> Result<()> { let blockchain_service = self.blockchain_service.as_ref().ok_or_else(|| { PrimeProtocolError::InvalidConfig("Blockchain service not initialized".to_string()) @@ -348,6 +383,20 @@ impl WorkerClientCore { })? .clone(); + let validator_addresses = self + .validator_addresses + .as_ref() + .ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Validator addresses not initialized".to_string()) + })? + .clone(); + + let pool_owner_address = self.pool_owner_address.ok_or_else(|| { + PrimeProtocolError::InvalidConfig("Pool owner address not initialized".to_string()) + })?; + + let compute_manager_address = self.compute_manager_address; + Ok(MessageProcessorConfig { auth_manager, message_queue_rx, @@ -355,6 +404,9 @@ impl WorkerClientCore { outbound_tx, authenticated_peers, cancellation_token: self.cancellation_token.clone(), + validator_addresses, + pool_owner_address: Some(pool_owner_address), + compute_manager_address, }) } diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs index 13fbe027..bd720c9a 100644 --- a/crates/prime-protocol-py/src/worker/mod.rs +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -87,6 +87,8 @@ impl WorkerClient { peer_id, multiaddrs, sender_address: None, // Will be filled from our wallet automatically + is_sender_validator: false, + is_sender_pool_owner: false, response_tx: None, }; @@ -228,6 +230,8 @@ fn message_to_pyobject(message: Message) -> PyObject { "peer_id": message.peer_id, "multiaddrs": message.multiaddrs, "sender_address": message.sender_address, + "is_sender_validator": message.is_sender_validator, + "is_sender_pool_owner": message.is_sender_pool_owner, }); Python::with_gil(|py| crate::utils::json_parser::json_to_pyobject(py, &json_value)) From a5ffd0e676bcf37fe894cdb56eddd131c57f0567 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 14:11:52 +0200 Subject: [PATCH 19/23] basic message passing example --- .../examples/orchestrator.py | 101 +++++------------- .../prime-protocol-py/examples/validator.py | 51 ++------- .../src/p2p_handler/message_processor.rs | 4 +- .../prime-protocol-py/src/p2p_handler/mod.rs | 8 +- 4 files changed, 44 insertions(+), 120 deletions(-) diff --git a/crates/prime-protocol-py/examples/orchestrator.py b/crates/prime-protocol-py/examples/orchestrator.py index f7f05773..4fba6717 100644 --- a/crates/prime-protocol-py/examples/orchestrator.py +++ b/crates/prime-protocol-py/examples/orchestrator.py @@ -18,6 +18,7 @@ def main(): discovery_urls_str = os.getenv("DISCOVERY_URLS", "http://localhost:8089") discovery_urls = [url.strip() for url in discovery_urls_str.split(",")] pool_id = int(os.getenv("POOL_ID", "0")) + p2p_port = int(os.getenv("ORCHESTRATOR_P2P_PORT", "8180")) if not private_key: print("Error: PRIVATE_KEY_ORCHESTRATOR environment variable is required") @@ -27,6 +28,7 @@ def main(): print(f"RPC URL: {rpc_url}") print(f"Discovery URLs: {discovery_urls}") print(f"Pool ID: {pool_id}") + print(f"P2P Port: {p2p_port}") # Initialize and start the orchestrator orchestrator = OrchestratorClient( @@ -36,59 +38,29 @@ def main(): ) print("Starting orchestrator client...") - orchestrator.start(p2p_port=8180) - print("Orchestrator client started") - - # Wait for P2P to initialize - time.sleep(5) + orchestrator.start(p2p_port=p2p_port) + print(f"Orchestrator started with peer ID: {orchestrator.get_peer_id()}") print("\nStarting orchestrator loop...") print("Press Ctrl+C to stop\n") - # Track nodes in our pool - pool_node_ids = set() - try: while True: print(f"{'='*50}") print(f"Cycle at {time.strftime('%H:%M:%S')}") - # Check for any pending messages first - print("Checking for pending messages...") + # Check for a single message message = orchestrator.get_next_message() - while message: - peer_id = message['peer_id'] - msg_data = message.get('message', {}) - - if msg_data.get('type') == 'general': - data = bytes(msg_data.get('data', [])) - sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ - "POOL_OWNER" if message.get('is_sender_pool_owner') else \ - "WORKER" - print(f" 📨 From {peer_id[:16]}... ({sender_type}): {data}") - elif msg_data.get('type') == 'authentication_complete': - print(f" ✓ Auth complete with {peer_id[:16]}...") - else: - print(f" 📋 {msg_data.get('type')} from {peer_id[:16]}...") - - # Check for more messages - message = orchestrator.get_next_message() + print(f"Got message - python orchestrator: {message}") - # Get nodes in the pool + # 1. Invite validated but inactive nodes pool_nodes = orchestrator.list_nodes_for_pool(pool_id) - print(f"\nFound {len(pool_nodes)} nodes in pool {pool_id}") + validated_inactive = [n for n in pool_nodes if n.is_validated and not n.is_active and n.worker_p2p_id] - # Update our tracking of pool nodes - pool_node_ids = {node.worker_p2p_id for node in pool_nodes if node.worker_p2p_id} - - # Process each node - for node in pool_nodes: - if not node.worker_p2p_id: - continue - - try: - if not node.is_active and node.is_validated: - # Invite inactive but validated nodes + if validated_inactive: + print(f"\nInviting {len(validated_inactive)} validated inactive nodes...") + for node in validated_inactive: + try: print(f" 📨 Inviting {node.id[:8]}...") orchestrator.invite_node( peer_id=node.worker_p2p_id, @@ -99,45 +71,28 @@ def main(): orchestrator_url=None, expiration_seconds=1000 ) - elif node.is_active: - # Send message to active nodes - print(f" 💬 Messaging {node.id[:8]}...") + except Exception as e: + print(f" ❌ Error inviting {node.id[:8]}: {e}") + + # 2. Send messages to active nodes + active_nodes = [n for n in pool_nodes if n.is_active and n.worker_p2p_id] + + if active_nodes: + print(f"\nMessaging {len(active_nodes)} active nodes...") + for node in active_nodes: + try: orchestrator.send_message( peer_id=node.worker_p2p_id, multiaddrs=node.worker_p2p_addresses, data=b"Hello from orchestrator!", ) - except Exception as e: - print(f" ❌ Error with {node.id[:8]}: {e}") - - # Check for messages throughout the wait period - print(f"\nWaiting 30 seconds (checking for messages)...") - end_time = time.time() + 30 - messages_during_wait = 0 - - while time.time() < end_time: - message = orchestrator.get_next_message() - if message: - peer_id = message['peer_id'] - msg_data = message.get('message', {}) - - if msg_data.get('type') == 'general': - data = bytes(msg_data.get('data', [])) - sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ - "POOL_OWNER" if message.get('is_sender_pool_owner') else \ - "WORKER" - - # Check if message is from a pool node - if peer_id in pool_node_ids: - print(f" 📨 From pool node {peer_id[:16]}... ({sender_type}): {data}") - else: - print(f" 📨 From external {peer_id[:16]}... ({sender_type}): {data}") - messages_during_wait += 1 - else: - time.sleep(0.1) + print(f" 💬 Sent to {node.id[:8]}...") + except Exception as e: + print(f" ❌ Error messaging {node.id[:8]}: {e}") - if messages_during_wait > 0: - print(f"Received {messages_during_wait} messages during wait") + # Wait before next cycle + print(f"\nWaiting 10 seconds...") + time.sleep(10) print() except KeyboardInterrupt: diff --git a/crates/prime-protocol-py/examples/validator.py b/crates/prime-protocol-py/examples/validator.py index 2b4af7cc..2fdf2f79 100644 --- a/crates/prime-protocol-py/examples/validator.py +++ b/crates/prime-protocol-py/examples/validator.py @@ -46,29 +46,14 @@ def main(): while True: print(f"{'='*50}") print(f"Cycle at {time.strftime('%H:%M:%S')}") - - # Check for messages first - print("Checking for any pending messages...") + # Check for a single message message = validator.get_next_message() - while message: - peer_id = message['peer_id'] + print(f"Message: {message}") + if message: msg_data = message.get('message', {}) - if msg_data.get('type') == 'general': data = bytes(msg_data.get('data', [])) - sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ - "POOL_OWNER" if message.get('is_sender_pool_owner') else \ - "WORKER" - print(f" 📨 From {peer_id[:16]}... ({sender_type}): {data}") - elif msg_data.get('type') == 'authentication_complete': - print(f" ✓ Auth complete with {peer_id[:16]}...") - else: - print(f" 📋 {msg_data.get('type')} from {peer_id[:16]}...") - - # Check for more messages - message = validator.get_next_message() - - # 1. Validate any non-validated nodes + print(f"Message payload: {data}") non_validated = validator.list_non_validated_nodes() if non_validated: print(f"\nValidating {len(non_validated)} nodes...") @@ -98,34 +83,14 @@ def main(): except Exception as e: print(f" ❌ Error messaging {node['id'][:8]}: {e}") - # 3. Check for more messages throughout the wait period - print(f"\nWaiting 10 seconds (checking for messages)...") - end_time = time.time() + 10 - messages_during_wait = 0 - - while time.time() < end_time: - message = validator.get_next_message() - if message: - peer_id = message['peer_id'] - msg_data = message.get('message', {}) - - if msg_data.get('type') == 'general': - data = bytes(msg_data.get('data', [])) - sender_type = "VALIDATOR" if message.get('is_sender_validator') else \ - "POOL_OWNER" if message.get('is_sender_pool_owner') else \ - "WORKER" - print(f" 📨 From {peer_id[:16]}... ({sender_type}): {data}") - messages_during_wait += 1 - else: - time.sleep(0.1) - - if messages_during_wait > 0: - print(f"Received {messages_during_wait} messages during wait") + # Wait before next cycle + print(f"\nWaiting 10 seconds...") + time.sleep(10) print() except KeyboardInterrupt: print("\n\nShutting down validator...") - validator.stop() + # ValidatorClient doesn't have a stop() method print("Validator stopped") if __name__ == "__main__": diff --git a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs index 8efa0eb3..0e1d98f4 100644 --- a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs +++ b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs @@ -107,7 +107,6 @@ impl MessageProcessor { } Err(_) => continue, // Timeout, continue loop }; - log::debug!("Received message: {:?}", message); if let Err(e) = self.process_message(message).await { @@ -125,9 +124,8 @@ impl MessageProcessor { peer_id, multiaddrs, sender_address, - is_sender_validator, - is_sender_pool_owner, response_tx, + .. } = message; match message_type { diff --git a/crates/prime-protocol-py/src/p2p_handler/mod.rs b/crates/prime-protocol-py/src/p2p_handler/mod.rs index 89b98cf6..8e9f160c 100644 --- a/crates/prime-protocol-py/src/p2p_handler/mod.rs +++ b/crates/prime-protocol-py/src/p2p_handler/mod.rs @@ -44,7 +44,13 @@ pub enum MessageType { AuthenticationComplete, } -#[derive(Debug, Serialize, Deserialize)] +impl Default for MessageType { + fn default() -> Self { + MessageType::General { data: vec![] } + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] pub struct Message { pub message_type: MessageType, pub peer_id: String, From a25726d9c13e69d57cda58b27fa9eea39544d394 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 14:24:11 +0200 Subject: [PATCH 20/23] clippy --- crates/prime-protocol-py/src/p2p_handler/message_processor.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs index 0e1d98f4..8345e4df 100644 --- a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs +++ b/crates/prime-protocol-py/src/p2p_handler/message_processor.rs @@ -39,6 +39,7 @@ pub struct MessageProcessor { } impl MessageProcessor { + #[allow(clippy::too_many_arguments)] pub fn new( auth_manager: Arc, message_queue_rx: Arc>>, From 203f21ba80d72fea7465431c412c7792da88e1b5 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 14:46:52 +0200 Subject: [PATCH 21/23] remove additional print statements --- crates/dev-utils/examples/compute_pool.rs | 13 +++---------- crates/prime-protocol-py/README.md | 10 ---------- crates/prime-protocol-py/src/orchestrator/mod.rs | 8 -------- .../implementations/compute_pool_contract.rs | 5 ----- 4 files changed, 3 insertions(+), 33 deletions(-) diff --git a/crates/dev-utils/examples/compute_pool.rs b/crates/dev-utils/examples/compute_pool.rs index 51658d59..8870a528 100644 --- a/crates/dev-utils/examples/compute_pool.rs +++ b/crates/dev-utils/examples/compute_pool.rs @@ -58,7 +58,7 @@ async fn main() -> Result<()> { let compute_limit = U256::from(0); - let tx = contracts + let _tx = contracts .compute_pool .create_compute_pool( domain_id, @@ -68,28 +68,21 @@ async fn main() -> Result<()> { compute_limit, ) .await; - println!("Transaction: {tx:?}"); let rewards_distributor_address = contracts .compute_pool .get_reward_distributor_address(U256::from(0)) .await .unwrap(); - println!("Rewards distributor address: {rewards_distributor_address:?}"); let rewards_distributor = RewardsDistributor::new( rewards_distributor_address, wallet.provider(), "rewards_distributor.json", ); let rate = U256::from(10000000000000000u64); - let tx = rewards_distributor.set_reward_rate(rate).await; - println!("Setting reward rate: {tx:?}"); + let _tx = rewards_distributor.set_reward_rate(rate).await; - let reward_rate = rewards_distributor.get_reward_rate().await.unwrap(); - println!( - "Reward rate: {}", - reward_rate.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0) - ); + let _reward_rate = rewards_distributor.get_reward_rate().await.unwrap(); Ok(()) } diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md index 0ba36ef8..1b3cdc28 100644 --- a/crates/prime-protocol-py/README.md +++ b/crates/prime-protocol-py/README.md @@ -70,19 +70,9 @@ uv run examples/orchestrator.py You should see that the orchestrator invites the worker node that you started in the earlier step and keeps sending messages via p2p to this node. ## TODOs: -- [ ] can validator send message? (special validation message) -- [ ] can the worker send messages as response to the validator / orchestrator? -- [ ] can orchestrator send message? -- [ ] whitelist using python api? -- [ ] borrow bug? - [ ] restart keeps increasing provider stake? -- [ ] p2p cleanup - [ ] I keep forgetting to run make build - [ ] what about formatting? -- 09, in main - validator.stop() - ^^^^^^^^^^^^^^ -AttributeError: 'builtins.ValidatorClient' object has no attribute 'stop' ## Known Limitations: 1. Orchestrator can no longer send messages to a node when the worker node restarts. diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs index f6adff74..5a2ee2e7 100644 --- a/crates/prime-protocol-py/src/orchestrator/mod.rs +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -314,7 +314,6 @@ impl OrchestratorClient { orchestrator_url: Option, expiration_seconds: u64, ) -> PyResult<()> { - println!("invite_node"); let rt = self.get_or_create_runtime()?; let wallet = self.wallet.as_ref().ok_or_else(|| { @@ -343,14 +342,10 @@ impl OrchestratorClient { )) })?; - println!("worker_addr: {:?}", worker_addr); - let wallet = wallet.clone(); let outbound_tx = outbound_tx.clone(); let auth_manager = auth_manager.clone(); - println!("invite_node 2"); - py.allow_threads(|| { rt.block_on(async { // Generate invite parameters @@ -403,15 +398,12 @@ impl OrchestratorClient { }; // Send the invite - println!("sending invite"); crate::p2p_handler::send_message_with_auth(message, &auth_manager, &outbound_tx) .await .map_err(|e| { PyErr::new::(e.to_string()) })?; - println!("invite sent"); - Ok(()) }) }) diff --git a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs index b52f96e2..32637b62 100644 --- a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs +++ b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs @@ -61,9 +61,6 @@ impl ComputePool

{ _ => panic!("Unknown status value: {status}"), }; - println!("Mapped status: {mapped_status:?}"); - println!("Returning pool info"); - let pool_info = PoolInfo { pool_id, domain_id, @@ -392,8 +389,6 @@ impl ComputePool { pool_id: u32, node: Address, ) -> Result, Box> { - println!("Ejecting node"); - let arg_pool_id: U256 = U256::from(pool_id); let result = self From 27328057bd523782d9b84c9e473c8db62290f593 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 18:45:36 +0200 Subject: [PATCH 22/23] rename prime-protocol-py to python-sdk --- Cargo.lock | 56 +++++++++---------- Cargo.toml | 2 +- .../.gitignore | 0 .../.python-version | 0 .../Cargo.toml | 2 +- .../Makefile | 0 .../README.md | 0 .../examples/orchestrator.py | 0 .../examples/validator.py | 0 .../examples/worker.py | 0 .../pyproject.toml | 0 .../requirements-dev.txt | 0 .../setup.sh | 0 .../src/common/mod.rs | 0 .../src/constants.rs | 0 .../src/error.rs | 0 .../src/lib.rs | 0 .../src/orchestrator/mod.rs | 0 .../src/p2p_handler/auth.rs | 0 .../src/p2p_handler/common.rs | 0 .../src/p2p_handler/message_processor.rs | 0 .../src/p2p_handler/mod.rs | 0 .../src/utils/json_parser.rs | 0 .../src/utils/mod.rs | 0 .../src/validator/mod.rs | 0 .../src/worker/blockchain.rs | 0 .../src/worker/client.rs | 0 .../src/worker/discovery.rs | 0 .../src/worker/mod.rs | 0 .../tests/integration/test_worker.rs | 0 .../tests/test_client.py | 0 .../tests/test_validator.py | 0 .../{prime-protocol-py => python-sdk}/uv.lock | 0 33 files changed, 30 insertions(+), 30 deletions(-) rename crates/{prime-protocol-py => python-sdk}/.gitignore (100%) rename crates/{prime-protocol-py => python-sdk}/.python-version (100%) rename crates/{prime-protocol-py => python-sdk}/Cargo.toml (97%) rename crates/{prime-protocol-py => python-sdk}/Makefile (100%) rename crates/{prime-protocol-py => python-sdk}/README.md (100%) rename crates/{prime-protocol-py => python-sdk}/examples/orchestrator.py (100%) rename crates/{prime-protocol-py => python-sdk}/examples/validator.py (100%) rename crates/{prime-protocol-py => python-sdk}/examples/worker.py (100%) rename crates/{prime-protocol-py => python-sdk}/pyproject.toml (100%) rename crates/{prime-protocol-py => python-sdk}/requirements-dev.txt (100%) rename crates/{prime-protocol-py => python-sdk}/setup.sh (100%) rename crates/{prime-protocol-py => python-sdk}/src/common/mod.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/constants.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/error.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/lib.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/orchestrator/mod.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/p2p_handler/auth.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/p2p_handler/common.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/p2p_handler/message_processor.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/p2p_handler/mod.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/utils/json_parser.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/utils/mod.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/validator/mod.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/worker/blockchain.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/worker/client.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/worker/discovery.rs (100%) rename crates/{prime-protocol-py => python-sdk}/src/worker/mod.rs (100%) rename crates/{prime-protocol-py => python-sdk}/tests/integration/test_worker.rs (100%) rename crates/{prime-protocol-py => python-sdk}/tests/test_client.py (100%) rename crates/{prime-protocol-py => python-sdk}/tests/test_validator.py (100%) rename crates/{prime-protocol-py => python-sdk}/uv.lock (100%) diff --git a/Cargo.lock b/Cargo.lock index 48a90c27..f20ba450 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6735,34 +6735,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "prime-protocol-py" -version = "0.1.0" -dependencies = [ - "alloy", - "alloy-provider", - "anyhow", - "futures", - "hex", - "log", - "p2p", - "prime-core", - "pyo3", - "pyo3-log", - "pythonize", - "rand 0.8.5", - "reqwest 0.11.27", - "serde", - "serde_json", - "shared", - "test-log", - "thiserror 1.0.69", - "tokio", - "tokio-test", - "tokio-util", - "url", -] - [[package]] name = "primeorder" version = "0.13.6" @@ -7020,6 +6992,34 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "python-sdk" +version = "0.1.0" +dependencies = [ + "alloy", + "alloy-provider", + "anyhow", + "futures", + "hex", + "log", + "p2p", + "prime-core", + "pyo3", + "pyo3-log", + "pythonize", + "rand 0.8.5", + "reqwest 0.11.27", + "serde", + "serde_json", + "shared", + "test-log", + "thiserror 1.0.69", + "tokio", + "tokio-test", + "tokio-util", + "url", +] + [[package]] name = "pythonize" version = "0.25.0" diff --git a/Cargo.toml b/Cargo.toml index 15655fbc..6b6db7ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ members = [ "crates/orchestrator", "crates/p2p", "crates/dev-utils", - "crates/prime-protocol-py", + "crates/python-sdk", "crates/prime-core", ] resolver = "2" diff --git a/crates/prime-protocol-py/.gitignore b/crates/python-sdk/.gitignore similarity index 100% rename from crates/prime-protocol-py/.gitignore rename to crates/python-sdk/.gitignore diff --git a/crates/prime-protocol-py/.python-version b/crates/python-sdk/.python-version similarity index 100% rename from crates/prime-protocol-py/.python-version rename to crates/python-sdk/.python-version diff --git a/crates/prime-protocol-py/Cargo.toml b/crates/python-sdk/Cargo.toml similarity index 97% rename from crates/prime-protocol-py/Cargo.toml rename to crates/python-sdk/Cargo.toml index 971480e3..a39b3a81 100644 --- a/crates/prime-protocol-py/Cargo.toml +++ b/crates/python-sdk/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "prime-protocol-py" +name = "python-sdk" version = "0.1.0" authors = ["Prime Protocol"] edition = "2021" diff --git a/crates/prime-protocol-py/Makefile b/crates/python-sdk/Makefile similarity index 100% rename from crates/prime-protocol-py/Makefile rename to crates/python-sdk/Makefile diff --git a/crates/prime-protocol-py/README.md b/crates/python-sdk/README.md similarity index 100% rename from crates/prime-protocol-py/README.md rename to crates/python-sdk/README.md diff --git a/crates/prime-protocol-py/examples/orchestrator.py b/crates/python-sdk/examples/orchestrator.py similarity index 100% rename from crates/prime-protocol-py/examples/orchestrator.py rename to crates/python-sdk/examples/orchestrator.py diff --git a/crates/prime-protocol-py/examples/validator.py b/crates/python-sdk/examples/validator.py similarity index 100% rename from crates/prime-protocol-py/examples/validator.py rename to crates/python-sdk/examples/validator.py diff --git a/crates/prime-protocol-py/examples/worker.py b/crates/python-sdk/examples/worker.py similarity index 100% rename from crates/prime-protocol-py/examples/worker.py rename to crates/python-sdk/examples/worker.py diff --git a/crates/prime-protocol-py/pyproject.toml b/crates/python-sdk/pyproject.toml similarity index 100% rename from crates/prime-protocol-py/pyproject.toml rename to crates/python-sdk/pyproject.toml diff --git a/crates/prime-protocol-py/requirements-dev.txt b/crates/python-sdk/requirements-dev.txt similarity index 100% rename from crates/prime-protocol-py/requirements-dev.txt rename to crates/python-sdk/requirements-dev.txt diff --git a/crates/prime-protocol-py/setup.sh b/crates/python-sdk/setup.sh similarity index 100% rename from crates/prime-protocol-py/setup.sh rename to crates/python-sdk/setup.sh diff --git a/crates/prime-protocol-py/src/common/mod.rs b/crates/python-sdk/src/common/mod.rs similarity index 100% rename from crates/prime-protocol-py/src/common/mod.rs rename to crates/python-sdk/src/common/mod.rs diff --git a/crates/prime-protocol-py/src/constants.rs b/crates/python-sdk/src/constants.rs similarity index 100% rename from crates/prime-protocol-py/src/constants.rs rename to crates/python-sdk/src/constants.rs diff --git a/crates/prime-protocol-py/src/error.rs b/crates/python-sdk/src/error.rs similarity index 100% rename from crates/prime-protocol-py/src/error.rs rename to crates/python-sdk/src/error.rs diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/python-sdk/src/lib.rs similarity index 100% rename from crates/prime-protocol-py/src/lib.rs rename to crates/python-sdk/src/lib.rs diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/python-sdk/src/orchestrator/mod.rs similarity index 100% rename from crates/prime-protocol-py/src/orchestrator/mod.rs rename to crates/python-sdk/src/orchestrator/mod.rs diff --git a/crates/prime-protocol-py/src/p2p_handler/auth.rs b/crates/python-sdk/src/p2p_handler/auth.rs similarity index 100% rename from crates/prime-protocol-py/src/p2p_handler/auth.rs rename to crates/python-sdk/src/p2p_handler/auth.rs diff --git a/crates/prime-protocol-py/src/p2p_handler/common.rs b/crates/python-sdk/src/p2p_handler/common.rs similarity index 100% rename from crates/prime-protocol-py/src/p2p_handler/common.rs rename to crates/python-sdk/src/p2p_handler/common.rs diff --git a/crates/prime-protocol-py/src/p2p_handler/message_processor.rs b/crates/python-sdk/src/p2p_handler/message_processor.rs similarity index 100% rename from crates/prime-protocol-py/src/p2p_handler/message_processor.rs rename to crates/python-sdk/src/p2p_handler/message_processor.rs diff --git a/crates/prime-protocol-py/src/p2p_handler/mod.rs b/crates/python-sdk/src/p2p_handler/mod.rs similarity index 100% rename from crates/prime-protocol-py/src/p2p_handler/mod.rs rename to crates/python-sdk/src/p2p_handler/mod.rs diff --git a/crates/prime-protocol-py/src/utils/json_parser.rs b/crates/python-sdk/src/utils/json_parser.rs similarity index 100% rename from crates/prime-protocol-py/src/utils/json_parser.rs rename to crates/python-sdk/src/utils/json_parser.rs diff --git a/crates/prime-protocol-py/src/utils/mod.rs b/crates/python-sdk/src/utils/mod.rs similarity index 100% rename from crates/prime-protocol-py/src/utils/mod.rs rename to crates/python-sdk/src/utils/mod.rs diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/python-sdk/src/validator/mod.rs similarity index 100% rename from crates/prime-protocol-py/src/validator/mod.rs rename to crates/python-sdk/src/validator/mod.rs diff --git a/crates/prime-protocol-py/src/worker/blockchain.rs b/crates/python-sdk/src/worker/blockchain.rs similarity index 100% rename from crates/prime-protocol-py/src/worker/blockchain.rs rename to crates/python-sdk/src/worker/blockchain.rs diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/python-sdk/src/worker/client.rs similarity index 100% rename from crates/prime-protocol-py/src/worker/client.rs rename to crates/python-sdk/src/worker/client.rs diff --git a/crates/prime-protocol-py/src/worker/discovery.rs b/crates/python-sdk/src/worker/discovery.rs similarity index 100% rename from crates/prime-protocol-py/src/worker/discovery.rs rename to crates/python-sdk/src/worker/discovery.rs diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/python-sdk/src/worker/mod.rs similarity index 100% rename from crates/prime-protocol-py/src/worker/mod.rs rename to crates/python-sdk/src/worker/mod.rs diff --git a/crates/prime-protocol-py/tests/integration/test_worker.rs b/crates/python-sdk/tests/integration/test_worker.rs similarity index 100% rename from crates/prime-protocol-py/tests/integration/test_worker.rs rename to crates/python-sdk/tests/integration/test_worker.rs diff --git a/crates/prime-protocol-py/tests/test_client.py b/crates/python-sdk/tests/test_client.py similarity index 100% rename from crates/prime-protocol-py/tests/test_client.py rename to crates/python-sdk/tests/test_client.py diff --git a/crates/prime-protocol-py/tests/test_validator.py b/crates/python-sdk/tests/test_validator.py similarity index 100% rename from crates/prime-protocol-py/tests/test_validator.py rename to crates/python-sdk/tests/test_validator.py diff --git a/crates/prime-protocol-py/uv.lock b/crates/python-sdk/uv.lock similarity index 100% rename from crates/prime-protocol-py/uv.lock rename to crates/python-sdk/uv.lock From 295c7fc0b402a7476c61bb1965ce09d80a47e606 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 23 Jul 2025 18:48:01 +0200 Subject: [PATCH 23/23] rename prime-core to operations --- Cargo.lock | 56 +++++++++---------- Cargo.toml | 4 +- crates/{prime-core => operations}/Cargo.toml | 4 +- .../src/invite/admin.rs | 0 .../src/invite/common.rs | 0 .../src/invite/mod.rs | 0 .../src/invite/worker.rs | 0 crates/{prime-core => operations}/src/lib.rs | 0 .../src/operations/compute_node.rs | 0 .../src/operations/mod.rs | 0 .../src/operations/provider.rs | 0 crates/orchestrator/Cargo.toml | 2 +- crates/orchestrator/src/node/invite.rs | 4 +- crates/python-sdk/Cargo.toml | 2 +- crates/python-sdk/src/orchestrator/mod.rs | 2 +- crates/python-sdk/src/worker/blockchain.rs | 4 +- crates/python-sdk/src/worker/client.rs | 2 +- crates/worker/Cargo.toml | 2 +- crates/worker/src/cli/command.rs | 4 +- crates/worker/src/p2p/mod.rs | 2 +- 20 files changed, 44 insertions(+), 44 deletions(-) rename crates/{prime-core => operations}/Cargo.toml (94%) rename crates/{prime-core => operations}/src/invite/admin.rs (100%) rename crates/{prime-core => operations}/src/invite/common.rs (100%) rename crates/{prime-core => operations}/src/invite/mod.rs (100%) rename crates/{prime-core => operations}/src/invite/worker.rs (100%) rename crates/{prime-core => operations}/src/lib.rs (100%) rename crates/{prime-core => operations}/src/operations/compute_node.rs (100%) rename crates/{prime-core => operations}/src/operations/mod.rs (100%) rename crates/{prime-core => operations}/src/operations/provider.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index f20ba450..d9b610be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6180,6 +6180,31 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "operations" +version = "0.1.0" +dependencies = [ + "actix-web", + "alloy", + "alloy-provider", + "anyhow", + "env_logger", + "futures-util", + "hex", + "log", + "p2p", + "rand 0.8.5", + "redis", + "serde", + "serde_json", + "shared", + "subtle", + "tokio", + "tokio-util", + "url", + "uuid", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -6203,8 +6228,8 @@ dependencies = [ "hex", "log", "mockito", + "operations", "p2p", - "prime-core", "prometheus 0.14.0", "rand 0.9.1", "redis", @@ -6710,31 +6735,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "prime-core" -version = "0.1.0" -dependencies = [ - "actix-web", - "alloy", - "alloy-provider", - "anyhow", - "env_logger", - "futures-util", - "hex", - "log", - "p2p", - "rand 0.8.5", - "redis", - "serde", - "serde_json", - "shared", - "subtle", - "tokio", - "tokio-util", - "url", - "uuid", -] - [[package]] name = "primeorder" version = "0.13.6" @@ -7002,8 +7002,8 @@ dependencies = [ "futures", "hex", "log", + "operations", "p2p", - "prime-core", "pyo3", "pyo3-log", "pythonize", @@ -10478,8 +10478,8 @@ dependencies = [ "libc", "log", "nvml-wrapper", + "operations", "p2p", - "prime-core", "rand 0.8.5", "rand 0.9.1", "reqwest 0.12.15", diff --git a/Cargo.toml b/Cargo.toml index 6b6db7ff..f35a3e7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,15 +8,15 @@ members = [ "crates/p2p", "crates/dev-utils", "crates/python-sdk", - "crates/prime-core", + "crates/operations", ] resolver = "2" [workspace.dependencies] shared = { path = "crates/shared" } p2p = { path = "crates/p2p" } +operations = { path = "crates/operations" } -prime-core = { path = "crates/prime-core" } actix-web = "4.9.0" clap = { version = "4.5.27", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] } diff --git a/crates/prime-core/Cargo.toml b/crates/operations/Cargo.toml similarity index 94% rename from crates/prime-core/Cargo.toml rename to crates/operations/Cargo.toml index 4b6ec28c..875478cd 100644 --- a/crates/prime-core/Cargo.toml +++ b/crates/operations/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "prime-core" +name = "operations" version = "0.1.0" edition = "2021" @@ -7,7 +7,7 @@ edition = "2021" workspace = true [lib] -name = "prime_core" +name = "operations" path = "src/lib.rs" [dependencies] diff --git a/crates/prime-core/src/invite/admin.rs b/crates/operations/src/invite/admin.rs similarity index 100% rename from crates/prime-core/src/invite/admin.rs rename to crates/operations/src/invite/admin.rs diff --git a/crates/prime-core/src/invite/common.rs b/crates/operations/src/invite/common.rs similarity index 100% rename from crates/prime-core/src/invite/common.rs rename to crates/operations/src/invite/common.rs diff --git a/crates/prime-core/src/invite/mod.rs b/crates/operations/src/invite/mod.rs similarity index 100% rename from crates/prime-core/src/invite/mod.rs rename to crates/operations/src/invite/mod.rs diff --git a/crates/prime-core/src/invite/worker.rs b/crates/operations/src/invite/worker.rs similarity index 100% rename from crates/prime-core/src/invite/worker.rs rename to crates/operations/src/invite/worker.rs diff --git a/crates/prime-core/src/lib.rs b/crates/operations/src/lib.rs similarity index 100% rename from crates/prime-core/src/lib.rs rename to crates/operations/src/lib.rs diff --git a/crates/prime-core/src/operations/compute_node.rs b/crates/operations/src/operations/compute_node.rs similarity index 100% rename from crates/prime-core/src/operations/compute_node.rs rename to crates/operations/src/operations/compute_node.rs diff --git a/crates/prime-core/src/operations/mod.rs b/crates/operations/src/operations/mod.rs similarity index 100% rename from crates/prime-core/src/operations/mod.rs rename to crates/operations/src/operations/mod.rs diff --git a/crates/prime-core/src/operations/provider.rs b/crates/operations/src/operations/provider.rs similarity index 100% rename from crates/prime-core/src/operations/provider.rs rename to crates/operations/src/operations/provider.rs diff --git a/crates/orchestrator/Cargo.toml b/crates/orchestrator/Cargo.toml index 2703facf..594df148 100644 --- a/crates/orchestrator/Cargo.toml +++ b/crates/orchestrator/Cargo.toml @@ -9,7 +9,7 @@ workspace = true [dependencies] p2p = { workspace = true} shared = { workspace = true } -prime-core = { workspace = true } +operations = { workspace = true } actix-web = { workspace = true } alloy = { workspace = true } diff --git a/crates/orchestrator/src/node/invite.rs b/crates/orchestrator/src/node/invite.rs index 4e3cc874..7aa68558 100644 --- a/crates/orchestrator/src/node/invite.rs +++ b/crates/orchestrator/src/node/invite.rs @@ -7,11 +7,11 @@ use anyhow::{bail, Result}; use futures::stream; use futures::StreamExt; use log::{debug, error, info, warn}; -use p2p::InviteRequestUrl; -use prime_core::invite::{ +use operations::invite::{ admin::{generate_invite_expiration, generate_invite_nonce, generate_invite_signature}, common::InviteBuilder, }; +use p2p::InviteRequestUrl; use shared::web3::wallet::Wallet; use std::sync::Arc; use tokio::sync::mpsc::Sender; diff --git a/crates/python-sdk/Cargo.toml b/crates/python-sdk/Cargo.toml index a39b3a81..fffe1880 100644 --- a/crates/python-sdk/Cargo.toml +++ b/crates/python-sdk/Cargo.toml @@ -13,7 +13,7 @@ crate-type = ["cdylib"] pyo3 = { version = "0.25.1", features = ["extension-module"] } thiserror = "1.0" shared = { workspace = true } -prime-core = { workspace = true } +operations = { workspace = true } p2p = { workspace = true } alloy = { workspace = true } alloy-provider = { workspace = true } diff --git a/crates/python-sdk/src/orchestrator/mod.rs b/crates/python-sdk/src/orchestrator/mod.rs index 5a2ee2e7..34f51cba 100644 --- a/crates/python-sdk/src/orchestrator/mod.rs +++ b/crates/python-sdk/src/orchestrator/mod.rs @@ -18,7 +18,7 @@ use shared::discovery::fetch_nodes_from_discovery_urls; // Add imports for invite functionality use alloy::primitives::Address; -use prime_core::invite::{ +use operations::invite::{ admin::{generate_invite_expiration, generate_invite_nonce, generate_invite_signature}, common::InviteBuilder, }; diff --git a/crates/python-sdk/src/worker/blockchain.rs b/crates/python-sdk/src/worker/blockchain.rs index 793e3b05..0ed4ea5b 100644 --- a/crates/python-sdk/src/worker/blockchain.rs +++ b/crates/python-sdk/src/worker/blockchain.rs @@ -1,8 +1,8 @@ use alloy::primitives::utils::format_ether; use alloy::primitives::{Address, U256}; use anyhow::{Context, Result}; -use prime_core::operations::compute_node::ComputeNodeOperations; -use prime_core::operations::provider::ProviderOperations; +use operations::operations::compute_node::ComputeNodeOperations; +use operations::operations::provider::ProviderOperations; use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; use shared::web3::contracts::structs::compute_pool::PoolStatus; use shared::web3::wallet::{Wallet, WalletProvider}; diff --git a/crates/python-sdk/src/worker/client.rs b/crates/python-sdk/src/worker/client.rs index 8eb2ca4d..ec69e448 100644 --- a/crates/python-sdk/src/worker/client.rs +++ b/crates/python-sdk/src/worker/client.rs @@ -165,7 +165,7 @@ impl WorkerClientCore { log::info!("Received invite from peer: {}", message.peer_id); // Check if invite has expired - if let Ok(true) = prime_core::invite::worker::is_invite_expired(&invite) { + if let Ok(true) = operations::invite::worker::is_invite_expired(&invite) { log::warn!("Received expired invite from peer: {}", message.peer_id); return Some(message); // Return it so user can see the expired invite } diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index bd35ca32..51981fa5 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -9,7 +9,7 @@ workspace = true [dependencies] shared = { workspace = true } p2p = { workspace = true } -prime-core = { workspace = true} +operations = { workspace = true} actix-web = { workspace = true } alloy = { workspace = true } diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs index f53f2762..e0214d83 100644 --- a/crates/worker/src/cli/command.rs +++ b/crates/worker/src/cli/command.rs @@ -19,8 +19,8 @@ use alloy::signers::local::PrivateKeySigner; use alloy::signers::Signer; use clap::{Parser, Subcommand}; use log::{error, info}; -use prime_core::operations::compute_node::ComputeNodeOperations; -use prime_core::operations::provider::ProviderOperations; +use operations::operations::compute_node::ComputeNodeOperations; +use operations::operations::provider::ProviderOperations; use shared::models::node::ComputeRequirements; use shared::models::node::Node; use shared::web3::contracts::core::builder::ContractBuilder; diff --git a/crates/worker/src/p2p/mod.rs b/crates/worker/src/p2p/mod.rs index 2635f3f9..8d31d27a 100644 --- a/crates/worker/src/p2p/mod.rs +++ b/crates/worker/src/p2p/mod.rs @@ -1,12 +1,12 @@ use anyhow::Context as _; use anyhow::Result; use futures::stream::FuturesUnordered; +use operations::invite::{common::get_endpoint_from_url, worker::is_invite_expired}; use p2p::Node; use p2p::NodeBuilder; use p2p::PeerId; use p2p::Response; use p2p::{IncomingMessage, Libp2pIncomingMessage, OutgoingMessage}; -use prime_core::invite::{common::get_endpoint_from_url, worker::is_invite_expired}; use shared::web3::contracts::core::builder::Contracts; use shared::web3::wallet::Wallet; use std::collections::HashMap;