diff --git a/examples/account_summary.rs b/examples/account_summary.rs index fd6f1106..7f09ceec 100644 --- a/examples/account_summary.rs +++ b/examples/account_summary.rs @@ -15,7 +15,7 @@ fn main() { for update in &subscription { match update { AccountSummaries::Summary(summary) => println!("{summary:?}"), - AccountSummaries::End => subscription.cancel().expect("cancel failed"), + AccountSummaries::End => subscription.cancel(), } } } diff --git a/examples/account_updates.rs b/examples/account_updates.rs new file mode 100644 index 00000000..f45616d4 --- /dev/null +++ b/examples/account_updates.rs @@ -0,0 +1,20 @@ +use ibapi::accounts::AccountUpdate; +use ibapi::Client; + +fn main() { + env_logger::init(); + + let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + + let account = "DU1234567"; + + let subscription = client.account_updates(account).expect("error requesting account updates"); + for update in &subscription { + println!("{update:?}"); + + // stop after full initial update + if let AccountUpdate::End = update { + subscription.cancel(); + } + } +} diff --git a/examples/account_updates_multi.rs b/examples/account_updates_multi.rs new file mode 100644 index 00000000..a8ef2739 --- /dev/null +++ b/examples/account_updates_multi.rs @@ -0,0 +1,22 @@ +use ibapi::accounts::AccountUpdateMulti; +use ibapi::Client; + +fn main() { + env_logger::init(); + + let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + + let account = Some("DU1234567"); + + let subscription = client + .account_updates_multi(account, None) + .expect("error requesting account updates multi"); + for update in &subscription { + println!("{update:?}"); + + // stop after full initial update + if let AccountUpdateMulti::End = update { + subscription.cancel(); + } + } +} diff --git a/examples/readme_realtime_data_1.rs b/examples/readme_realtime_data_1.rs index db6800f3..c3538c97 100644 --- a/examples/readme_realtime_data_1.rs +++ b/examples/readme_realtime_data_1.rs @@ -19,6 +19,6 @@ fn main() { println!("bar: {bar:?}"); // when your algorithm is done, cancel subscription - subscription.cancel().expect("cancel failed"); + subscription.cancel(); } } diff --git a/examples/readme_realtime_data_2.rs b/examples/readme_realtime_data_2.rs index 8f6ee8b5..006bf5e6 100644 --- a/examples/readme_realtime_data_2.rs +++ b/examples/readme_realtime_data_2.rs @@ -22,7 +22,7 @@ fn main() { println!("NVDA {}, AAPL {}", bar_nvda.close, bar_aapl.close); // when your algorithm is done, cancel subscription - subscription_aapl.cancel().expect("cancel failed"); - subscription_nvda.cancel().expect("cancel failed"); + subscription_aapl.cancel(); + subscription_nvda.cancel(); } } diff --git a/src/accounts.rs b/src/accounts.rs index 90cdb725..c70d3107 100644 --- a/src/accounts.rs +++ b/src/accounts.rs @@ -273,31 +273,60 @@ pub struct FamilyCode { /// Account's information, portfolio and last update time #[allow(clippy::large_enum_variant)] -pub enum AccountUpdates { +#[derive(Debug)] +pub enum AccountUpdate { /// Receives the subscribed account's information. - Value(AccountValue), + AccountValue(AccountValue), /// Receives the subscribed account's portfolio. - Portfolio(AccountPortfolio), + PortfolioValue(AccountPortfolioValue), /// Receives the last time on which the account was updated. - Time(AccountTime), + UpdateTime(AccountUpdateTime), /// Notifies when all the account’s information has finished. End, } +impl Subscribable for AccountUpdate { + const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[ + IncomingMessages::AccountValue, + IncomingMessages::PortfolioValue, + IncomingMessages::AccountUpdateTime, + IncomingMessages::AccountDownloadEnd, + ]; + + fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + match message.message_type() { + IncomingMessages::AccountValue => Ok(AccountUpdate::AccountValue(decoders::decode_account_value(message)?)), + IncomingMessages::PortfolioValue => Ok(AccountUpdate::PortfolioValue(decoders::decode_account_portfolio_value( + server_version, + message, + )?)), + IncomingMessages::AccountUpdateTime => Ok(AccountUpdate::UpdateTime(decoders::decode_account_update_time(message)?)), + IncomingMessages::AccountDownloadEnd => Ok(AccountUpdate::End), + message => Err(Error::Simple(format!("unexpected message: {message:?}"))), + } + } + + fn cancel_message(server_version: i32, _request_id: Option) -> Result { + encoders::encode_cancel_account_updates(server_version) + } +} + /// A value of subscribed account's information. +#[derive(Debug, Default)] pub struct AccountValue { /// The value being updated. pub key: String, /// Current value pub value: String, - /// The currency inn which the value is expressed. + /// The currency in which the value is expressed. pub currency: String, /// The account identifier. - pub account: String, + pub account: Option, } /// Subscribed account's portfolio. -pub struct AccountPortfolio { +#[derive(Debug, Default)] +pub struct AccountPortfolioValue { /// The Contract for which a position is held. pub contract: Contract, /// The number of positions held. @@ -310,18 +339,61 @@ pub struct AccountPortfolio { pub average_cost: f64, /// Daily unrealized profit and loss on the position. pub unrealized_pnl: f64, - /// Daily realized profit and loss on the position. + /// Daily realized profit and loss on the position. pub realized_pnl: f64, /// Account identifier for the update. - pub account: String, + pub account: Option, } /// Last time at which the account was updated. -pub struct AccountTime { +#[derive(Debug, Default)] +pub struct AccountUpdateTime { /// The last update system time. pub timestamp: String, } +/// Account's information, portfolio and last update time +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum AccountUpdateMulti { + /// Receives the subscribed account's information. + AccountMultiValue(AccountMultiValue), + /// Notifies when all the account’s information has finished. + End, +} + +// Provides the account updates. +#[derive(Debug, Default)] +pub struct AccountMultiValue { + /// he account with updates. + pub account: String, + /// The model code with updates. + pub model_code: String, + /// The name of parameter. + pub key: String, + /// The value of parameter. + pub value: String, + /// The currency of parameter. + pub currency: String, +} + +impl Subscribable for AccountUpdateMulti { + const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::AccountUpdateMulti, IncomingMessages::AccountUpdateMultiEnd]; + + fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + match message.message_type() { + IncomingMessages::AccountUpdateMulti => Ok(AccountUpdateMulti::AccountMultiValue(decoders::decode_account_multi_value(message)?)), + IncomingMessages::AccountUpdateMultiEnd => Ok(AccountUpdateMulti::End), + message => Err(Error::Simple(format!("unexpected message: {message:?}"))), + } + } + + fn cancel_message(server_version: i32, request_id: Option) -> Result { + let request_id = request_id.expect("Request ID required to encode cancel account updates multi"); + encoders::encode_cancel_account_updates_multi(server_version, request_id) + } +} + // Subscribes to position updates for all accessible accounts. // All positions sent initially, and then only updates as positions change. pub(crate) fn positions(client: &Client) -> Result, Error> { @@ -412,6 +484,27 @@ pub fn account_summary<'a>(client: &'a Client, group: &str, tags: &[&str]) -> Re Ok(Subscription::new(client, subscription)) } +pub fn account_updates<'a>(client: &'a Client, account: &str) -> Result, Error> { + let request = encoders::encode_request_account_updates(client.server_version(), account)?; + let subscription = client.send_shared_request(OutgoingMessages::RequestAccountData, request)?; + + Ok(Subscription::new(client, subscription)) +} + +pub fn account_updates_multi<'a>( + client: &'a Client, + account: Option<&str>, + model_code: Option<&str>, +) -> Result, Error> { + client.check_server_version(server_versions::MODELS_SUPPORT, "It does not support account updates multi requests.")?; + + let request_id = client.next_request_id(); + let request = encoders::encode_request_account_updates_multi(request_id, account, model_code)?; + let subscription = client.send_request(request_id, request)?; + + Ok(Subscription::new(client, subscription)) +} + pub fn managed_accounts(client: &Client) -> Result, Error> { let request = encoders::encode_request_managed_accounts()?; let subscription = client.send_shared_request(OutgoingMessages::RequestManagedAccounts, request)?; diff --git a/src/accounts/decoders.rs b/src/accounts/decoders.rs index add4da19..0cfb1c26 100644 --- a/src/accounts/decoders.rs +++ b/src/accounts/decoders.rs @@ -1,8 +1,10 @@ -use crate::contracts::SecurityType; +use crate::contracts::{Contract, SecurityType}; use crate::messages::ResponseMessage; use crate::{server_versions, Error}; -use super::{AccountSummary, FamilyCode, PnL, PnLSingle, Position, PositionMulti}; +use super::{ + AccountMultiValue, AccountPortfolioValue, AccountSummary, AccountUpdateTime, AccountValue, FamilyCode, PnL, PnLSingle, Position, PositionMulti, +}; pub(crate) fn decode_position(message: &mut ResponseMessage) -> Result { message.skip(); // message type @@ -144,5 +146,98 @@ pub(crate) fn decode_account_summary(_server_version: i32, message: &mut Respons }) } +pub(crate) fn decode_account_value(message: &mut ResponseMessage) -> Result { + message.skip(); // message type + + let message_version = message.next_int()?; + + let mut account_value = AccountValue { + key: message.next_string()?, + value: message.next_string()?, + currency: message.next_string()?, + ..Default::default() + }; + + if message_version >= 2 { + account_value.account = Some(message.next_string()?); + } + + Ok(account_value) +} + +pub(crate) fn decode_account_portfolio_value(server_version: i32, message: &mut ResponseMessage) -> Result { + message.skip(); // message type + + let message_version = message.next_int()?; + + let mut contract = Contract::default(); + if message_version >= 6 { + contract.contract_id = message.next_int()?; + } + contract.symbol = message.next_string()?; + contract.security_type = SecurityType::from(&message.next_string()?); + contract.last_trade_date_or_contract_month = message.next_string()?; + contract.strike = message.next_double()?; + contract.right = message.next_string()?; + if message_version >= 7 { + contract.multiplier = message.next_string()?; + contract.primary_exchange = message.next_string()?; + } + contract.currency = message.next_string()?; + if message_version >= 2 { + contract.local_symbol = message.next_string()?; + } + if message_version >= 8 { + contract.trading_class = message.next_string()?; + } + + let mut portfolio_value = AccountPortfolioValue { + contract, + ..Default::default() + }; + + portfolio_value.position = message.next_double()?; + portfolio_value.market_price = message.next_double()?; + portfolio_value.market_value = message.next_double()?; + if message_version >= 3 { + portfolio_value.average_cost = message.next_double()?; + portfolio_value.unrealized_pnl = message.next_double()?; + portfolio_value.realized_pnl = message.next_double()?; + } + if message_version >= 4 { + portfolio_value.account = Some(message.next_string()?); + } + if message_version == 6 && server_version == 39 { + portfolio_value.contract.primary_exchange = message.next_string()? + } + + Ok(portfolio_value) +} + +pub(crate) fn decode_account_update_time(message: &mut ResponseMessage) -> Result { + message.skip(); // message type + message.skip(); // version + + Ok(AccountUpdateTime { + timestamp: message.next_string()?, + }) +} + +pub(crate) fn decode_account_multi_value(message: &mut ResponseMessage) -> Result { + message.skip(); // message type + message.skip(); // message version + message.skip(); // request id + + let value = AccountMultiValue { + account: message.next_string()?, + model_code: message.next_string()?, + key: message.next_string()?, + value: message.next_string()?, + currency: message.next_string()?, + }; + + Ok(value) +} + #[cfg(test)] mod tests; diff --git a/src/accounts/decoders/tests.rs b/src/accounts/decoders/tests.rs index 9b640832..880fdfb9 100644 --- a/src/accounts/decoders/tests.rs +++ b/src/accounts/decoders/tests.rs @@ -1,4 +1,4 @@ -use crate::{accounts::AccountSummaryTags, server_versions}; +use crate::{accounts::AccountSummaryTags, server_versions, testdata::responses}; #[test] fn test_decode_positions() { @@ -119,3 +119,16 @@ fn test_decode_account_summary() { assert_eq!(account_summary.value, "FA", "account_summary.value"); assert_eq!(account_summary.currency, "", "account_summary.currency"); } + +#[test] +fn test_decode_account_multi_value() { + let mut message = super::ResponseMessage::from_simple(responses::ACCOUNT_UPDATE_MULTI_CURRENCY); + + let value = super::decode_account_multi_value(&mut message).expect("error decoding account multi value"); + + assert_eq!(value.account, "DU1234567", "value.account"); + assert_eq!(value.model_code, "", "value.model_code"); + assert_eq!(value.key, "Currency", "value.key"); + assert_eq!(value.value, "USD", "value.value"); + assert_eq!(value.currency, "USD", "value.currency"); +} diff --git a/src/accounts/encoders.rs b/src/accounts/encoders.rs index 17a41403..18ac8a3c 100644 --- a/src/accounts/encoders.rs +++ b/src/accounts/encoders.rs @@ -96,6 +96,67 @@ pub(crate) fn encode_request_managed_accounts() -> Result Ok(message) } +pub(crate) fn encode_request_account_updates(server_version: i32, account: &str) -> Result { + const VERSION: i32 = 2; + + let mut message = RequestMessage::new(); + + message.push_field(&OutgoingMessages::RequestAccountData); + message.push_field(&VERSION); + message.push_field(&true); // subscribe + if server_version > 9 { + message.push_field(&account); + } + + Ok(message) +} + +pub(crate) fn encode_request_account_updates_multi( + request_id: i32, + account: Option<&str>, + model_code: Option<&str>, +) -> Result { + const VERSION: i32 = 1; + + let mut message = RequestMessage::new(); + + message.push_field(&OutgoingMessages::RequestAccountUpdatesMulti); + message.push_field(&VERSION); + message.push_field(&request_id); + message.push_field(&account); + message.push_field(&model_code); + message.push_field(&true); // subscribe + + Ok(message) +} + +pub(crate) fn encode_cancel_account_updates(server_version: i32) -> Result { + const VERSION: i32 = 2; + + let mut message = RequestMessage::new(); + + message.push_field(&OutgoingMessages::RequestAccountData); + message.push_field(&VERSION); + message.push_field(&false); // subscribe + if server_version > 9 { + message.push_field(&""); + } + + Ok(message) +} + +pub(crate) fn encode_cancel_account_updates_multi(_server_version: i32, request_id: i32) -> Result { + const VERSION: i32 = 1; + + let mut message = RequestMessage::new(); + + message.push_field(&OutgoingMessages::CancelAccountUpdatesMulti); + message.push_field(&VERSION); + message.push_field(&request_id); + + Ok(message) +} + fn encode_simple(message_type: OutgoingMessages, version: i32) -> Result { let mut message = RequestMessage::new(); diff --git a/src/accounts/tests.rs b/src/accounts/tests.rs index b9d1a21c..c0e066c3 100644 --- a/src/accounts/tests.rs +++ b/src/accounts/tests.rs @@ -1,5 +1,6 @@ use std::sync::{Arc, RwLock}; +use crate::accounts::AccountUpdateMulti; use crate::testdata::responses; use crate::{accounts::AccountSummaryTags, server_versions, stubs::MessageBusStub, Client}; @@ -126,3 +127,37 @@ fn test_managed_accounts() { assert_eq!(accounts, &["DU1234567", "DU7654321"]); } + +#[test] +fn test_account_updates_multi() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + responses::ACCOUNT_UPDATE_MULTI_CASH_BALANCE.into(), + responses::ACCOUNT_UPDATE_MULTI_CURRENCY.into(), + responses::ACCOUNT_UPDATE_MULTI_END.into(), + ], + }); + + let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); + + let account = Some("DU1234567"); + let subscription = client.account_updates_multi(account, None).expect("request managed accounts failed"); + + let update = subscription.next().unwrap(); + match update { + AccountUpdateMulti::AccountMultiValue(value) => { + assert_eq!(value.key, "CashBalance"); + } + AccountUpdateMulti::End => { + panic!("value expected") + } + } + + subscription.cancel(); + + let request_messages = client.message_bus.request_messages(); + + assert_eq!(request_messages[0].encode_simple(), "76|1|9000|DU1234567||1|"); + assert_eq!(request_messages[1].encode_simple(), "77|1|9000|"); +} diff --git a/src/client.rs b/src/client.rs index d693bc61..b9d6ffff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,14 +1,14 @@ use std::fmt::Debug; use std::marker::PhantomData; -use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; use std::sync::Arc; use std::time::Duration; -use log::{debug, error}; +use log::{debug, error, info, warn}; use time::OffsetDateTime; use time_tz::Tz; -use crate::accounts::{AccountSummaries, FamilyCode, PnL, PnLSingle, PositionUpdate, PositionUpdateMulti}; +use crate::accounts::{AccountSummaries, AccountUpdate, AccountUpdateMulti, FamilyCode, PnL, PnLSingle, PositionUpdate, PositionUpdateMulti}; use crate::contracts::Contract; use crate::errors::Error; use crate::market_data::historical; @@ -229,6 +229,74 @@ impl Client { accounts::account_summary(self, group, tags) } + /// Subscribes to a specific account’s information and portfolio. + /// + /// All account values and positions will be returned initially, and then there will only be updates when there is a change in a position, or to an account value every 3 minutes if it has changed. Only one account can be subscribed at a time. + /// + /// # Arguments + /// * `account` - The account id (i.e. U1234567) for which the information is requested. + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::Client; + /// use ibapi::accounts::AccountUpdate; + /// + /// let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + /// + /// let account = "U1234567"; + /// + /// let subscription = client.account_updates(account).expect("error requesting account updates"); + /// for update in &subscription { + /// println!("{update:?}"); + /// + /// // stop after full initial update + /// if let AccountUpdate::End = update { + /// subscription.cancel(); + /// } + /// } + /// ``` + pub fn account_updates<'a>(&'a self, account: &str) -> Result, Error> { + accounts::account_updates(self, account) + } + + /// Requests account updates for account and/or model. + /// + /// All account values and positions will be returned initially, and then there will only be updates when there is a change in a position, or to an account value every 3 minutes if it has changed. Only one account can be subscribed at a time. + /// + /// # Arguments + /// * `account` - Account values can be requested for a particular account. + /// * `model_code` - Account values can also be requested for a model. + /// * `ledger_and_nlv` - Returns light-weight request; only currency positions as opposed to account values and currency positions. + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::Client; + /// use ibapi::accounts::AccountUpdateMulti; + /// + /// let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + /// + /// let account = Some("U1234567"); + /// + /// let subscription = client.account_updates_multi(account, None).expect("error requesting account updates multi"); + /// for update in &subscription { + /// println!("{update:?}"); + /// + /// // stop after full initial update + /// if let AccountUpdateMulti::End = update { + /// subscription.cancel(); + /// } + /// } + /// ``` + pub fn account_updates_multi<'a>( + &'a self, + account: Option<&str>, + model_code: Option<&str>, + ) -> Result, Error> { + accounts::account_updates_multi(self, account, model_code) + } + /// Requests the accounts to which the logged user has access to. /// /// # Examples @@ -979,6 +1047,7 @@ pub struct Subscription<'a, T: Subscribable> { pub(crate) message_type: Option, pub(crate) subscription: InternalSubscription, pub(crate) phantom: PhantomData, + cancelled: AtomicBool, } #[allow(private_bounds)] @@ -992,6 +1061,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { message_type: None, subscription, phantom: PhantomData, + cancelled: AtomicBool::new(false), } } else if let Some(order_id) = subscription.order_id { Subscription { @@ -1001,6 +1071,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { message_type: None, subscription, phantom: PhantomData, + cancelled: AtomicBool::new(false), } } else if let Some(message_type) = subscription.message_type { Subscription { @@ -1010,6 +1081,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { message_type: Some(message_type), subscription, phantom: PhantomData, + cancelled: AtomicBool::new(false), } } else { panic!("unsupported internal subscription: {:?}", subscription) @@ -1033,7 +1105,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { error!("{error_message}"); return None; } else { - error!("subscription iterator unexpected message: {message:?}"); + info!("subscription iterator unexpected message: {message:?}"); } } Some(Response::Cancelled) => { @@ -1112,26 +1184,37 @@ impl<'a, T: Subscribable> Subscription<'a, T> { } /// Cancel the subscription - pub fn cancel(&self) -> Result<(), Error> { + pub fn cancel(&self) { + if self.cancelled.load(Ordering::Relaxed) { + return; + } + + self.cancelled.store(true, Ordering::Relaxed); + if let Some(request_id) = self.request_id { if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { - self.client.message_bus.cancel_subscription(request_id, &message)?; + if let Err(e) = self.client.message_bus.cancel_subscription(request_id, &message) { + warn!("error cancelling subscription: {e}") + } self.subscription.cancel(); } } else if let Some(order_id) = self.order_id { if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { - self.client.message_bus.cancel_order_subscription(order_id, &message)?; + if let Err(e) = self.client.message_bus.cancel_order_subscription(order_id, &message) { + warn!("error cancelling order subscription: {e}") + } self.subscription.cancel(); } } else if let Some(message_type) = self.message_type { if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { - self.client.message_bus.cancel_shared_subscription(message_type, &message)?; + if let Err(e) = self.client.message_bus.cancel_shared_subscription(message_type, &message) { + warn!("error cancelling shared subscription: {e}") + } self.subscription.cancel(); } } else { debug!("Could not determine cancel method") } - Ok(()) } pub fn iter(&self) -> SubscriptionIter { @@ -1149,9 +1232,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { impl<'a, T: Subscribable> Drop for Subscription<'a, T> { fn drop(&mut self) { - if let Err(err) = self.cancel() { - error!("error cancelling subscription: {err}"); - } + self.cancel(); } } diff --git a/src/messages.rs b/src/messages.rs index 0693c5a3..8d313a03 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,7 +1,7 @@ use std::ops::Index; use std::str::{self, FromStr}; -use log::error; +use log::debug; use time::OffsetDateTime; use crate::{Error, ToField}; @@ -216,9 +216,11 @@ pub fn request_id_index(kind: IncomingMessages) -> Option { | IncomingMessages::Error | IncomingMessages::ExecutionDataEnd | IncomingMessages::AccountSummary - | IncomingMessages::AccountSummaryEnd => Some(2), + | IncomingMessages::AccountSummaryEnd + | IncomingMessages::AccountUpdateMulti + | IncomingMessages::AccountUpdateMultiEnd => Some(2), _ => { - error!("could not determine request id index for {kind:?}"); + debug!("could not determine request id index for {kind:?}"); None } } @@ -517,12 +519,19 @@ impl ResponseMessage { } pub fn from(fields: &str) -> ResponseMessage { + let fields = fields.replace("|", "\0"); ResponseMessage { i: 0, fields: fields.split('\x00').map(|x| x.to_string()).collect(), } } + #[cfg(test)] + pub fn from_simple(fields: &str) -> ResponseMessage { + let fields = fields.replace("|", "\0"); + Self::from(&fields) + } + pub fn skip(&mut self) { self.i += 1; } @@ -534,5 +543,6 @@ impl ResponseMessage { } } +pub(crate) mod shared_channel_configuration; #[cfg(test)] mod tests; diff --git a/src/messages/shared_channel_configuration.rs b/src/messages/shared_channel_configuration.rs new file mode 100644 index 00000000..2ff963c6 --- /dev/null +++ b/src/messages/shared_channel_configuration.rs @@ -0,0 +1,47 @@ +use super::{IncomingMessages, OutgoingMessages}; + +pub struct ChannelMapping<'a> { + pub request: OutgoingMessages, + pub responses: &'a [IncomingMessages], +} + +// For shared channels configures mapping of request message id to response message ids. +pub(crate) const CHANNEL_MAPPINGS: &[ChannelMapping] = &[ + ChannelMapping { + request: OutgoingMessages::RequestIds, + responses: &[IncomingMessages::NextValidId], + }, + ChannelMapping { + request: OutgoingMessages::RequestFamilyCodes, + responses: &[IncomingMessages::FamilyCodes], + }, + ChannelMapping { + request: OutgoingMessages::RequestMarketRule, + responses: &[IncomingMessages::MarketRule], + }, + ChannelMapping { + request: OutgoingMessages::RequestPositions, + responses: &[IncomingMessages::Position, IncomingMessages::PositionEnd], + }, + ChannelMapping { + request: OutgoingMessages::RequestPositionsMulti, + responses: &[IncomingMessages::PositionMulti, IncomingMessages::PositionMultiEnd], + }, + ChannelMapping { + request: OutgoingMessages::RequestOpenOrders, + responses: &[IncomingMessages::OpenOrder, IncomingMessages::OpenOrderEnd], + }, + ChannelMapping { + request: OutgoingMessages::RequestManagedAccounts, + responses: &[IncomingMessages::ManagedAccounts], + }, + ChannelMapping { + request: OutgoingMessages::RequestAccountData, + responses: &[ + IncomingMessages::AccountValue, + IncomingMessages::PortfolioValue, + IncomingMessages::AccountDownloadEnd, + IncomingMessages::AccountUpdateTime, + ], + }, +]; diff --git a/src/stubs.rs b/src/stubs.rs index f8ebb1e0..c7a27fcb 100644 --- a/src/stubs.rs +++ b/src/stubs.rs @@ -46,6 +46,8 @@ impl MessageBus for MessageBusStub { Ok(()) } + fn ensure_shutdown(&self) {} + // fn process_messages(&mut self, _server_version: i32) -> Result<(), Error> { // Ok(()) // } diff --git a/src/testdata/responses.rs b/src/testdata/responses.rs index bb3e6afa..d3cdd09f 100644 --- a/src/testdata/responses.rs +++ b/src/testdata/responses.rs @@ -1,2 +1,7 @@ pub const POSITION: &str = "61\03\0DU1234567\076792991\0TSLA\0STK\0\00.0\0\0\0NASDAQ\0USD\0TSLA\0NMS\0500\0196.77\0"; pub const MANAGED_ACCOUNT: &str = "15|1|DU1234567,DU7654321|"; + +pub const ACCOUNT_UPDATE_MULTI_CASH_BALANCE: &str = "73|1|9000|DU1234567||CashBalance|94629.71|USD||"; +pub const ACCOUNT_UPDATE_MULTI_CURRENCY: &str = "73|1|9000|DU1234567||Currency|USD|USD||"; +pub const ACCOUNT_UPDATE_MULTI_STOCK_MARKET_VALUE: &str = "73|1|9000|DU1234567||StockMarketValue|0.00|BASE||"; +pub const ACCOUNT_UPDATE_MULTI_END: &str = "74|1|9000||"; diff --git a/src/transport.rs b/src/transport.rs index ee3519f0..b272d0a0 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -17,8 +17,7 @@ use time::macros::format_description; use time::OffsetDateTime; use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, Tz}; -use crate::messages::{IncomingMessages, OutgoingMessages}; -use crate::messages::{RequestMessage, ResponseMessage}; +use crate::messages::{shared_channel_configuration, IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; use crate::{server_versions, Error}; use recorder::MessageRecorder; @@ -46,7 +45,7 @@ pub(crate) trait MessageBus: Send + Sync { fn cancel_order_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error>; - fn ensure_shutdown(&self) {} + fn ensure_shutdown(&self); // Testing interface. Tracks requests sent messages when Bus is stubbed. #[cfg(test)] @@ -87,22 +86,9 @@ impl SharedChannels { }; // Register request/response pairs. - instance.register(OutgoingMessages::RequestIds, &[IncomingMessages::NextValidId]); - instance.register(OutgoingMessages::RequestFamilyCodes, &[IncomingMessages::FamilyCodes]); - instance.register(OutgoingMessages::RequestMarketRule, &[IncomingMessages::MarketRule]); - instance.register( - OutgoingMessages::RequestPositions, - &[IncomingMessages::Position, IncomingMessages::PositionEnd], - ); - instance.register( - OutgoingMessages::RequestPositionsMulti, - &[IncomingMessages::PositionMulti, IncomingMessages::PositionMultiEnd], - ); - instance.register( - OutgoingMessages::RequestOpenOrders, - &[IncomingMessages::OpenOrder, IncomingMessages::OpenOrderEnd], - ); - instance.register(OutgoingMessages::RequestManagedAccounts, &[IncomingMessages::ManagedAccounts]); + for mapping in shared_channel_configuration::CHANNEL_MAPPINGS { + instance.register(mapping.request, mapping.responses); + } instance } @@ -186,6 +172,8 @@ impl TcpMessageBus { } fn request_shutdown(&self) { + debug!("shutdown requested"); + self.requests.notify_all(&Response::Disconnected); self.orders.notify_all(&Response::Disconnected); @@ -232,6 +220,13 @@ impl TcpMessageBus { backoff.reset(); retry_attempt = 0; } + Err(Error::Io(e)) if e.kind() == ErrorKind::WouldBlock => { + if message_bus.is_shutting_down() { + debug!("dispatcher thread exiting"); + return; + } + thread::sleep(Duration::from_millis(1)); + } Err(Error::Io(e)) if RECONNECT_ERRORS.contains(&e.kind()) => { error!("error reading packet: {:?}", e); // reset hashes @@ -262,10 +257,6 @@ impl TcpMessageBus { return; } }; - - if message_bus.is_shutting_down() { - return; - } } }) } @@ -395,20 +386,23 @@ impl TcpMessageBus { fn start_cleanup_thread(self: &Arc) -> JoinHandle<()> { let message_bus = Arc::clone(self); - thread::spawn(move || loop { + thread::spawn(move || { let signal_recv = message_bus.signals_recv.clone(); - for signal in &signal_recv { - match signal { - Signal::Request(request_id) => { - message_bus.clean_request(request_id); - } - Signal::Order(order_id) => { - message_bus.clean_order(order_id); + loop { + if let Ok(signal) = signal_recv.recv_timeout(Duration::from_secs(1)) { + match signal { + Signal::Request(request_id) => { + message_bus.clean_request(request_id); + } + Signal::Order(order_id) => { + message_bus.clean_order(order_id); + } } } if message_bus.is_shutting_down() { + debug!("cleanup thread exiting"); return; } } @@ -524,6 +518,7 @@ impl MessageBus for TcpMessageBus { } fn ensure_shutdown(&self) { + self.request_shutdown(); self.join(); } } @@ -823,6 +818,8 @@ impl Connection { let reader = TcpStream::connect(connection_url)?; let writer = reader.try_clone()?; + reader.set_read_timeout(Some(Duration::from_secs(1)))?; + let connection = Self { client_id, connection_url: connection_url.into(), @@ -859,6 +856,8 @@ impl Connection { let mut writer = self.writer.lock()?; *reader = stream.try_clone()?; + reader.set_read_timeout(Some(Duration::from_secs(1)))?; + *writer = stream; }