Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ Options:
--commitment solana commitment level to use for state updates (default:
confirmed)
--default-sub-account-id
default sub_account_id to use (default: 0)
default sub_account_id to use as default. Use the new active-sub-accounts param to subscribe to multiple sub accounts. This param will override active-sub-accounts.
--active-sub-accounts sub accounts to subscribe to. (default: 0)
--skip-tx-preflight
skip tx preflight checks
--extra-rpcs extra solana RPC urls for improved Tx broadcast
Expand Down
93 changes: 60 additions & 33 deletions src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ pub struct AppState {
/// Solana tx commitment level for preflight confirmation
tx_commitment: CommitmentConfig,
/// default sub_account_id to use if not provided
default_subaccount_id: u16,
sub_account_ids: Vec<u16>,
/// skip tx preflight on send or not (default: false)
skip_tx_preflight: bool,
priority_fee_subscriber: Arc<PriorityFeeSubscriber>,
slot_subscriber: Arc<SlotSubscriber>,
/// list of additional RPC endpoints for tx broadcast
extra_rpcs: Vec<Arc<RpcClient>>,
/// swift node url
swift_node: Option<String>,
swift_node: String,
}

impl AppState {
Expand All @@ -111,12 +111,12 @@ impl AppState {
pub fn signer(&self) -> Pubkey {
self.wallet.signer()
}
pub fn default_sub_account(&self) -> Pubkey {
self.wallet.sub_account(self.default_subaccount_id)
pub fn sub_account(&self, sub_account_id: u16) -> Pubkey {
self.wallet.sub_account(sub_account_id)
}
pub fn resolve_sub_account(&self, sub_account_id: Option<u16>) -> Pubkey {
self.wallet
.sub_account(sub_account_id.unwrap_or(self.default_subaccount_id))
.sub_account(sub_account_id.unwrap_or(self.sub_account_ids[0]))
}

/// Initialize Gateway Drift client
Expand All @@ -125,18 +125,18 @@ impl AppState {
/// * `devnet` - whether to run against devnet or not
/// * `wallet` - wallet to use for tx signing
/// * `commitment` - Slot finalisation/commitement levels
/// * `default_subaccount_id` - by default all queries will use this sub-account
/// * `sub_account_ids` - the sub_accounts to subscribe too. In your query specify a specific subaccount, otherwise subaccount 0 will be used as default
/// * `skip_tx_preflight` - submit txs without checking preflight results
/// * `extra_rpcs` - list of additional RPC endpoints for tx submission
pub async fn new(
endpoint: &str,
devnet: bool,
wallet: Wallet,
commitment: Option<(CommitmentConfig, CommitmentConfig)>,
default_subaccount_id: Option<u16>,
sub_account_ids: Vec<u16>,
skip_tx_preflight: bool,
extra_rpcs: Vec<&str>,
swift_node: Option<String>,
swift_node: String,
) -> Self {
let (state_commitment, tx_commitment) =
commitment.unwrap_or((CommitmentConfig::confirmed(), CommitmentConfig::confirmed()));
Expand All @@ -151,11 +151,13 @@ impl AppState {
.await
.expect("ok");

let default_subaccount = wallet.sub_account(default_subaccount_id.unwrap_or(0));
if let Err(err) = client.subscribe_account(&default_subaccount).await {
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}");
} else {
log::info!(target: LOG_TARGET, "subscribed to subaccount: {default_subaccount}");
for sub_account_id in sub_account_ids.clone() {
let sub_account = wallet.sub_account(sub_account_id);
if let Err(err) = client.subscribe_account(&sub_account).await {
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}. subaccount: {sub_account_id}");
} else {
log::info!(target: LOG_TARGET, "subscribed to subaccount: {sub_account}");
}
}

let priority_fee_subscriber = PriorityFeeSubscriber::with_config(
Expand Down Expand Up @@ -190,7 +192,7 @@ impl AppState {
Self {
client: Arc::new(client),
tx_commitment,
default_subaccount_id: default_subaccount_id.unwrap_or(0),
sub_account_ids,
skip_tx_preflight,
priority_fee_subscriber,
slot_subscriber: Arc::new(slot_subscriber),
Expand All @@ -207,11 +209,26 @@ impl AppState {
&self,
configured_markets: &[MarketId],
) -> Result<(), SdkError> {
let default_sub_account = self.default_sub_account();
let sub_account_ids = self.sub_account_ids.clone();
for id in sub_account_ids {
self.sync_market_subscriptions_on_user_subaccount_changes(configured_markets, id)
.await?;
}

Ok(())
}

async fn sync_market_subscriptions_on_user_subaccount_changes(
&self,
configured_markets: &[MarketId],
sub_account_id: u16,
) -> Result<(), SdkError> {
let sub_account = self.sub_account(sub_account_id);
let state_commitment = self.tx_commitment;
let configured_markets_vec = configured_markets.to_vec();
let self_clone = self.clone();
let mut current_user_markets_to_subscribe = self.get_marketids_to_subscribe().await?;
let mut current_user_markets_to_subscribe =
self.get_marketids_to_subscribe(sub_account).await?;

tokio::spawn(async move {
let pubsub_config = RpcAccountInfoConfig {
Expand All @@ -224,7 +241,7 @@ impl AppState {
let pubsub_client = self_clone.client.ws();

let (mut account_subscription, unsubscribe_fn) = match pubsub_client
.account_subscribe(&default_sub_account, Some(pubsub_config))
.account_subscribe(&sub_account, Some(pubsub_config))
.await
{
Ok(res) => res,
Expand All @@ -239,7 +256,7 @@ impl AppState {
// Process incoming account updates
while let Some(_) = account_subscription.next().await {
let current_market_ids_count = current_user_markets_to_subscribe.len();
match self_clone.get_marketids_to_subscribe().await {
match self_clone.get_marketids_to_subscribe(sub_account).await {
Ok(new_market_ids) => {
if new_market_ids.len() != current_market_ids_count {
if let Err(err) = self_clone
Expand Down Expand Up @@ -267,13 +284,13 @@ impl AppState {
Ok(())
}

async fn get_marketids_to_subscribe(&self) -> Result<Vec<MarketId>, SdkError> {
let (all_spot, all_perp) = self
.client
.all_positions(&self.default_sub_account())
.await?;
async fn get_marketids_to_subscribe(
&self,
sub_account: Pubkey,
) -> Result<Vec<MarketId>, SdkError> {
let (all_spot, all_perp) = self.client.all_positions(&sub_account).await?;

let open_orders = self.client.all_orders(&self.default_sub_account()).await?;
let open_orders = self.client.all_orders(&sub_account).await?;

let user_markets: Vec<MarketId> = all_spot
.iter()
Expand All @@ -296,11 +313,25 @@ impl AppState {
/// * configured_markets - list of static markets provided by user
///
/// additional subscriptions will be included based on user's current positions (on default sub-account)

pub(crate) async fn subscribe_market_data(
&self,
configured_markets: &[MarketId],
) -> Result<(), SdkError> {
let mut user_markets = self.get_marketids_to_subscribe().await?;
for id in self.sub_account_ids.clone() {
self.subscribe_market_data_for_subaccount(configured_markets, id)
.await?;
}
Ok(())
}

async fn subscribe_market_data_for_subaccount(
&self,
configured_markets: &[MarketId],
sub_account_id: u16,
) -> Result<(), SdkError> {
let sub_account = self.sub_account(sub_account_id);
let mut user_markets = self.get_marketids_to_subscribe(sub_account).await?;
user_markets.extend_from_slice(configured_markets);

let init_rpc_throttle: u64 = std::env::var("INIT_RPC_THROTTLE")
Expand Down Expand Up @@ -636,7 +667,7 @@ impl AppState {
let orders_len = orders_iter.len();
let mut signed_messages = Vec::with_capacity(orders_len);
let mut hashes: Vec<String> = Vec::with_capacity(orders_len);
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id);
let sub_account_id = ctx.sub_account_id.unwrap_or(self.sub_account_ids[0]);
let current_slot = self.slot_subscriber.current_slot();
let orders_with_hex: Vec<(OrderParams, Vec<u8>)> = orders_iter
.map(|order| {
Expand All @@ -663,7 +694,7 @@ impl AppState {
};
let incoming_msg = IncomingSignedMessage {
taker_authority: self.authority().to_string(),
signature: general_purpose::STANDARD.encode(signature),
signature: general_purpose::STANDARD.encode(signature), // TODO: test just using .to_string() for base64 encoding
message: String::from_utf8(message).unwrap(),
signing_authority: self.signer().to_string(),
market_type,
Expand All @@ -677,11 +708,7 @@ impl AppState {

let client = reqwest::Client::new();

let swift_orders_url = self
.swift_node
.clone()
.unwrap_or("https://master.swift.drift.trade".to_string())
+ "/orders";
let swift_orders_url = self.swift_node.clone() + "/orders";

let mut futures = FuturesOrdered::new();
for msg in signed_messages {
Expand Down Expand Up @@ -900,7 +927,7 @@ impl AppState {
ctx: Context,
new_margin_ratio: Decimal,
) -> GatewayResult<TxResponse> {
let sub_account_id = ctx.sub_account_id.unwrap_or(self.default_subaccount_id);
let sub_account_id = ctx.sub_account_id.unwrap_or(self.sub_account_ids[0]);
let sub_account_address = self.wallet.sub_account(sub_account_id);
let account_data = self.client.get_user_account(&sub_account_address).await?;

Expand Down
61 changes: 47 additions & 14 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,23 @@ async fn main() -> std::io::Result<()> {
let tx_commitment = CommitmentConfig::from_str(&config.tx_commitment)
.expect("one of: processed | confirmed | finalized");
let extra_rpcs = config.extra_rpcs.as_ref();

let sub_account_ids: Vec<u16> = if config.default_sub_account_id.is_some() {
vec![config.default_sub_account_id.unwrap()]
} else {
config
.active_sub_accounts
.split(",")
.map(|s| s.parse::<u16>().unwrap())
.collect()
};

let state = AppState::new(
&config.rpc_host,
config.dev,
wallet,
Some((state_commitment, tx_commitment)),
Some(config.default_sub_account_id),
sub_account_ids.clone(),
config.skip_tx_preflight,
extra_rpcs
.map(|s| s.split(",").collect())
Expand Down Expand Up @@ -331,15 +342,15 @@ async fn main() -> std::io::Result<()> {
target: LOG_TARGET,
"🪪 authority: {:?}, default sub-account: {:?}, 🔑 delegate: {:?}",
state.authority(),
state.default_sub_account(),
sub_account_ids.iter().map(|id| state.sub_account(*id)),
state.signer(),
);
} else {
info!(
target: LOG_TARGET,
"🪪 authority: {:?}, default sub-account: {:?}",
state.authority(),
state.default_sub_account()
state.sub_account(config.default_sub_account_id.unwrap_or(0))
);
if emulate.is_some() {
warn!("using emulation mode, tx signing unavailable");
Expand Down Expand Up @@ -466,8 +477,8 @@ struct GatewayConfig {
#[argh(option)]
markets: Option<String>,
/// swift node url
#[argh(option)]
swift_node: Option<String>,
#[argh(option, default = "String::from(\"https://master.swift.drift.trade\")")]
swift_node: String,
/// run in devnet mode
#[argh(switch)]
dev: bool,
Expand Down Expand Up @@ -498,8 +509,11 @@ struct GatewayConfig {
#[argh(option, default = "String::from(\"confirmed\")")]
commitment: String,
/// default sub_account_id to use (default: 0)
#[argh(option, default = "0")]
default_sub_account_id: u16,
#[argh(option)]
default_sub_account_id: Option<u16>,
/// list of active sub_account_ids to use (default: 0)
#[argh(option, default = "String::from(\"0\")")]
active_sub_accounts: String,
/// skip tx preflight checks
#[argh(switch)]
skip_tx_preflight: bool,
Expand Down Expand Up @@ -551,7 +565,17 @@ mod tests {
};
let rpc_endpoint = std::env::var("TEST_RPC_ENDPOINT")
.unwrap_or_else(|_| "https://api.devnet.solana.com".to_string());
AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await
AppState::new(
&rpc_endpoint,
true,
wallet,
None,
vec![0],
false,
vec![],
"https://master.swift.drift.trade".to_string(),
)
.await
}

// likely safe to ignore during development, mainly regression test for CI
Expand All @@ -572,8 +596,17 @@ mod tests {

let rpc_endpoint = std::env::var("TEST_MAINNET_RPC_ENDPOINT")
.unwrap_or_else(|_| "https://api.mainnet-beta.solana.com".to_string());
let state =
AppState::new(&rpc_endpoint, true, wallet, None, None, false, vec![], None).await;
let state = AppState::new(
&rpc_endpoint,
true,
wallet,
None,
vec![],
false,
vec![],
"https://master.swift.drift.trade".to_string(),
)
.await;

let app = test::init_service(
App::new()
Expand Down Expand Up @@ -616,10 +649,10 @@ mod tests {
false,
wallet,
None,
None,
vec![],
false,
vec![],
None,
"https://master.swift.drift.trade".to_string(),
)
.await;

Expand Down Expand Up @@ -659,10 +692,10 @@ mod tests {
false,
wallet,
None,
None,
vec![],
false,
vec![],
None,
"https://master.swift.drift.trade".to_string(),
)
.await;

Expand Down
1 change: 1 addition & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ impl PlaceOrder {
stop_loss_order_params: None, // TODO: add stop loss order params
};

// TODO: support delegate signed message type here
let signed_order_type = SignedOrderType::Authority(order);

let borsh_encoding = signed_order_type.to_borsh();
Expand Down
Loading