diff --git a/README.md b/README.md index 6d83d5e..e0a14c5 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,8 @@ Options: --commitment solana commitment level to use for state updates (default: confirmed) --default-sub-account-id - default sub_account_id to use (default: 0) + default sub_account_id to use as default. Use the new active-sub-accounts param to subscribe to multiple sub accounts. This param will override active-sub-accounts. + --active-sub-accounts sub accounts to subscribe to. (default: 0) --skip-tx-preflight skip tx preflight checks --extra-rpcs extra solana RPC urls for improved Tx broadcast diff --git a/src/controller.rs b/src/controller.rs index 419e55d..a99f434 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -90,8 +90,8 @@ pub struct AppState { pub client: Arc, /// Solana tx commitment level for preflight confirmation tx_commitment: CommitmentConfig, - /// default sub_account_id to use if not provided - default_subaccount_id: u16, + /// sub_account_ids to subscribe to + sub_account_ids: Vec, /// skip tx preflight on send or not (default: false) skip_tx_preflight: bool, priority_fee_subscriber: Arc, @@ -99,7 +99,7 @@ pub struct AppState { /// list of additional RPC endpoints for tx broadcast extra_rpcs: Vec>, /// swift node url - swift_node: Option, + swift_node: String, } impl AppState { @@ -111,12 +111,12 @@ impl AppState { pub fn signer(&self) -> Pubkey { self.wallet.signer() } - pub fn default_sub_account(&self) -> Pubkey { - self.wallet.sub_account(self.default_subaccount_id) + pub fn sub_account(&self, sub_account_id: u16) -> Pubkey { + self.wallet.sub_account(sub_account_id) } pub fn resolve_sub_account(&self, sub_account_id: Option) -> Pubkey { self.wallet - .sub_account(sub_account_id.unwrap_or(self.default_subaccount_id)) + .sub_account(sub_account_id.unwrap_or(self.default_sub_account_id())) } /// Initialize Gateway Drift client @@ -125,7 +125,7 @@ impl AppState { /// * `devnet` - whether to run against devnet or not /// * `wallet` - wallet to use for tx signing /// * `commitment` - Slot finalisation/commitement levels - /// * `default_subaccount_id` - by default all queries will use this sub-account + /// * `sub_account_ids` - the sub_accounts to subscribe too. In your query specify a specific subaccount, otherwise subaccount 0 will be used as default /// * `skip_tx_preflight` - submit txs without checking preflight results /// * `extra_rpcs` - list of additional RPC endpoints for tx submission pub async fn new( @@ -133,10 +133,10 @@ impl AppState { devnet: bool, wallet: Wallet, commitment: Option<(CommitmentConfig, CommitmentConfig)>, - default_subaccount_id: Option, + sub_account_ids: Vec, skip_tx_preflight: bool, extra_rpcs: Vec<&str>, - swift_node: Option, + swift_node: String, ) -> Self { let (state_commitment, tx_commitment) = commitment.unwrap_or((CommitmentConfig::confirmed(), CommitmentConfig::confirmed())); @@ -151,11 +151,13 @@ impl AppState { .await .expect("ok"); - let default_subaccount = wallet.sub_account(default_subaccount_id.unwrap_or(0)); - if let Err(err) = client.subscribe_account(&default_subaccount).await { - log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}"); - } else { - log::info!(target: LOG_TARGET, "subscribed to subaccount: {default_subaccount}"); + for sub_account_id in &sub_account_ids { + let sub_account = wallet.sub_account(*sub_account_id); + if let Err(err) = client.subscribe_account(&sub_account).await { + log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}. subaccount: {sub_account_id}"); + } else { + log::info!(target: LOG_TARGET, "subscribed to subaccount: {sub_account}"); + } } let priority_fee_subscriber = PriorityFeeSubscriber::with_config( @@ -190,7 +192,7 @@ impl AppState { Self { client: Arc::new(client), tx_commitment, - default_subaccount_id: default_subaccount_id.unwrap_or(0), + sub_account_ids, skip_tx_preflight, priority_fee_subscriber, slot_subscriber: Arc::new(slot_subscriber), @@ -207,11 +209,26 @@ impl AppState { &self, configured_markets: &[MarketId], ) -> Result<(), SdkError> { - let default_sub_account = self.default_sub_account(); + let sub_account_ids = self.sub_account_ids.clone(); + for id in sub_account_ids { + self.sync_market_subscriptions_on_user_subaccount_changes(configured_markets, id) + .await?; + } + + Ok(()) + } + + async fn sync_market_subscriptions_on_user_subaccount_changes( + &self, + configured_markets: &[MarketId], + sub_account_id: u16, + ) -> Result<(), SdkError> { + let sub_account = self.sub_account(sub_account_id); let state_commitment = self.tx_commitment; let configured_markets_vec = configured_markets.to_vec(); let self_clone = self.clone(); - let mut current_user_markets_to_subscribe = self.get_marketids_to_subscribe().await?; + let mut current_user_markets_to_subscribe = + self.get_marketids_to_subscribe(sub_account).await?; tokio::spawn(async move { let pubsub_config = RpcAccountInfoConfig { @@ -224,7 +241,7 @@ impl AppState { let pubsub_client = self_clone.client.ws(); let (mut account_subscription, unsubscribe_fn) = match pubsub_client - .account_subscribe(&default_sub_account, Some(pubsub_config)) + .account_subscribe(&sub_account, Some(pubsub_config)) .await { Ok(res) => res, @@ -239,7 +256,7 @@ impl AppState { // Process incoming account updates while let Some(_) = account_subscription.next().await { let current_market_ids_count = current_user_markets_to_subscribe.len(); - match self_clone.get_marketids_to_subscribe().await { + match self_clone.get_marketids_to_subscribe(sub_account).await { Ok(new_market_ids) => { if new_market_ids.len() != current_market_ids_count { if let Err(err) = self_clone @@ -267,13 +284,13 @@ impl AppState { Ok(()) } - async fn get_marketids_to_subscribe(&self) -> Result, SdkError> { - let (all_spot, all_perp) = self - .client - .all_positions(&self.default_sub_account()) - .await?; + async fn get_marketids_to_subscribe( + &self, + sub_account: Pubkey, + ) -> Result, SdkError> { + let (all_spot, all_perp) = self.client.all_positions(&sub_account).await?; - let open_orders = self.client.all_orders(&self.default_sub_account()).await?; + let open_orders = self.client.all_orders(&sub_account).await?; let user_markets: Vec = all_spot .iter() @@ -296,11 +313,25 @@ impl AppState { /// * configured_markets - list of static markets provided by user /// /// additional subscriptions will be included based on user's current positions (on default sub-account) + pub(crate) async fn subscribe_market_data( &self, configured_markets: &[MarketId], ) -> Result<(), SdkError> { - let mut user_markets = self.get_marketids_to_subscribe().await?; + for id in self.sub_account_ids.clone() { + self.subscribe_market_data_for_subaccount(configured_markets, id) + .await?; + } + Ok(()) + } + + async fn subscribe_market_data_for_subaccount( + &self, + configured_markets: &[MarketId], + sub_account_id: u16, + ) -> Result<(), SdkError> { + let sub_account = self.sub_account(sub_account_id); + let mut user_markets = self.get_marketids_to_subscribe(sub_account).await?; user_markets.extend_from_slice(configured_markets); let init_rpc_throttle: u64 = std::env::var("INIT_RPC_THROTTLE") @@ -636,7 +667,7 @@ impl AppState { let orders_len = orders_iter.len(); let mut signed_messages = Vec::with_capacity(orders_len); let mut hashes: Vec = Vec::with_capacity(orders_len); - let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id); + let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_sub_account_id()); let current_slot = self.slot_subscriber.current_slot(); let orders_with_hex: Vec<(OrderParams, Vec)> = orders_iter .map(|order| { @@ -663,7 +694,7 @@ impl AppState { }; let incoming_msg = IncomingSignedMessage { taker_authority: self.authority().to_string(), - signature: general_purpose::STANDARD.encode(signature), + signature: general_purpose::STANDARD.encode(signature), // TODO: test just using .to_string() for base64 encoding message: String::from_utf8(message).unwrap(), signing_authority: self.signer().to_string(), market_type, @@ -677,11 +708,7 @@ impl AppState { let client = reqwest::Client::new(); - let swift_orders_url = self - .swift_node - .clone() - .unwrap_or("https://master.swift.drift.trade".to_string()) - + "/orders"; + let swift_orders_url = self.swift_node.clone() + "/orders"; let mut futures = FuturesOrdered::new(); for msg in signed_messages { @@ -900,7 +927,7 @@ impl AppState { ctx: Context, new_margin_ratio: Decimal, ) -> GatewayResult { - let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id); + let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_sub_account_id()); let sub_account_address = self.wallet.sub_account(sub_account_id); let account_data = self.client.get_user_account(&sub_account_address).await?; @@ -921,6 +948,10 @@ impl AppState { self.send_tx(tx, "set_margin_ratio", ctx.ttl).await } + pub fn default_sub_account_id(&self) -> u16 { + self.sub_account_ids[0] + } + fn get_priority_fee(&self) -> u64 { self.priority_fee_subscriber.priority_fee_nth(0.9) } diff --git a/src/main.rs b/src/main.rs index 88dfa8f..e7bce81 100644 --- a/src/main.rs +++ b/src/main.rs @@ -277,12 +277,22 @@ async fn main() -> std::io::Result<()> { let tx_commitment = CommitmentConfig::from_str(&config.tx_commitment) .expect("one of: processed | confirmed | finalized"); let extra_rpcs = config.extra_rpcs.as_ref(); + + let mut sub_account_ids = vec![config.default_sub_account_id]; + sub_account_ids.extend( + config + .active_sub_accounts + .split(",") + .map(|s| s.parse::().unwrap()), + ); + sub_account_ids.dedup(); + let state = AppState::new( &config.rpc_host, config.dev, wallet, Some((state_commitment, tx_commitment)), - Some(config.default_sub_account_id), + sub_account_ids.clone(), config.skip_tx_preflight, extra_rpcs .map(|s| s.split(",").collect()) @@ -329,17 +339,17 @@ async fn main() -> std::io::Result<()> { if delegate.is_some() { info!( target: LOG_TARGET, - "🪪 authority: {:?}, default sub-account: {:?}, 🔑 delegate: {:?}", + "🪪 authority: {:?}, sub-accounts: {:?}, 🔑 delegate: {:?}", state.authority(), - state.default_sub_account(), + sub_account_ids, state.signer(), ); } else { info!( target: LOG_TARGET, - "🪪 authority: {:?}, default sub-account: {:?}", + "🪪 authority: {:?}, sub-accounts: {:?}", state.authority(), - state.default_sub_account() + sub_account_ids ); if emulate.is_some() { warn!("using emulation mode, tx signing unavailable"); @@ -454,6 +464,22 @@ fn handle_deser_error(err: serde_json::Error) -> Either ))) } +fn default_swift_node() -> String { + let strings: Vec = std::env::args_os() + .map(|s| s.into_string()) + .collect::, _>>() + .unwrap_or_else(|arg| { + eprintln!("Invalid utf8: {}", arg.to_string_lossy()); + std::process::exit(1) + }); + let is_dev = strings.iter().any(|s| s.to_string() == "--dev".to_string()); + if is_dev { + "https://master.swift.drift.trade".to_string() + } else { + "https://swift.drift.trade".to_string() + } +} + #[derive(FromArgs)] /// Drift gateway server struct GatewayConfig { @@ -466,8 +492,8 @@ struct GatewayConfig { #[argh(option)] markets: Option, /// swift node url - #[argh(option)] - swift_node: Option, + #[argh(option, default = "default_swift_node()")] + swift_node: String, /// run in devnet mode #[argh(switch)] dev: bool, @@ -500,6 +526,9 @@ struct GatewayConfig { /// default sub_account_id to use (default: 0) #[argh(option, default = "0")] default_sub_account_id: u16, + /// list of active sub_account_ids to use (default: 0) + #[argh(option, default = "String::from(\"0\")")] + active_sub_accounts: String, /// skip tx preflight checks #[argh(switch)] skip_tx_preflight: bool, @@ -551,7 +580,17 @@ mod tests { }; let rpc_endpoint = std::env::var("TEST_RPC_ENDPOINT") .unwrap_or_else(|_| "https://api.devnet.solana.com".to_string()); - AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await + AppState::new( + &rpc_endpoint, + true, + wallet, + None, + vec![0], + false, + vec![], + "https://master.swift.drift.trade".to_string(), + ) + .await } // likely safe to ignore during development, mainly regression test for CI @@ -572,8 +611,17 @@ mod tests { let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT") .unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string()); - let state = - AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await; + let state = AppState::new( + &rpc_endpoint, + true, + wallet, + None, + vec![], + false, + vec![], + "https://master.swift.drift.trade".to_string(), + ) + .await; let app = test::init_service( App::new() @@ -616,10 +664,10 @@ mod tests { false, wallet, None, - None, + vec![], false, vec![], - None, + "https://master.swift.drift.trade".to_string(), ) .await; @@ -659,10 +707,10 @@ mod tests { false, wallet, None, - None, + vec![], false, vec![], - None, + "https://master.swift.drift.trade".to_string(), ) .await; diff --git a/src/types.rs b/src/types.rs index 1f374a7..943b3cc 100644 --- a/src/types.rs +++ b/src/types.rs @@ -417,6 +417,7 @@ impl PlaceOrder { stop_loss_order_params: None, // TODO: add stop loss order params }; + // TODO: support delegate signed message type here let signed_order_type = SignedOrderType::Authority(order); let borsh_encoding = signed_order_type.to_borsh();