diff --git a/Cargo.lock b/Cargo.lock index ac67bc0..793f2a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1586,18 +1586,23 @@ version = "1.5.2" dependencies = [ "actix-web", "argh", + "base64 0.22.1", "drift-rs", - "env_logger 0.9.3", + "env_logger 0.11.7", + "faster-hex", "futures-util", "log", + "nanoid", + "reqwest 0.12.15", "rust_decimal", "serde", "serde_json", + "sha256", "solana-account-decoder-client-types", "solana-rpc-client-api", "solana-sdk", "solana-transaction-status", - "thiserror 1.0.69", + "thiserror 2.0.12", "tokio", "tokio-tungstenite", ] @@ -1805,6 +1810,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "faster-hex" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7223ae2d2f179b803433d9c830478527e92b8117eab39460edae7f1614d9fb73" +dependencies = [ + "heapless", + "serde", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -2101,6 +2116,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -2131,6 +2155,16 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32 0.3.1", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.3.3" @@ -2932,6 +2966,15 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "nanoid" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" +dependencies = [ + "rand 0.8.5", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -4094,6 +4137,19 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha256" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f880fc8562bdeb709793f00eb42a2ad0e672c4f883bbe59122b926eca935c8f6" +dependencies = [ + "async-trait", + "bytes", + "hex", + "sha2 0.10.8", + "tokio", +] + [[package]] name = "sha3" version = "0.10.8" @@ -5365,7 +5421,7 @@ checksum = "66a3ce7a0f4d6830124ceb2c263c36d1ee39444ec70146eb49b939e557e72b96" dependencies = [ "byteorder", "combine", - "hash32", + "hash32 0.2.1", "libc", "log", "rand 0.8.5", diff --git a/Cargo.toml b/Cargo.toml index e80e21e..675f2b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,18 +7,23 @@ edition = "2021" actix-web = "*" argh = "*" drift-rs = { git = "https://github.com/drift-labs/drift-rs", rev = "6d2d6e0" } +base64 = "0.22.1" env_logger = "*" +faster-hex = "0.10.0" futures-util = "*" log = "*" +nanoid = "0.4.0" +reqwest = { version = "*", features = ["json"] } rust_decimal = "*" serde = { version = "*", features = ["derive"] } serde_json = "*" +sha256 = "1.6.0" solana-account-decoder-client-types = "2.2.2" solana-rpc-client-api = "2" solana-sdk = "2" solana-transaction-status = "2" thiserror = "*" -tokio = "*" +tokio = {version ="*", features = ["full"]} tokio-tungstenite = "*" [profile.release] diff --git a/src/controller.rs b/src/controller.rs index a4deb02..419e55d 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -20,19 +20,21 @@ use drift_rs::{ }, }, priority_fee_subscriber::{PriorityFeeSubscriber, PriorityFeeSubscriberConfig}, + slot_subscriber::SlotSubscriber, types::{ - self, accounts::SpotMarket, MarketId, MarketType, ModifyOrderParams, OrderStatus, - ProgramError, RpcSendTransactionConfig, SdkError, SdkResult, VersionedMessage, + self, accounts::SpotMarket, MarketId, MarketType, ModifyOrderParams, OrderParams, + OrderStatus, ProgramError, RpcSendTransactionConfig, SdkError, SdkResult, VersionedMessage, }, utils::get_http_url, DriftClient, Pubkey, TransactionBuilder, Wallet, }; use futures_util::{ stream::{FuturesOrdered, FuturesUnordered}, - StreamExt, + FutureExt, StreamExt, }; -use log::{debug, info, warn}; +use log::{debug, info, trace, warn}; use rust_decimal::Decimal; +use sha256::digest; use solana_account_decoder_client_types::UiAccountEncoding; use solana_rpc_client_api::{ client_error::ErrorKind as ClientErrorKind, @@ -46,15 +48,23 @@ use crate::{ types::{ get_market_decimals, scale_decimal_to_u64, AllMarketsResponse, AuthorityResponse, CancelAndPlaceRequest, CancelOrdersRequest, GetOrdersRequest, GetOrdersResponse, - GetPositionsRequest, GetPositionsResponse, Market, MarketInfoResponse, ModifyOrdersRequest, - Order, PerpPosition, PerpPositionExtended, PlaceOrdersRequest, SolBalanceResponse, - SpotPosition, SwapRequest, TxEventsResponse, TxResponse, UserCollateralResponse, - UserLeverageResponse, UserMarginResponse, PRICE_DECIMALS, + GetPositionsRequest, GetPositionsResponse, IncomingSignedMessage, Market, + MarketInfoResponse, ModifyOrdersRequest, Order, PerpPosition, PerpPositionExtended, + PlaceOrderResponse, PlaceOrderType, PlaceOrdersRequest, SignedMsgOrderResult, + SignedMsgResponse, SolBalanceResponse, SpotPosition, SwapRequest, TxEventsResponse, + TxResponse, UserCollateralResponse, UserLeverageResponse, UserMarginResponse, + PRICE_DECIMALS, }, websocket::map_drift_event_for_account, Context, LOG_TARGET, }; +use base64::{ + alphabet, + engine::{self, general_purpose}, + Engine as _, +}; + /// Default TTL in seconds of gateway tx retry /// after which gateway will no longer resubmit or monitor the tx // ~15 slots @@ -85,8 +95,11 @@ pub struct AppState { /// skip tx preflight on send or not (default: false) skip_tx_preflight: bool, priority_fee_subscriber: Arc, + slot_subscriber: Arc, /// list of additional RPC endpoints for tx broadcast extra_rpcs: Vec>, + /// swift node url + swift_node: Option, } impl AppState { @@ -123,6 +136,7 @@ impl AppState { default_subaccount_id: Option, skip_tx_preflight: bool, extra_rpcs: Vec<&str>, + swift_node: Option, ) -> Self { let (state_commitment, tx_commitment) = commitment.unwrap_or((CommitmentConfig::confirmed(), CommitmentConfig::confirmed())); @@ -167,17 +181,25 @@ impl AppState { priority_fee_subscriber.subscribe() }; + let mut slot_subscriber = SlotSubscriber::new(client.ws()); + slot_subscriber + .subscribe(|new_slot| { + trace!(target: LOG_TARGET, "app_state slot_updated: {:#?}", new_slot); + }) + .expect("slot subscribed"); Self { client: Arc::new(client), tx_commitment, default_subaccount_id: default_subaccount_id.unwrap_or(0), skip_tx_preflight, priority_fee_subscriber, + slot_subscriber: Arc::new(slot_subscriber), wallet: Arc::new(wallet), extra_rpcs: extra_rpcs .into_iter() .map(|u| Arc::new(RpcClient::new(get_http_url(u).expect("valid RPC url")))) .collect(), + swift_node, } } @@ -253,7 +275,7 @@ impl AppState { let open_orders = self.client.all_orders(&self.default_sub_account()).await?; - let mut user_markets: Vec = all_spot + let user_markets: Vec = all_spot .iter() .map(|s| MarketId::spot(s.market_index)) .chain(all_perp.iter().map(|p| MarketId::perp(p.market_index))) @@ -265,7 +287,6 @@ impl AppState { } })) .collect(); - user_markets.push(MarketId::QUOTE_SPOT); Ok(user_markets) } @@ -578,32 +599,130 @@ impl AppState { &self, ctx: Context, req: PlaceOrdersRequest, - ) -> GatewayResult { + ) -> GatewayResult { let sub_account = self.resolve_sub_account(ctx.sub_account_id); let account_data = self.client.get_user_account(&sub_account).await?; let pf = self.get_priority_fee(); let priority_fee = ctx.cu_price.unwrap_or(pf); debug!(target: LOG_TARGET, "priority fee: {priority_fee:?}"); - let orders = req - .orders - .into_iter() - .map(|o| { - let base_decimals = get_market_decimals(self.client.program_data(), o.market); - o.to_order_params(base_decimals) - }) - .collect(); - let tx = TransactionBuilder::new( - self.client.program_data(), - sub_account, - Cow::Owned(account_data), - self.wallet.is_delegated(), - ) - .with_priority_fee(priority_fee, ctx.cu_limit) - .place_orders(orders) - .build(); + let orders_iter = req.orders.into_iter(); + match req.place_order_type { + PlaceOrderType::Tx => { + let orders = orders_iter + .map(|o| { + let base_decimals = + get_market_decimals(self.client.program_data(), o.market); + o.to_order_params(base_decimals) + }) + .collect(); + let tx = TransactionBuilder::new( + self.client.program_data(), + sub_account, + Cow::Owned(account_data), + self.wallet.is_delegated(), + ) + .with_priority_fee(priority_fee, ctx.cu_limit) + .place_orders(orders) + .build(); + + let tx_res = self.send_tx(tx, "place_orders", ctx.ttl).await; + match tx_res { + Ok(tx_res) => Ok(PlaceOrderResponse::Tx(tx_res)), + Err(e) => Err(e), + } + } + PlaceOrderType::SignedMsg => { + let orders_len = orders_iter.len(); + let mut signed_messages = Vec::with_capacity(orders_len); + let mut hashes: Vec = Vec::with_capacity(orders_len); + let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id); + let current_slot = self.slot_subscriber.current_slot(); + let orders_with_hex: Vec<(OrderParams, Vec)> = orders_iter + .map(|order| { + let base_decimals = + get_market_decimals(self.client.program_data(), order.market); + let order_for_signing_hex = order.clone(); + let order_params = order.to_order_params(base_decimals); + ( + order_params, + order_for_signing_hex.to_signed_order_hex( + order_params, + current_slot, + sub_account_id, + ), + ) + }) + .collect(); + + for (order, message) in orders_with_hex { + let signature = self.wallet.sign_message(message.as_slice())?; + let market_type: &'static str = match order.market_type { + MarketType::Spot => "spot", + MarketType::Perp => "perp", + }; + let incoming_msg = IncomingSignedMessage { + taker_authority: self.authority().to_string(), + signature: general_purpose::STANDARD.encode(signature), + message: String::from_utf8(message).unwrap(), + signing_authority: self.signer().to_string(), + market_type, + market_index: order.market_index, + }; + + signed_messages.push(incoming_msg); + let hash = digest(signature.as_ref()); + hashes.push(hash); + } - self.send_tx(tx, "place_orders", ctx.ttl).await + let client = reqwest::Client::new(); + + let swift_orders_url = self + .swift_node + .clone() + .unwrap_or("https://master.swift.drift.trade".to_string()) + + "/orders"; + + let mut futures = FuturesOrdered::new(); + for msg in signed_messages { + let future = + client + .post(&swift_orders_url) + .json(&msg) + .send() + .then(|resp| async move { + match resp { + Ok(response) => { + let status = response.status(); + let response_text = + response.text().await.unwrap_or_default(); + (status.to_string(), response_text) + } + Err(e) => { + ("500".to_string(), format!("swift server error: {:?}", e)) + } + } + }); + futures.push_back(future); + } + + let responses: Vec<_> = futures.collect().await; + + let signed_msg = SignedMsgResponse { + results: hashes + .iter() + .zip(responses) + .map(|(hash, (status, response))| SignedMsgOrderResult { + hash: hash.clone(), + status: status.clone(), + error: Some(response.clone()), + }) + .collect(), + }; + + Ok(PlaceOrderResponse::SignedMsg(signed_msg)) + } + } } pub async fn modify_orders( diff --git a/src/main.rs b/src/main.rs index 897fd9d..88dfa8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -287,6 +287,7 @@ async fn main() -> std::io::Result<()> { extra_rpcs .map(|s| s.split(",").collect()) .unwrap_or_default(), + config.swift_node, ) .await; @@ -464,6 +465,9 @@ struct GatewayConfig { /// gateway creates market subscriptions for responsive trading #[argh(option)] markets: Option, + /// swift node url + #[argh(option)] + swift_node: Option, /// run in devnet mode #[argh(switch)] dev: bool, @@ -547,7 +551,7 @@ mod tests { }; let rpc_endpoint = std::env::var("TEST_RPC_ENDPOINT") .unwrap_or_else(|_| "https://api.devnet.solana.com".to_string()); - AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![]).await + AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await } // likely safe to ignore during development, mainly regression test for CI @@ -568,7 +572,8 @@ mod tests { let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT") .unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string()); - let state = AppState::new(&rpc_endpoint, false, wallet, None, None, false, vec![]).await; + let state = + AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await; let app = test::init_service( App::new() @@ -606,7 +611,17 @@ mod tests { let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT") .unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string()); - let state = AppState::new(&rpc_endpoint, false, wallet, None, None, false, vec![]).await; + let state = AppState::new( + &rpc_endpoint, + false, + wallet, + None, + None, + false, + vec![], + None, + ) + .await; let app = test::init_service(App::new().app_data(web::Data::new(state)).service(swap)).await; @@ -639,7 +654,17 @@ mod tests { let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT") .unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string()); - let state = AppState::new(&rpc_endpoint, false, wallet, None, None, false, vec![]).await; + let state = AppState::new( + &rpc_endpoint, + false, + wallet, + None, + None, + false, + vec![], + None, + ) + .await; let app = test::init_service(App::new().app_data(web::Data::new(state)).service(swap)).await; diff --git a/src/types.rs b/src/types.rs index 54d4a4a..1f374a7 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,15 +8,18 @@ use drift_rs::{ constants::{BASE_PRECISION, PRICE_PRECISION, QUOTE_PRECISION}, liquidation::{CollateralInfo, MarginRequirementInfo}, }, + swift_order_subscriber::SignedOrderType, types::{ self as sdk_types, accounts::{PerpMarket, SpotMarket}, MarketPrecision, MarketType, ModifyOrderParams, OrderParams, OrderTriggerCondition, - PositionDirection, PostOnlyParam, + PositionDirection, PostOnlyParam, SignedMsgOrderParamsMessage, }, }; +use nanoid::nanoid; use rust_decimal::Decimal; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::convert::TryInto; use crate::websocket::AccountEvent; @@ -244,12 +247,15 @@ impl ModifyOrder { } #[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] pub struct PlaceOrdersRequest { pub orders: Vec, + #[serde(default)] + pub place_order_type: PlaceOrderType, } #[cfg_attr(test, derive(Default))] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct PlaceOrder { #[serde(flatten)] @@ -277,6 +283,41 @@ pub struct PlaceOrder { #[serde(default)] oracle_price_offset: Option, max_ts: Option, + #[serde(default)] + auction_duration: Option, + #[serde(default)] + auction_start_price: Option, + #[serde(default)] + auction_end_price: Option, +} + +#[derive(Serialize, Debug)] +pub enum PlaceOrderType { + Tx, + SignedMsg, +} + +impl Default for PlaceOrderType { + fn default() -> Self { + Self::Tx + } +} + +impl<'de> Deserialize<'de> for PlaceOrderType { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s: &str = Deserialize::deserialize(deserializer)?; + match s { + "tx" => Ok(PlaceOrderType::Tx), + "swift" => Ok(PlaceOrderType::SignedMsg), + _ => Err(serde::de::Error::custom(format!( + "unknown place order type: {}", + s + ))), + } + } } pub fn ser_market_type(x: &MarketType, s: S) -> Result @@ -293,8 +334,8 @@ pub fn de_market_type<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { - let s = String::deserialize(deserializer)?; - match s.as_str() { + let s: &str = Deserialize::deserialize(deserializer)?; + match s { "perp" => Ok(MarketType::Perp), "spot" => Ok(MarketType::Spot), _ => Err(serde::de::Error::custom(format!( @@ -326,6 +367,10 @@ impl PlaceOrder { 0 }; + let oracle_price_offset = self + .oracle_price_offset + .map(|x| scale_decimal_to_i64(x, PRICE_PRECISION as u32) as i32); + OrderParams { market_index: self.market.market_index, market_type: self.market.market_type, @@ -344,17 +389,42 @@ impl PlaceOrder { PostOnlyParam::None }, user_order_id: self.user_order_id, - oracle_price_offset: self - .oracle_price_offset - .map(|x| scale_decimal_to_i64(x, PRICE_PRECISION as u32) as i32), + oracle_price_offset, max_ts: self.max_ts, trigger_price: self .trigger_price .map(|v| scale_decimal_to_u64(v, PRICE_PRECISION as u32)), trigger_condition: self.trigger_condition.unwrap_or_default(), + auction_duration: Some(self.auction_duration.unwrap_or(20)), + auction_start_price: self.auction_start_price, + auction_end_price: self.auction_end_price, ..Default::default() } } + + pub fn to_signed_order_hex( + self, + order_params: OrderParams, + slot: u64, + sub_account_id: u16, + ) -> Vec { + let order = SignedMsgOrderParamsMessage { + signed_msg_order_params: order_params, + slot, + uuid: nanoid!(8).as_bytes().try_into().unwrap(), + sub_account_id, + take_profit_order_params: None, // TODO: add take profit order params + stop_loss_order_params: None, // TODO: add stop loss order params + }; + + let signed_order_type = SignedOrderType::Authority(order); + + let borsh_encoding = signed_order_type.to_borsh(); + let borsh_bytes = borsh_encoding.as_slice(); + let mut hex_bytes = vec![0; borsh_bytes.len() * 2]; // 2 hex bytes per msg byte + let _ = faster_hex::hex_encode(borsh_bytes, &mut hex_bytes).expect("hexified"); + hex_bytes + } } #[cfg_attr(test, derive(Default))] @@ -534,6 +604,25 @@ impl TxEventsResponse { } } +#[derive(Serialize, Deserialize, Debug)] +pub struct SignedMsgResponse { + pub results: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SignedMsgOrderResult { + pub hash: String, + pub status: String, + pub error: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase", untagged)] +pub enum PlaceOrderResponse { + Tx(TxResponse), + SignedMsg(SignedMsgResponse), +} + #[derive(Debug, Serialize, Deserialize)] pub struct CancelAndPlaceRequest { pub cancel: CancelOrdersRequest, @@ -610,6 +699,16 @@ impl From for UserCollateralResponse { } } +#[derive(serde::Serialize, Clone, Debug, PartialEq)] +pub struct IncomingSignedMessage { + pub taker_authority: String, + pub signature: String, + pub message: String, + pub signing_authority: String, + pub market_type: &'static str, + pub market_index: u16, +} + #[cfg(test)] mod tests { use std::str::FromStr;