Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ Options:
--commitment solana commitment level to use for state updates (default:
confirmed)
--default-sub-account-id
default sub_account_id to use (default: 0)
default sub_account_id to use as default. Use the new active-sub-accounts param to subscribe to multiple sub accounts. This param will override active-sub-accounts.
--active-sub-accounts sub accounts to subscribe to. (default: 0)
--skip-tx-preflight
skip tx preflight checks
--extra-rpcs extra solana RPC urls for improved Tx broadcast
Expand Down
99 changes: 65 additions & 34 deletions src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ pub struct AppState {
pub client: Arc<DriftClient>,
/// Solana tx commitment level for preflight confirmation
tx_commitment: CommitmentConfig,
/// default sub_account_id to use if not provided
default_subaccount_id: u16,
/// sub_account_ids to subscribe to
sub_account_ids: Vec<u16>,
/// skip tx preflight on send or not (default: false)
skip_tx_preflight: bool,
priority_fee_subscriber: Arc<PriorityFeeSubscriber>,
slot_subscriber: Arc<SlotSubscriber>,
/// list of additional RPC endpoints for tx broadcast
extra_rpcs: Vec<Arc<RpcClient>>,
/// swift node url
swift_node: Option<String>,
swift_node: String,
}

impl AppState {
Expand All @@ -111,12 +111,12 @@ impl AppState {
pub fn signer(&self) -> Pubkey {
self.wallet.signer()
}
pub fn default_sub_account(&self) -> Pubkey {
self.wallet.sub_account(self.default_subaccount_id)
pub fn sub_account(&self, sub_account_id: u16) -> Pubkey {
self.wallet.sub_account(sub_account_id)
}
pub fn resolve_sub_account(&self, sub_account_id: Option<u16>) -> Pubkey {
self.wallet
.sub_account(sub_account_id.unwrap_or(self.default_subaccount_id))
.sub_account(sub_account_id.unwrap_or(self.default_sub_account_id()))
}

/// Initialize Gateway Drift client
Expand All @@ -125,18 +125,18 @@ impl AppState {
/// * `devnet` - whether to run against devnet or not
/// * `wallet` - wallet to use for tx signing
/// * `commitment` - Slot finalisation/commitement levels
/// * `default_subaccount_id` - by default all queries will use this sub-account
/// * `sub_account_ids` - the sub_accounts to subscribe too. In your query specify a specific subaccount, otherwise subaccount 0 will be used as default
/// * `skip_tx_preflight` - submit txs without checking preflight results
/// * `extra_rpcs` - list of additional RPC endpoints for tx submission
pub async fn new(
endpoint: &str,
devnet: bool,
wallet: Wallet,
commitment: Option<(CommitmentConfig, CommitmentConfig)>,
default_subaccount_id: Option<u16>,
sub_account_ids: Vec<u16>,
skip_tx_preflight: bool,
extra_rpcs: Vec<&str>,
swift_node: Option<String>,
swift_node: String,
) -> Self {
let (state_commitment, tx_commitment) =
commitment.unwrap_or((CommitmentConfig::confirmed(), CommitmentConfig::confirmed()));
Expand All @@ -151,11 +151,13 @@ impl AppState {
.await
.expect("ok");

let default_subaccount = wallet.sub_account(default_subaccount_id.unwrap_or(0));
if let Err(err) = client.subscribe_account(&default_subaccount).await {
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}");
} else {
log::info!(target: LOG_TARGET, "subscribed to subaccount: {default_subaccount}");
for sub_account_id in &sub_account_ids {
let sub_account = wallet.sub_account(*sub_account_id);
if let Err(err) = client.subscribe_account(&sub_account).await {
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}. subaccount: {sub_account_id}");
} else {
log::info!(target: LOG_TARGET, "subscribed to subaccount: {sub_account}");
}
}

let priority_fee_subscriber = PriorityFeeSubscriber::with_config(
Expand Down Expand Up @@ -190,7 +192,7 @@ impl AppState {
Self {
client: Arc::new(client),
tx_commitment,
default_subaccount_id: default_subaccount_id.unwrap_or(0),
sub_account_ids,
skip_tx_preflight,
priority_fee_subscriber,
slot_subscriber: Arc::new(slot_subscriber),
Expand All @@ -207,11 +209,26 @@ impl AppState {
&self,
configured_markets: &[MarketId],
) -> Result<(), SdkError> {
let default_sub_account = self.default_sub_account();
let sub_account_ids = self.sub_account_ids.clone();
for id in sub_account_ids {
self.sync_market_subscriptions_on_user_subaccount_changes(configured_markets, id)
.await?;
}

Ok(())
}

async fn sync_market_subscriptions_on_user_subaccount_changes(
&self,
configured_markets: &[MarketId],
sub_account_id: u16,
) -> Result<(), SdkError> {
let sub_account = self.sub_account(sub_account_id);
let state_commitment = self.tx_commitment;
let configured_markets_vec = configured_markets.to_vec();
let self_clone = self.clone();
let mut current_user_markets_to_subscribe = self.get_marketids_to_subscribe().await?;
let mut current_user_markets_to_subscribe =
self.get_marketids_to_subscribe(sub_account).await?;

tokio::spawn(async move {
let pubsub_config = RpcAccountInfoConfig {
Expand All @@ -224,7 +241,7 @@ impl AppState {
let pubsub_client = self_clone.client.ws();

let (mut account_subscription, unsubscribe_fn) = match pubsub_client
.account_subscribe(&default_sub_account, Some(pubsub_config))
.account_subscribe(&sub_account, Some(pubsub_config))
.await
{
Ok(res) => res,
Expand All @@ -239,7 +256,7 @@ impl AppState {
// Process incoming account updates
while let Some(_) = account_subscription.next().await {
let current_market_ids_count = current_user_markets_to_subscribe.len();
match self_clone.get_marketids_to_subscribe().await {
match self_clone.get_marketids_to_subscribe(sub_account).await {
Ok(new_market_ids) => {
if new_market_ids.len() != current_market_ids_count {
if let Err(err) = self_clone
Expand Down Expand Up @@ -267,13 +284,13 @@ impl AppState {
Ok(())
}

async fn get_marketids_to_subscribe(&self) -> Result<Vec<MarketId>, SdkError> {
let (all_spot, all_perp) = self
.client
.all_positions(&self.default_sub_account())
.await?;
async fn get_marketids_to_subscribe(
&self,
sub_account: Pubkey,
) -> Result<Vec<MarketId>, SdkError> {
let (all_spot, all_perp) = self.client.all_positions(&sub_account).await?;

let open_orders = self.client.all_orders(&self.default_sub_account()).await?;
let open_orders = self.client.all_orders(&sub_account).await?;

let user_markets: Vec<MarketId> = all_spot
.iter()
Expand All @@ -296,11 +313,25 @@ impl AppState {
/// * configured_markets - list of static markets provided by user
///
/// additional subscriptions will be included based on user's current positions (on default sub-account)

pub(crate) async fn subscribe_market_data(
&self,
configured_markets: &[MarketId],
) -> Result<(), SdkError> {
let mut user_markets = self.get_marketids_to_subscribe().await?;
for id in self.sub_account_ids.clone() {
self.subscribe_market_data_for_subaccount(configured_markets, id)
.await?;
}
Ok(())
}

async fn subscribe_market_data_for_subaccount(
&self,
configured_markets: &[MarketId],
sub_account_id: u16,
) -> Result<(), SdkError> {
let sub_account = self.sub_account(sub_account_id);
let mut user_markets = self.get_marketids_to_subscribe(sub_account).await?;
user_markets.extend_from_slice(configured_markets);

let init_rpc_throttle: u64 = std::env::var("INIT_RPC_THROTTLE")
Expand Down Expand Up @@ -636,7 +667,7 @@ impl AppState {
let orders_len = orders_iter.len();
let mut signed_messages = Vec::with_capacity(orders_len);
let mut hashes: Vec<String> = Vec::with_capacity(orders_len);
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id);
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_sub_account_id());
let current_slot = self.slot_subscriber.current_slot();
let orders_with_hex: Vec<(OrderParams, Vec<u8>)> = orders_iter
.map(|order| {
Expand All @@ -663,7 +694,7 @@ impl AppState {
};
let incoming_msg = IncomingSignedMessage {
taker_authority: self.authority().to_string(),
signature: general_purpose::STANDARD.encode(signature),
signature: general_purpose::STANDARD.encode(signature), // TODO: test just using .to_string() for base64 encoding
message: String::from_utf8(message).unwrap(),
signing_authority: self.signer().to_string(),
market_type,
Expand All @@ -677,11 +708,7 @@ impl AppState {

let client = reqwest::Client::new();

let swift_orders_url = self
.swift_node
.clone()
.unwrap_or("https://master.swift.drift.trade".to_string())
+ "/orders";
let swift_orders_url = self.swift_node.clone() + "/orders";

let mut futures = FuturesOrdered::new();
for msg in signed_messages {
Expand Down Expand Up @@ -900,7 +927,7 @@ impl AppState {
ctx: Context,
new_margin_ratio: Decimal,
) -> GatewayResult<TxResponse> {
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id);
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_sub_account_id());
let sub_account_address = self.wallet.sub_account(sub_account_id);
let account_data = self.client.get_user_account(&sub_account_address).await?;

Expand All @@ -921,6 +948,10 @@ impl AppState {
self.send_tx(tx, "set_margin_ratio", ctx.ttl).await
}

pub fn default_sub_account_id(&self) -> u16 {
self.sub_account_ids[0]
}

fn get_priority_fee(&self) -> u64 {
self.priority_fee_subscriber.priority_fee_nth(0.9)
}
Expand Down
76 changes: 62 additions & 14 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,22 @@ async fn main() -> std::io::Result<()> {
let tx_commitment = CommitmentConfig::from_str(&config.tx_commitment)
.expect("one of: processed | confirmed | finalized");
let extra_rpcs = config.extra_rpcs.as_ref();

let mut sub_account_ids = vec![config.default_sub_account_id];
sub_account_ids.extend(
config
.active_sub_accounts
.split(",")
.map(|s| s.parse::<u16>().unwrap()),
);
sub_account_ids.dedup();

let state = AppState::new(
&config.rpc_host,
config.dev,
wallet,
Some((state_commitment, tx_commitment)),
Some(config.default_sub_account_id),
sub_account_ids.clone(),
config.skip_tx_preflight,
extra_rpcs
.map(|s| s.split(",").collect())
Expand Down Expand Up @@ -329,17 +339,17 @@ async fn main() -> std::io::Result<()> {
if delegate.is_some() {
info!(
target: LOG_TARGET,
"🪪 authority: {:?}, default sub-account: {:?}, 🔑 delegate: {:?}",
"🪪 authority: {:?}, sub-accounts: {:?}, 🔑 delegate: {:?}",
state.authority(),
state.default_sub_account(),
sub_account_ids,
state.signer(),
);
} else {
info!(
target: LOG_TARGET,
"🪪 authority: {:?}, default sub-account: {:?}",
"🪪 authority: {:?}, sub-accounts: {:?}",
state.authority(),
state.default_sub_account()
sub_account_ids
);
if emulate.is_some() {
warn!("using emulation mode, tx signing unavailable");
Expand Down Expand Up @@ -454,6 +464,22 @@ fn handle_deser_error<T>(err: serde_json::Error) -> Either<HttpResponse, Json<T>
)))
}

fn default_swift_node() -> String {
let strings: Vec<String> = std::env::args_os()
.map(|s| s.into_string())
.collect::<Result<Vec<_>, _>>()
.unwrap_or_else(|arg| {
eprintln!("Invalid utf8: {}", arg.to_string_lossy());
std::process::exit(1)
});
let is_dev = strings.iter().any(|s| s.to_string() == "--dev".to_string());
if is_dev {
"https://master.swift.drift.trade".to_string()
} else {
"https://swift.drift.trade".to_string()
}
}

#[derive(FromArgs)]
/// Drift gateway server
struct GatewayConfig {
Expand All @@ -466,8 +492,8 @@ struct GatewayConfig {
#[argh(option)]
markets: Option<String>,
/// swift node url
#[argh(option)]
swift_node: Option<String>,
#[argh(option, default = "default_swift_node()")]
swift_node: String,
/// run in devnet mode
#[argh(switch)]
dev: bool,
Expand Down Expand Up @@ -500,6 +526,9 @@ struct GatewayConfig {
/// default sub_account_id to use (default: 0)
#[argh(option, default = "0")]
default_sub_account_id: u16,
/// list of active sub_account_ids to use (default: 0)
#[argh(option, default = "String::from(\"0\")")]
active_sub_accounts: String,
/// skip tx preflight checks
#[argh(switch)]
skip_tx_preflight: bool,
Expand Down Expand Up @@ -551,7 +580,17 @@ mod tests {
};
let rpc_endpoint = std::env::var("TEST_RPC_ENDPOINT")
.unwrap_or_else(|_| "https://api.devnet.solana.com".to_string());
AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await
AppState::new(
&rpc_endpoint,
true,
wallet,
None,
vec![0],
false,
vec![],
"https://master.swift.drift.trade".to_string(),
)
.await
}

// likely safe to ignore during development, mainly regression test for CI
Expand All @@ -572,8 +611,17 @@ mod tests {

let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT")
.unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string());
let state =
AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await;
let state = AppState::new(
&rpc_endpoint,
true,
wallet,
None,
vec![],
false,
vec![],
"https://master.swift.drift.trade".to_string(),
)
.await;

let app = test::init_service(
App::new()
Expand Down Expand Up @@ -616,10 +664,10 @@ mod tests {
false,
wallet,
None,
None,
vec![],
false,
vec![],
None,
"https://master.swift.drift.trade".to_string(),
)
.await;

Expand Down Expand Up @@ -659,10 +707,10 @@ mod tests {
false,
wallet,
None,
None,
vec![],
false,
vec![],
None,
"https://master.swift.drift.trade".to_string(),
)
.await;

Expand Down
Loading
Loading