From 2efb9769953fee0998140267b0d3fe1914e287e9 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Tue, 15 Oct 2024 13:38:06 -0700 Subject: [PATCH] Refactors connection management (#127) --- examples/account_summary.rs | 3 + examples/breakout.rs | 2 + examples/contract_details.rs | 2 + examples/executions.rs | 2 + examples/family_codes.rs | 2 + examples/managed_accounts.rs | 2 + examples/market_rule.rs | 2 + examples/matching_symbols.rs | 2 + examples/next_order_id.rs | 2 + examples/positions.rs | 2 + examples/positions_multi.rs | 2 + examples/readme_connection.rs | 2 + examples/readme_historical_data.rs | 2 + examples/readme_place_order.rs | 2 + examples/readme_realtime_data_1.rs | 2 + src/accounts.rs | 65 +- src/accounts/tests.rs | 36 +- src/client.rs | 292 ++++----- src/client/tests.rs | 14 - src/contracts.rs | 9 +- src/contracts/tests.rs | 9 +- src/errors.rs | 4 + src/lib.rs | 5 +- src/market_data/historical.rs | 17 +- src/market_data/historical/tests.rs | 15 +- src/market_data/realtime.rs | 37 +- src/market_data/realtime/tests.rs | 8 +- src/orders.rs | 12 +- src/orders/tests.rs | 69 +- src/stubs.rs | 55 +- src/tests.rs | 1 + src/transport.rs | 941 ++++++++++++++++++++-------- src/transport/tests.rs | 39 ++ 33 files changed, 1009 insertions(+), 650 deletions(-) create mode 100644 src/tests.rs diff --git a/examples/account_summary.rs b/examples/account_summary.rs index 321de0d6..fd6f1106 100644 --- a/examples/account_summary.rs +++ b/examples/account_summary.rs @@ -2,6 +2,8 @@ use ibapi::accounts::{AccountSummaries, AccountSummaryTags}; use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); let group = "All"; @@ -9,6 +11,7 @@ fn main() { let subscription = client .account_summary(group, AccountSummaryTags::ALL) .expect("error requesting account summary"); + for update in &subscription { match update { AccountSummaries::Summary(summary) => println!("{summary:?}"), diff --git a/examples/breakout.rs b/examples/breakout.rs index ad329f5e..0502b393 100644 --- a/examples/breakout.rs +++ b/examples/breakout.rs @@ -6,6 +6,8 @@ use ibapi::orders::{order_builder, Action, OrderNotification}; use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).unwrap(); let symbol = "TSLA"; diff --git a/examples/contract_details.rs b/examples/contract_details.rs index 2387b4c7..68b87735 100644 --- a/examples/contract_details.rs +++ b/examples/contract_details.rs @@ -2,6 +2,8 @@ use ibapi::contracts::Contract; use ibapi::Client; fn main() -> anyhow::Result<()> { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100)?; println!("server_version: {}", client.server_version()); diff --git a/examples/executions.rs b/examples/executions.rs index 8fcff307..6020a13e 100644 --- a/examples/executions.rs +++ b/examples/executions.rs @@ -2,6 +2,8 @@ use ibapi::orders::ExecutionFilter; use ibapi::Client; fn main() -> anyhow::Result<()> { + env_logger::init(); + let mut filter = ExecutionFilter::default(); filter.client_id = Some(32); diff --git a/examples/family_codes.rs b/examples/family_codes.rs index 0cca1974..a4ae21e0 100644 --- a/examples/family_codes.rs +++ b/examples/family_codes.rs @@ -1,6 +1,8 @@ use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); let family_codes = client.family_codes().expect("request failed"); diff --git a/examples/managed_accounts.rs b/examples/managed_accounts.rs index 08f6cc7f..05584b8e 100644 --- a/examples/managed_accounts.rs +++ b/examples/managed_accounts.rs @@ -1,6 +1,8 @@ use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 101).expect("connection failed"); let accounts = client.managed_accounts().expect("error requesting managed accounts"); diff --git a/examples/market_rule.rs b/examples/market_rule.rs index 5a60e2d4..d7a9f327 100644 --- a/examples/market_rule.rs +++ b/examples/market_rule.rs @@ -1,6 +1,8 @@ use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).unwrap(); let market_rule_id = 12; diff --git a/examples/matching_symbols.rs b/examples/matching_symbols.rs index 0385c200..dc489037 100644 --- a/examples/matching_symbols.rs +++ b/examples/matching_symbols.rs @@ -1,6 +1,8 @@ use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).unwrap(); let pattern = "TSLA"; diff --git a/examples/next_order_id.rs b/examples/next_order_id.rs index bbc95043..d86d8da5 100644 --- a/examples/next_order_id.rs +++ b/examples/next_order_id.rs @@ -1,6 +1,8 @@ use ibapi::Client; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).unwrap(); let order_id = client.next_valid_order_id().unwrap(); diff --git a/examples/positions.rs b/examples/positions.rs index 0f66065d..ac6f3691 100644 --- a/examples/positions.rs +++ b/examples/positions.rs @@ -1,6 +1,8 @@ use ibapi::{accounts::PositionUpdate, Client}; fn main() { + env_logger::init(); + let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); let positions = client.positions().expect("request failed"); diff --git a/examples/positions_multi.rs b/examples/positions_multi.rs index 7dbb96ad..9411a3a5 100644 --- a/examples/positions_multi.rs +++ b/examples/positions_multi.rs @@ -3,6 +3,8 @@ use std::env; use ibapi::Client; pub fn main() { + env_logger::init(); + let account = env::var("IBKR_ACCOUNT").expect("Please set IBKR_ACCOUNT environment variable to an account ID"); let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); diff --git a/examples/readme_connection.rs b/examples/readme_connection.rs index a31ad0f9..c7a320f1 100644 --- a/examples/readme_connection.rs +++ b/examples/readme_connection.rs @@ -2,6 +2,8 @@ use ibapi::Client; fn main() { + env_logger::init(); + let connection_url = "127.0.0.1:4002"; let _client = Client::connect(connection_url, 100).expect("connection to TWS failed!"); diff --git a/examples/readme_historical_data.rs b/examples/readme_historical_data.rs index ae7a6875..cd4d80f3 100644 --- a/examples/readme_historical_data.rs +++ b/examples/readme_historical_data.rs @@ -5,6 +5,8 @@ use ibapi::market_data::historical::{BarSize, ToDuration, WhatToShow}; use ibapi::Client; fn main() { + env_logger::init(); + let connection_url = "127.0.0.1:4002"; let client = Client::connect(connection_url, 100).expect("connection to TWS failed!"); diff --git a/examples/readme_place_order.rs b/examples/readme_place_order.rs index f9e85727..5565c5ca 100644 --- a/examples/readme_place_order.rs +++ b/examples/readme_place_order.rs @@ -3,6 +3,8 @@ use ibapi::orders::{order_builder, Action, OrderNotification}; use ibapi::Client; pub fn main() { + env_logger::init(); + let connection_url = "127.0.0.1:4002"; let client = Client::connect(connection_url, 100).expect("connection to TWS failed!"); diff --git a/examples/readme_realtime_data_1.rs b/examples/readme_realtime_data_1.rs index 514d0ce9..db6800f3 100644 --- a/examples/readme_realtime_data_1.rs +++ b/examples/readme_realtime_data_1.rs @@ -3,6 +3,8 @@ use ibapi::market_data::realtime::{BarSize, WhatToShow}; use ibapi::Client; fn main() { + env_logger::init(); + let connection_url = "127.0.0.1:4002"; let client = Client::connect(connection_url, 100).expect("connection to TWS failed!"); diff --git a/src/accounts.rs b/src/accounts.rs index d4ead9bd..90cdb725 100644 --- a/src/accounts.rs +++ b/src/accounts.rs @@ -9,11 +9,10 @@ //! - Real-time PnL updates for individual positions //! -use std::marker::PhantomData; - use crate::client::{SharesChannel, Subscribable, Subscription}; use crate::contracts::Contract; use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; +use crate::transport::Response; use crate::{server_versions, Client, Error}; mod decoders; @@ -329,14 +328,9 @@ pub(crate) fn positions(client: &Client) -> Result, client.check_server_version(server_versions::ACCOUNT_SUMMARY, "It does not support position requests.")?; let request = encoders::encode_request_positions()?; - let responses = client.send_shared_request(OutgoingMessages::RequestPositions, request)?; - - Ok(Subscription { - client, - request_id: None, - subscription: responses, - phantom: PhantomData, - }) + let subscription = client.send_shared_request(OutgoingMessages::RequestPositions, request)?; + + Ok(Subscription::new(client, subscription)) } impl SharesChannel for Subscription<'_, PositionUpdate> {} @@ -349,16 +343,10 @@ pub(crate) fn positions_multi<'a>( client.check_server_version(server_versions::MODELS_SUPPORT, "It does not support positions multi requests.")?; let request_id = client.next_request_id(); - let request = encoders::encode_request_positions_multi(request_id, account, model_code)?; - let responses = client.send_request(request_id, request)?; - - Ok(Subscription { - client, - request_id: Some(request_id), - subscription: responses, - phantom: PhantomData, - }) + let subscription = client.send_request(request_id, request)?; + + Ok(Subscription::new(client, subscription)) } // Determine whether an account exists under an account family and find the account family code. @@ -368,7 +356,8 @@ pub(crate) fn family_codes(client: &Client) -> Result, Error> { let request = encoders::encode_request_family_codes()?; let subscription = client.send_shared_request(OutgoingMessages::RequestFamilyCodes, request)?; - if let Some(mut message) = subscription.next() { + // TODO: enumerate + if let Some(Response::Message(mut message)) = subscription.next() { decoders::decode_family_codes(&mut message) } else { Ok(Vec::default()) @@ -385,16 +374,10 @@ pub(crate) fn pnl<'a>(client: &'a Client, account: &str, model_code: Option<&str client.check_server_version(server_versions::PNL, "It does not support PnL requests.")?; let request_id = client.next_request_id(); - let request = encoders::encode_request_pnl(request_id, account, model_code)?; - let responses = client.send_request(request_id, request)?; - - Ok(Subscription { - client, - request_id: Some(request_id), - subscription: responses, - phantom: PhantomData, - }) + let subscription = client.send_request(request_id, request)?; + + Ok(Subscription::new(client, subscription)) } // Requests real time updates for daily PnL of individual positions. @@ -413,32 +396,20 @@ pub(crate) fn pnl_single<'a>( client.check_server_version(server_versions::REALIZED_PNL, "It does not support PnL requests.")?; let request_id = client.next_request_id(); - let request = encoders::encode_request_pnl_single(request_id, account, contract_id, model_code)?; - let responses = client.send_request(request_id, request)?; - - Ok(Subscription { - client, - request_id: Some(request_id), - subscription: responses, - phantom: PhantomData, - }) + let subscription = client.send_request(request_id, request)?; + + Ok(Subscription::new(client, subscription)) } pub fn account_summary<'a>(client: &'a Client, group: &str, tags: &[&str]) -> Result, Error> { client.check_server_version(server_versions::ACCOUNT_SUMMARY, "It does not support account summary requests.")?; let request_id = client.next_request_id(); - let request = encoders::encode_request_account_summary(request_id, group, tags)?; let subscription = client.send_request(request_id, request)?; - Ok(Subscription { - client, - request_id: Some(request_id), - subscription, - phantom: PhantomData, - }) + Ok(Subscription::new(client, subscription)) } pub fn managed_accounts(client: &Client) -> Result, Error> { @@ -446,13 +417,15 @@ pub fn managed_accounts(client: &Client) -> Result, Error> { let subscription = client.send_shared_request(OutgoingMessages::RequestManagedAccounts, request)?; match subscription.next() { - Some(mut message) => { + Some(Response::Message(mut message)) => { message.skip(); // message type message.skip(); // message version let accounts = message.next_string()?; Ok(accounts.split(",").map(String::from).collect()) } + Some(Response::Cancelled) => Err(Error::Cancelled), + Some(Response::Disconnected) => Err(Error::ConnectionFailed), None => Ok(Vec::default()), } } diff --git a/src/accounts/tests.rs b/src/accounts/tests.rs index 70c267d9..b9d1a21c 100644 --- a/src/accounts/tests.rs +++ b/src/accounts/tests.rs @@ -1,14 +1,14 @@ -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, RwLock}; use crate::testdata::responses; use crate::{accounts::AccountSummaryTags, server_versions, stubs::MessageBusStub, Client}; #[test] fn test_pnl() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -18,7 +18,7 @@ fn test_pnl() { let _ = client.pnl(account, model_code).expect("request pnl failed"); let _ = client.pnl(account, None).expect("request pnl failed"); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "92|9000|DU1234567|TARGET2024|"); assert_eq!(request_messages[1].encode_simple(), "93|9000|"); @@ -29,10 +29,10 @@ fn test_pnl() { #[test] fn test_pnl_single() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -43,7 +43,7 @@ fn test_pnl_single() { let _ = client.pnl_single(account, contract_id, model_code).expect("request pnl failed"); let _ = client.pnl_single(account, contract_id, None).expect("request pnl failed"); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "94|9000|DU1234567|TARGET2024|1001|"); assert_eq!(request_messages[1].encode_simple(), "95|9000|"); @@ -54,16 +54,16 @@ fn test_pnl_single() { #[test] fn test_positions() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let _ = client.positions().expect("request positions failed"); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "61|1|"); assert_eq!(request_messages[1].encode_simple(), "64|1|"); @@ -71,10 +71,10 @@ fn test_positions() { #[test] fn test_positions_multi() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -84,7 +84,7 @@ fn test_positions_multi() { let _ = client.positions_multi(account, model_code).expect("request positions failed"); let _ = client.positions_multi(None, model_code).expect("request positions failed"); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "74|1|9000|DU1234567|TARGET2024|"); assert_eq!(request_messages[1].encode_simple(), "75|1|9000|"); @@ -95,10 +95,10 @@ fn test_positions_multi() { #[test] fn test_account_summary() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -107,7 +107,7 @@ fn test_account_summary() { let _ = client.account_summary(group, tags).expect("request account summary failed"); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "62|1|9000|All|AccountType|"); assert_eq!(request_messages[1].encode_simple(), "64|1|"); @@ -115,10 +115,10 @@ fn test_account_summary() { #[test] fn test_managed_accounts() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![responses::MANAGED_ACCOUNT.into()], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); diff --git a/src/client.rs b/src/client.rs index cb7e827b..d693bc61 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,15 +1,12 @@ use std::fmt::Debug; -use std::io::Write; use std::marker::PhantomData; use std::sync::atomic::{AtomicI32, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; -use byteorder::{BigEndian, WriteBytesExt}; -use log::{debug, error, info}; -use time::macros::format_description; +use log::{debug, error}; use time::OffsetDateTime; -use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, Tz}; +use time_tz::Tz; use crate::accounts::{AccountSummaries, FamilyCode, PnL, PnLSingle, PositionUpdate, PositionUpdateMulti}; use crate::contracts::Contract; @@ -19,14 +16,11 @@ use crate::market_data::realtime::{self, Bar, BarSize, MidPoint, WhatToShow}; use crate::messages::{IncomingMessages, OutgoingMessages}; use crate::messages::{RequestMessage, ResponseMessage}; use crate::orders::{Order, OrderDataResult, OrderNotification}; -use crate::transport::{InternalSubscription, MessageBus, TcpMessageBus}; -use crate::{accounts, contracts, orders, server_versions}; +use crate::transport::{Connection, ConnectionMetadata, InternalSubscription, MessageBus, Response, TcpMessageBus}; +use crate::{accounts, contracts, orders}; // Client -const MIN_SERVER_VERSION: i32 = 100; -const MAX_SERVER_VERSION: i32 = server_versions::HISTORICAL_SCHEDULE; - /// TWS API Client. Manages the connection to TWS or Gateway. /// Tracks some global information such as server version and server time. /// Supports generation of order ids @@ -37,10 +31,9 @@ pub struct Client { // pub server_time: OffsetDateTime, pub(crate) connection_time: Option, pub(crate) time_zone: Option<&'static Tz>, + pub(crate) message_bus: Arc, - managed_accounts: String, - client_id: i32, // ID of client. - pub(crate) message_bus: Arc>, + client_id: i32, // ID of client. next_request_id: AtomicI32, // Next available request_id. order_id: AtomicI32, // Next available order_id. Starts with value returned on connection. } @@ -66,119 +59,31 @@ impl Client { /// println!("next_order_id: {}", client.next_order_id()); /// ``` pub fn connect(address: &str, client_id: i32) -> Result { - let message_bus = Arc::new(Mutex::new(TcpMessageBus::connect(address)?)); - Client::do_connect(client_id, message_bus) + let connection = Connection::connect(client_id, address)?; + let connection_metadata = connection.connection_metadata(); + + let message_bus = Arc::new(TcpMessageBus::new(connection)?); + + // Starts thread to read messages from TWS + message_bus.process_messages(connection_metadata.server_version)?; + + Client::new(connection_metadata, message_bus) } - fn do_connect(client_id: i32, message_bus: Arc>) -> Result { - let mut client = Client { - server_version: 0, - connection_time: None, - time_zone: None, - managed_accounts: String::from(""), + fn new(connection_metadata: ConnectionMetadata, message_bus: Arc) -> Result { + let client = Client { + server_version: connection_metadata.server_version, + connection_time: connection_metadata.connection_time, + time_zone: connection_metadata.time_zone, message_bus, - client_id, + client_id: connection_metadata.client_id, next_request_id: AtomicI32::new(9000), order_id: AtomicI32::new(-1), }; - client.handshake()?; - client.start_api()?; - client.receive_account_info()?; - - client.message_bus.lock()?.process_messages(client.server_version)?; - Ok(client) } - // sends server handshake - fn handshake(&mut self) -> Result<(), Error> { - let prefix = "API\0"; - let version = format!("v{MIN_SERVER_VERSION}..{MAX_SERVER_VERSION}"); - - let packet = prefix.to_owned() + &encode_packet(&version); - self.message_bus.lock()?.write(&packet)?; - - let ack = self.message_bus.lock()?.read_message(); - - match ack { - Ok(mut response_message) => { - self.server_version = response_message.next_int()?; - - let time = response_message.next_string()?; - (self.connection_time, self.time_zone) = parse_connection_time(time.as_str()); - } - Err(Error::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - return Err(Error::Simple(format!("The server may be rejecting connections from this host: {err}"))); - } - Err(err) => { - return Err(err); - } - } - Ok(()) - } - - // asks server to start processing messages - fn start_api(&mut self) -> Result<(), Error> { - const VERSION: i32 = 2; - - let prelude = &mut RequestMessage::default(); - - prelude.push_field(&OutgoingMessages::StartApi); - prelude.push_field(&VERSION); - prelude.push_field(&self.client_id); - - if self.server_version > server_versions::OPTIONAL_CAPABILITIES { - prelude.push_field(&""); - } - - self.message_bus.lock()?.write_message(prelude)?; - - Ok(()) - } - - // Fetches next order id and managed accounts. - fn receive_account_info(&mut self) -> Result<(), Error> { - let mut saw_next_order_id: bool = false; - let mut saw_managed_accounts: bool = false; - - let mut attempts = 0; - const MAX_ATTEMPTS: i32 = 100; - loop { - let mut message = self.message_bus.lock()?.read_message()?; - - match message.message_type() { - IncomingMessages::NextValidId => { - saw_next_order_id = true; - - message.skip(); // message type - message.skip(); // message version - - self.order_id.store(message.next_int()?, Ordering::Relaxed); - } - IncomingMessages::ManagedAccounts => { - saw_managed_accounts = true; - - message.skip(); // message type - message.skip(); // message version - - self.managed_accounts = message.next_string()?; - } - IncomingMessages::Error => { - error!("message: {message:?}") - } - _ => info!("message: {message:?}"), - } - - attempts += 1; - if (saw_next_order_id && saw_managed_accounts) || attempts > MAX_ATTEMPTS { - break; - } - } - - Ok(()) - } - /// Returns the next request ID. pub fn next_request_id(&self) -> i32 { self.next_request_id.fetch_add(1, Ordering::Relaxed) @@ -1009,12 +914,11 @@ impl Client { // == Internal Use == #[cfg(test)] - pub(crate) fn stubbed(message_bus: Arc>, server_version: i32) -> Client { + pub(crate) fn stubbed(message_bus: Arc, server_version: i32) -> Client { Client { server_version: server_version, connection_time: None, time_zone: None, - managed_accounts: String::from(""), message_bus, client_id: 100, next_request_id: AtomicI32::new(9000), @@ -1022,23 +926,19 @@ impl Client { } } - pub(crate) fn send_message(&self, packet: RequestMessage) -> Result<(), Error> { - self.message_bus.lock()?.write_message(&packet) - } - pub(crate) fn send_request(&self, request_id: i32, message: RequestMessage) -> Result { debug!("send_message({:?}, {:?})", request_id, message); - self.message_bus.lock()?.send_request(request_id, &message) + self.message_bus.send_request(request_id, &message) } pub(crate) fn send_order(&self, order_id: i32, message: RequestMessage) -> Result { debug!("send_order({:?}, {:?})", order_id, message); - self.message_bus.lock()?.send_order_request(order_id, &message) + self.message_bus.send_order_request(order_id, &message) } /// Sends request for the next valid order id. pub(crate) fn send_shared_request(&self, message_id: OutgoingMessages, message: RequestMessage) -> Result { - self.message_bus.lock()?.send_shared_request(message_id, &message) + self.message_bus.send_shared_request(message_id, &message) } pub(crate) fn check_server_version(&self, version: i32, message: &str) -> Result<(), Error> { @@ -1052,7 +952,8 @@ impl Client { impl Drop for Client { fn drop(&mut self) { - debug!("dropping basic client") + debug!("dropping basic client"); + self.message_bus.ensure_shutdown(); } } @@ -1074,32 +975,78 @@ impl Debug for Client { pub struct Subscription<'a, T: Subscribable> { pub(crate) client: &'a Client, pub(crate) request_id: Option, + pub(crate) order_id: Option, + pub(crate) message_type: Option, pub(crate) subscription: InternalSubscription, pub(crate) phantom: PhantomData, } #[allow(private_bounds)] impl<'a, T: Subscribable> Subscription<'a, T> { + pub(crate) fn new(client: &'a Client, subscription: InternalSubscription) -> Self { + if let Some(request_id) = subscription.request_id { + Subscription { + client, + request_id: Some(request_id), + order_id: None, + message_type: None, + subscription, + phantom: PhantomData, + } + } else if let Some(order_id) = subscription.order_id { + Subscription { + client, + request_id: None, + order_id: Some(order_id), + message_type: None, + subscription, + phantom: PhantomData, + } + } else if let Some(message_type) = subscription.message_type { + Subscription { + client, + request_id: None, + order_id: None, + message_type: Some(message_type), + subscription, + phantom: PhantomData, + } + } else { + panic!("unsupported internal subscription: {:?}", subscription) + } + } + /// Blocks until the item become available. pub fn next(&self) -> Option { loop { - if let Some(mut message) = self.subscription.next() { - if T::RESPONSE_MESSAGE_IDS.contains(&message.message_type()) { - match T::decode(self.client.server_version(), &mut message) { - Ok(val) => return Some(val), - Err(err) => { - error!("error decoding execution data: {err}"); + match self.subscription.next() { + Some(Response::Message(mut message)) => { + if T::RESPONSE_MESSAGE_IDS.contains(&message.message_type()) { + match T::decode(self.client.server_version(), &mut message) { + Ok(val) => return Some(val), + Err(err) => { + error!("error decoding execution data: {err}"); + } } + } else if message.message_type() == IncomingMessages::Error { + let error_message = message.peek_string(4); + error!("{error_message}"); + return None; + } else { + error!("subscription iterator unexpected message: {message:?}"); } - } else if message.message_type() == IncomingMessages::Error { - let error_message = message.peek_string(4); - error!("{error_message}"); + } + Some(Response::Cancelled) => { + debug!("subscription cancelled"); + return None; + } + Some(Response::Disconnected) => { + debug!("server disconnected"); + return None; + } + _ => { return None; - } else { - error!("subscription iterator unexpected message: {message:?}"); } - } else { - return None; } } } @@ -1116,7 +1063,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { /// //} /// ``` pub fn try_next(&self) -> Option { - if let Some(mut message) = self.subscription.try_next() { + if let Some(Response::Message(mut message)) = self.subscription.try_next() { if message.message_type() == IncomingMessages::Error { error!("{}", message.peek_string(4)); return None; @@ -1146,7 +1093,7 @@ impl<'a, T: Subscribable> Subscription<'a, T> { /// //} /// ``` pub fn next_timeout(&self, timeout: Duration) -> Option { - if let Some(mut message) = self.subscription.next_timeout(timeout) { + if let Some(Response::Message(mut message)) = self.subscription.next_timeout(timeout) { if message.message_type() == IncomingMessages::Error { error!("{}", message.peek_string(4)); return None; @@ -1166,9 +1113,23 @@ impl<'a, T: Subscribable> Subscription<'a, T> { /// Cancel the subscription pub fn cancel(&self) -> Result<(), Error> { - if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { - self.client.send_message(message)?; - self.subscription.cancel()?; + if let Some(request_id) = self.request_id { + if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { + self.client.message_bus.cancel_subscription(request_id, &message)?; + self.subscription.cancel(); + } + } else if let Some(order_id) = self.order_id { + if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { + self.client.message_bus.cancel_order_subscription(order_id, &message)?; + self.subscription.cancel(); + } + } else if let Some(message_type) = self.message_type { + if let Ok(message) = T::cancel_message(self.client.server_version(), self.request_id) { + self.client.message_bus.cancel_shared_subscription(message_type, &message)?; + self.subscription.cancel(); + } + } else { + debug!("Could not determine cancel method") } Ok(()) } @@ -1259,46 +1220,5 @@ impl<'a, T: Subscribable> Iterator for SubscriptionTimeoutIter<'a, T> { /// Marker trait for shared channels pub trait SharesChannel {} -// Parses following format: 20230405 22:20:39 PST -fn parse_connection_time(connection_time: &str) -> (Option, Option<&'static Tz>) { - let parts: Vec<&str> = connection_time.split(' ').collect(); - - let zones = timezones::find_by_name(parts[2]); - if zones.is_empty() { - error!("time zone not found for {}", parts[2]); - return (None, None); - } - - let timezone = zones[0]; - - let format = format_description!("[year][month][day] [hour]:[minute]:[second]"); - let date_str = format!("{} {}", parts[0], parts[1]); - let date = time::PrimitiveDateTime::parse(date_str.as_str(), format); - match date { - Ok(connected_at) => match connected_at.assume_timezone(timezone) { - OffsetResult::Some(date) => (Some(date), Some(timezone)), - _ => { - error!("error setting timezone"); - (None, Some(timezone)) - } - }, - Err(err) => { - error!("could not parse connection time from {date_str}: {err}"); - (None, Some(timezone)) - } - } -} - -fn encode_packet(message: &str) -> String { - let data = message.as_bytes(); - - let mut packet: Vec = Vec::with_capacity(data.len() + 4); - - packet.write_u32::(data.len() as u32).unwrap(); - packet.write_all(data).unwrap(); - - std::str::from_utf8(&packet).unwrap().into() -} - #[cfg(test)] mod tests; diff --git a/src/client/tests.rs b/src/client/tests.rs index 2c4138e3..8b137891 100644 --- a/src/client/tests.rs +++ b/src/client/tests.rs @@ -1,15 +1 @@ -use time::macros::datetime; -use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt}; -use super::*; - -#[test] -fn test_parse_connection_time() { - let example = "20230405 22:20:39 PST"; - let (connection_time, _) = parse_connection_time(example); - - let la = timezones::db::america::LOS_ANGELES; - if let OffsetResult::Some(other) = datetime!(2023-04-05 22:20:39).assume_timezone(la) { - assert_eq!(connection_time, Some(other)); - } -} diff --git a/src/contracts.rs b/src/contracts.rs index c46cfae8..aacf951d 100644 --- a/src/contracts.rs +++ b/src/contracts.rs @@ -8,6 +8,7 @@ use crate::encode_option_field; use crate::messages::IncomingMessages; use crate::messages::OutgoingMessages; use crate::messages::RequestMessage; +use crate::transport::Response; use crate::Client; use crate::{server_versions, Error, ToField}; @@ -403,7 +404,7 @@ pub(crate) fn contract_details(client: &Client, contract: &Contract) -> Result = Vec::default(); // TODO create iterator - while let Some(mut message) = responses.next() { + while let Some(Response::Message(mut message)) = responses.next() { match message.message_type() { IncomingMessages::ContractData => { let decoded = decoders::contract_details(client.server_version(), &mut message)?; @@ -476,7 +477,7 @@ pub(crate) fn matching_symbols(client: &Client, pattern: &str) -> Result { return decoders::contract_descriptions(client.server_version(), &mut message); @@ -519,7 +520,9 @@ pub(crate) fn market_rule(client: &Client, market_rule_id: i32) -> Result Ok(decoders::market_rule(&mut message)?), + Some(Response::Message(mut message)) => Ok(decoders::market_rule(&mut message)?), + Some(Response::Cancelled) => Err(Error::Simple("subscription cancelled".into())), + Some(Response::Disconnected) => Err(Error::Simple("server gone".into())), None => Err(Error::Simple("no market rule found".into())), } } diff --git a/src/contracts/tests.rs b/src/contracts/tests.rs index 1ae91d14..46a995db 100644 --- a/src/contracts/tests.rs +++ b/src/contracts/tests.rs @@ -1,5 +1,4 @@ -use std::sync::RwLock; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use super::*; @@ -7,14 +6,14 @@ use crate::stubs::MessageBusStub; #[test] fn request_stock_contract_details() { - let message_bus = Arc::new(Mutex::new(MessageBusStub{ + let message_bus = Arc::new(MessageBusStub{ request_messages: RwLock::new(vec![]), response_messages: vec![ "10|9001|TSLA|STK||0||SMART|USD|TSLA|NMS|NMS|76792991|0.01||ACTIVETIM,AD,ADJUST,ALERT,ALGO,ALLOC,AON,AVGCOST,BASKET,BENCHPX,CASHQTY,COND,CONDORDER,DARKONLY,DARKPOLL,DAY,DEACT,DEACTDIS,DEACTEOD,DIS,DUR,GAT,GTC,GTD,GTT,HID,IBKRATS,ICE,IMB,IOC,LIT,LMT,LOC,MIDPX,MIT,MKT,MOC,MTL,NGCOMB,NODARK,NONALGO,OCA,OPG,OPGREROUT,PEGBENCH,PEGMID,POSTATS,POSTONLY,PREOPGRTH,PRICECHK,REL,REL2MID,RELPCTOFS,RPI,RTH,SCALE,SCALEODD,SCALERST,SIZECHK,SNAPMID,SNAPMKT,SNAPREL,STP,STPLMT,SWEEP,TRAIL,TRAILLIT,TRAILLMT,TRAILMIT,WHATIF|SMART,AMEX,NYSE,CBOE,PHLX,ISE,CHX,ARCA,ISLAND,DRCTEDGE,BEX,BATS,EDGEA,CSFBALGO,JEFFALGO,BYX,IEX,EDGX,FOXRIVER,PEARL,NYSENAT,LTSE,MEMX,PSX|1|0|TESLA INC|NASDAQ||Consumer, Cyclical|Auto Manufacturers|Auto-Cars/Light Trucks|US/Eastern|20221229:0400-20221229:2000;20221230:0400-20221230:2000;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0400-20230103:2000|20221229:0930-20221229:1600;20221230:0930-20221230:1600;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0930-20230103:1600|||1|ISIN|US88160R1014|1|||26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26||COMMON|1|1|100||".to_string(), "10|9001|TSLA|STK||0||AMEX|USD|TSLA|NMS|NMS|76792991|0.01||ACTIVETIM,AD,ADJUST,ALERT,ALLOC,AVGCOST,BASKET,BENCHPX,CASHQTY,COND,CONDORDER,DAY,DEACT,DEACTDIS,DEACTEOD,GAT,GTC,GTD,GTT,HID,IOC,LIT,LMT,MIT,MKT,MTL,NGCOMB,NONALGO,OCA,PEGBENCH,SCALE,SCALERST,SNAPMID,SNAPMKT,SNAPREL,STP,STPLMT,TRAIL,TRAILLIT,TRAILLMT,TRAILMIT,WHATIF|SMART,AMEX,NYSE,CBOE,PHLX,ISE,CHX,ARCA,ISLAND,DRCTEDGE,BEX,BATS,EDGEA,CSFBALGO,JEFFALGO,BYX,IEX,EDGX,FOXRIVER,PEARL,NYSENAT,LTSE,MEMX,PSX|1|0|TESLA INC|NASDAQ||Consumer, Cyclical|Auto Manufacturers|Auto-Cars/Light Trucks|US/Eastern|20221229:0700-20221229:2000;20221230:0700-20221230:2000;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0700-20230103:2000|20221229:0700-20221229:2000;20221230:0700-20221230:2000;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0700-20230103:2000|||1|ISIN|US88160R1014|1|||26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26||COMMON|1|1|100||".to_string(), "52|1|9001||".to_string(), ] - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -22,7 +21,7 @@ fn request_stock_contract_details() { let results = client.contract_details(&contract); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "9|8|9000|0|TSLA|STK||0|||SMART||USD|||0|||"); diff --git a/src/errors.rs b/src/errors.rs index 141df5f7..cc1e6720 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -15,6 +15,8 @@ pub enum Error { Parse(usize, String, String), ServerVersion(i32, i32, String), Simple(String), + ConnectionFailed, + Cancelled, } impl std::error::Error for Error {} @@ -31,6 +33,8 @@ impl std::fmt::Display for Error { Error::NotImplemented => write!(f, "not implemented"), Error::Parse(i, value, message) => write!(f, "parse error: {i} - {value} - {message}"), Error::ServerVersion(wanted, have, message) => write!(f, "server version {wanted} required, got {have}: {message}"), + Error::ConnectionFailed => write!(f, "ConnectionFailed"), + Error::Cancelled => write!(f, "Cancelled"), Error::Simple(ref err) => write!(f, "error occurred: {err}"), } diff --git a/src/lib.rs b/src/lib.rs index ff56311e..52e26430 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,7 +116,7 @@ pub mod errors; /// APIs for retrieving market data pub mod market_data; mod messages; -pub(crate) mod news; +pub mod news; /// Data types for building and placing orders. pub mod orders; @@ -131,6 +131,9 @@ pub use client::Client; #[cfg(test)] pub(crate) mod stubs; +#[cfg(test)] +pub(crate) mod tests; + #[cfg(test)] pub(crate) mod testdata; diff --git a/src/market_data/historical.rs b/src/market_data/historical.rs index 6bfe88f3..aeef9cd6 100644 --- a/src/market_data/historical.rs +++ b/src/market_data/historical.rs @@ -6,7 +6,7 @@ use time::{Date, OffsetDateTime}; use crate::contracts::Contract; use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; -use crate::transport::InternalSubscription; +use crate::transport::{InternalSubscription, Response}; use crate::{server_versions, Client, Error, ToField}; mod decoders; @@ -167,7 +167,7 @@ impl ToDuration for i32 { } #[derive(Debug)] -struct HistogramData { +pub struct HistogramData { pub price: f64, pub count: i32, } @@ -304,7 +304,7 @@ pub(crate) fn head_timestamp(client: &Client, contract: &Contract, what_to_show: let subscription = client.send_request(request_id, request)?; - if let Some(mut message) = subscription.next() { + if let Some(Response::Message(mut message)) = subscription.next() { decoders::decode_head_timestamp(&mut message) } else { Err(Error::Simple("did not receive head timestamp message".into())) @@ -359,7 +359,7 @@ pub(crate) fn historical_data( let subscription = client.send_request(request_id, request)?; - if let Some(mut message) = subscription.next() { + if let Some(Response::Message(mut message)) = subscription.next() { let time_zone = if let Some(tz) = client.time_zone { tz } else { @@ -410,7 +410,7 @@ pub(crate) fn historical_schedule( let subscription = client.send_request(request_id, request)?; - if let Some(mut message) = subscription.next() { + if let Some(Response::Message(mut message)) = subscription.next() { match message.message_type() { IncomingMessages::HistoricalSchedule => decoders::decode_historical_schedule(&mut message), IncomingMessages::Error => Err(Error::Simple(message.peek_string(4))), @@ -547,7 +547,7 @@ impl + Debug> Iterator for TickIterator { loop { match self.messages.next() { - Some(mut message) => { + Some(Response::Message(mut message)) => { if message.message_type() == Self::Item::message_type() { let (ticks, done) = Self::Item::decode(&mut message).unwrap(); @@ -568,10 +568,11 @@ impl + Debug> Iterator for TickIterator { error!("unexpected message: {:?}", message) } } - None => return None, + // TODO enumerate + _ => return None, } } } } -struct HistogramDataIterator {} +pub struct HistogramDataIterator {} diff --git a/src/market_data/historical/tests.rs b/src/market_data/historical/tests.rs index 96df075e..03a2fc28 100644 --- a/src/market_data/historical/tests.rs +++ b/src/market_data/historical/tests.rs @@ -1,5 +1,4 @@ -use std::sync::RwLock; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use time::macros::datetime; @@ -11,10 +10,10 @@ use super::*; #[test] fn test_head_timestamp() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["9|9000|1678323335|".to_owned()], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -28,7 +27,7 @@ fn test_head_timestamp() { assert_eq!(head_timestamp, OffsetDateTime::from_unix_timestamp(1678323335).unwrap(), "bar.date"); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); let head_timestamp_request = &request_messages[0]; assert_eq!( @@ -71,12 +70,12 @@ fn test_histogram_data() { #[test] fn test_historical_data() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![ "17\09000\020230413 16:31:22\020230415 16:31:22\02\020230413\0182.9400\0186.5000\0180.9400\0185.9000\0948837.22\0184.869\0324891\020230414\0183.8800\0186.2800\0182.0100\0185.0000\0810998.27\0183.9865\0277547\0".to_owned() ], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -108,7 +107,7 @@ fn test_historical_data() { // Assert Request - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); let head_timestamp_request = &request_messages[0]; assert_eq!( diff --git a/src/market_data/realtime.rs b/src/market_data/realtime.rs index 9387712e..ace13a65 100644 --- a/src/market_data/realtime.rs +++ b/src/market_data/realtime.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use log::error; use time::OffsetDateTime; @@ -8,7 +6,7 @@ use crate::contracts::Contract; use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; use crate::orders::TagValue; use crate::server_versions; -use crate::transport::InternalSubscription; +use crate::transport::{InternalSubscription, Response}; use crate::ToField; use crate::{Client, Error}; @@ -195,16 +193,10 @@ pub(crate) fn realtime_bars<'a>( } let request_id = client.next_request_id(); - let packet = encoders::encode_request_realtime_bars(client.server_version(), request_id, contract, bar_size, what_to_show, use_rth, options)?; - - let responses = client.send_request(request_id, packet)?; + let request = encoders::encode_request_realtime_bars(client.server_version(), request_id, contract, bar_size, what_to_show, use_rth, options)?; + let subscription = client.send_request(request_id, request)?; - Ok(Subscription { - client, - request_id: Some(request_id), - subscription: responses, - phantom: PhantomData, - }) + Ok(Subscription::new(client, subscription)) } // Requests tick by tick AllLast ticks. @@ -300,14 +292,9 @@ pub(crate) fn tick_by_tick_midpoint<'a>( let request_id = client.next_request_id(); let message = encoders::tick_by_tick(server_version, request_id, contract, "MidPoint", number_of_ticks, ignore_size)?; - let responses = client.send_request(request_id, message)?; + let subscription = client.send_request(request_id, message)?; - Ok(Subscription { - client, - request_id: Some(request_id), - subscription: responses, - phantom: PhantomData, - }) + Ok(Subscription::new(client, subscription)) } // Iterators @@ -333,14 +320,15 @@ impl<'a> Iterator for TradeIterator<'a> { fn next(&mut self) -> Option { loop { match self.responses.next() { - Some(mut message) => match message.message_type() { + Some(Response::Message(mut message)) => match message.message_type() { IncomingMessages::TickByTick => match decoders::decode_trade_tick(&mut message) { Ok(tick) => return Some(tick), Err(e) => error!("unexpected message {message:?}: {e:?}"), }, _ => error!("unexpected message {message:?}"), }, - None => return None, + // TODO enumerate + _ => return None, } } } @@ -357,7 +345,7 @@ pub(crate) struct BidAskIterator<'a> { fn cancel_tick_by_tick(client: &Client, request_id: i32) { if client.server_version() >= server_versions::TICK_BY_TICK { let message = encoders::cancel_tick_by_tick(request_id).unwrap(); - client.send_message(message).unwrap(); + client.message_bus.cancel_subscription(request_id, &message).unwrap(); } } @@ -375,14 +363,15 @@ impl<'a> Iterator for BidAskIterator<'a> { fn next(&mut self) -> Option { loop { match self.responses.next() { - Some(mut message) => match message.message_type() { + Some(Response::Message(mut message)) => match message.message_type() { IncomingMessages::TickByTick => match decoders::bid_ask_tick(&mut message) { Ok(tick) => return Some(tick), Err(e) => error!("unexpected message {message:?}: {e:?}"), }, _ => error!("unexpected message {message:?}"), }, - None => return None, + // TODO enumerate + _ => return None, } } } diff --git a/src/market_data/realtime/tests.rs b/src/market_data/realtime/tests.rs index 2431b60e..b501beac 100644 --- a/src/market_data/realtime/tests.rs +++ b/src/market_data/realtime/tests.rs @@ -1,5 +1,5 @@ +use std::sync::Arc; use std::sync::RwLock; -use std::sync::{Arc, Mutex}; use time::OffsetDateTime; @@ -12,10 +12,10 @@ use super::*; #[test] fn realtime_bars() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["50|3|9001|1678323335|4028.75|4029.00|4028.25|4028.50|2|4026.75|1|".to_owned()], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -47,7 +47,7 @@ fn realtime_bars() { // Should trigger cancel realtime bars drop(bars); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); // Verify Requests let realtime_bars_request = &request_messages[0]; diff --git a/src/orders.rs b/src/orders.rs index cbc41956..84aa8005 100644 --- a/src/orders.rs +++ b/src/orders.rs @@ -6,7 +6,7 @@ use log::{error, info}; use crate::contracts::{ComboLeg, ComboLegOpenClose, Contract, DeltaNeutralContract, SecurityType}; use crate::messages::{IncomingMessages, OutgoingMessages}; use crate::messages::{RequestMessage, ResponseMessage}; -use crate::transport::InternalSubscription; +use crate::transport::{InternalSubscription, Response}; use crate::Client; use crate::{encode_option_field, ToField}; use crate::{server_versions, Error}; @@ -1055,7 +1055,7 @@ impl Iterator for OrderNotificationIterator { } loop { - if let Some(mut message) = self.messages.next() { + if let Some(Response::Message(mut message)) = self.messages.next() { match message.message_type() { IncomingMessages::OpenOrder => { let open_order = decoders::decode_open_order(self.server_version, message); @@ -1332,7 +1332,7 @@ impl Iterator for CancelOrderResultIterator { /// Returns the next [CancelOrderResult]. Waits up to x seconds for next [CancelOrderResult]. fn next(&mut self) -> Option { loop { - if let Some(mut message) = self.messages.next() { + if let Some(Response::Message(mut message)) = self.messages.next() { match message.message_type() { IncomingMessages::OrderStatus => match decoders::decode_order_status(self.server_version, &mut message) { Ok(val) => return Some(CancelOrderResult::OrderStatus(val)), @@ -1373,7 +1373,7 @@ pub(crate) fn next_valid_order_id(client: &Client) -> Result { let subscription = client.send_shared_request(OutgoingMessages::RequestIds, message)?; - if let Some(message) = subscription.next() { + if let Some(Response::Message(message)) = subscription.next() { let order_id_index = 2; let next_order_id = message.peek_int(order_id_index)?; @@ -1418,7 +1418,7 @@ impl Iterator for OrderDataIterator { /// Returns the next [OrderDataResult]. Waits up to x seconds for next [OrderDataResult]. fn next(&mut self) -> Option { loop { - if let Some(mut message) = self.messages.next() { + if let Some(Response::Message(mut message)) = self.messages.next() { match message.message_type() { IncomingMessages::CompletedOrder => match decoders::decode_completed_order(self.server_version, message) { Ok(val) => return Some(OrderDataResult::OrderData(Box::new(val))), @@ -1554,7 +1554,7 @@ impl Iterator for ExecutionDataIterator { /// Returns the next [OrderDataResult]. Waits up to x seconds for next [OrderDataResult]. fn next(&mut self) -> Option { loop { - if let Some(mut message) = self.messages.next() { + if let Some(Response::Message(mut message)) = self.messages.next() { match message.message_type() { IncomingMessages::ExecutionData => match decoders::decode_execution_data(self.server_version, &mut message) { Ok(val) => return Some(ExecutionDataResult::ExecutionData(Box::new(val))), diff --git a/src/orders/tests.rs b/src/orders/tests.rs index 622f2055..9b08fa10 100644 --- a/src/orders/tests.rs +++ b/src/orders/tests.rs @@ -1,5 +1,4 @@ -use std::sync::RwLock; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use crate::contracts::{contract_samples, Contract, SecurityType}; use crate::stubs::MessageBusStub; @@ -8,7 +7,7 @@ use super::*; #[test] fn place_order() { - let message_bus = Arc::new(Mutex::new(MessageBusStub{ + let message_bus = Arc::new(MessageBusStub{ request_messages: RwLock::new(vec![]), response_messages: vec![ "5|13|76792991|TSLA|STK||0|?||SMART|USD|TSLA|NMS|BUY|100|MKT|0.0|0.0|DAY||DU1234567||0||100|1376327563|0|0|0||1376327563.0/DU1234567/100||||||||||0||-1|0||||||2147483647|0|0|0||3|0|0||0|0||0|None||0||||?|0|0||0|0||||||0|0|0|2147483647|2147483647|||0||IB|0|0||0|0|PreSubmitted|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308||||||0|0|0|None|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|0||||0|1|0|0|0|||0||".to_owned(), @@ -19,7 +18,7 @@ fn place_order() { "5|13|76792991|TSLA|STK||0|?||SMART|USD|TSLA|NMS|BUY|100|MKT|0.0|0.0|DAY||DU1234567||0||100|1376327563|0|0|0||1376327563.0/DU1234567/100||||||||||0||-1|0||||||2147483647|0|0|0||3|0|0||0|0||0|None||0||||?|0|0||0|0||||||0|0|0|2147483647|2147483647|||0||IB|0|0||0|0|Filled|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.0|||USD||0|0|0|None|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|1.7976931348623157E308|0||||0|1|0|0|0|||0||".to_owned(), "59|1|00025b46.63f8f39c.01.01|1.0|USD|1.7976931348623157E308|1.7976931348623157E308|||".to_owned(), ] - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -36,7 +35,7 @@ fn place_order() { let result = client.place_order(order_id, &contract, &order); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!( request_messages[0].encode().replace('\0', "|"), @@ -297,20 +296,20 @@ fn place_order() { #[test] fn cancel_order() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![ "3|41|Cancelled|0|100|0|71270927|0|0|100||0||".to_owned(), "4|2|41|202|Order Canceled - reason:||".to_owned(), ], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let order_id = 41; let results = client.cancel_order(order_id, ""); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode(), "4\01\041\0"); @@ -339,16 +338,16 @@ fn cancel_order() { #[test] fn global_cancel() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let mut client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let results = super::global_cancel(&mut client); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode(), "58\01\0"); assert!(results.is_ok(), "failed to cancel order: {}", results.err().unwrap()); @@ -356,16 +355,16 @@ fn global_cancel() { #[test] fn next_valid_order_id() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["9|1|43||".to_owned()], - })); + }); let mut client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let results = super::next_valid_order_id(&mut client); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode(), "8\01\00\0"); @@ -375,20 +374,20 @@ fn next_valid_order_id() { #[test] fn completed_orders() { - let message_bus = Arc::new(Mutex::new(MessageBusStub{ + let message_bus = Arc::new(MessageBusStub{ request_messages: RwLock::new(vec![]), response_messages: vec![ "101|265598|AAPL|STK||0|?||SMART|USD|AAPL|NMS|BUY|0|MKT|0.0|0.0|DAY||DU1234567||0||1824933227|0|0|0|||||||||||0||-1||||||2147483647|0|0||3|0||0|None||0|0|0||0|0||||0|0|0|2147483647|2147483647||||IB|0|0||0|Filled|0|0|0|1.7976931348623157E308|1.7976931348623157E308|0|1|0||100|2147483647|0|Not an insider or substantial shareholder|0|0|9223372036854775807|20230306 12:28:30 America/Los_Angeles|Filled Size: 100|".to_owned(), "102|".to_owned(), ], - })); + }); let mut client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let api_only = true; let results = super::completed_orders(&mut client, api_only); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode(), "99\01\0"); @@ -511,16 +510,16 @@ fn completed_orders() { #[test] fn open_orders() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["9|1|43||".to_owned()], - })); + }); let mut client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let results = super::open_orders(&mut client); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "5|1|"); @@ -529,16 +528,16 @@ fn open_orders() { #[test] fn all_open_orders() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["9|1|43||".to_owned()], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let results = client.all_open_orders(); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "16|1|"); @@ -547,17 +546,17 @@ fn all_open_orders() { #[test] fn auto_open_orders() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["9|1|43||".to_owned()], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); let api_only = true; let results = client.auto_open_orders(api_only); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!(request_messages[0].encode_simple(), "15|1|1|"); @@ -566,10 +565,10 @@ fn auto_open_orders() { #[test] fn executions() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec!["9|1|43||".to_owned()], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -584,7 +583,7 @@ fn executions() { }; let results = client.executions(filter); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!( request_messages[0].encode_simple(), @@ -597,10 +596,10 @@ fn executions() { #[test] fn encode_limit_order() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -610,7 +609,7 @@ fn encode_limit_order() { let results = client.place_order(order_id, &contract, &order); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!( request_messages[0].encode_simple(), @@ -622,10 +621,10 @@ fn encode_limit_order() { #[test] fn encode_combo_market_order() { - let message_bus = Arc::new(Mutex::new(MessageBusStub { + let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), response_messages: vec![], - })); + }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); @@ -635,7 +634,7 @@ fn encode_combo_market_order() { let results = client.place_order(order_id, &contract, &order); - let request_messages = client.message_bus.lock().unwrap().request_messages(); + let request_messages = client.message_bus.request_messages(); assert_eq!( request_messages[0].encode_simple(), diff --git a/src/stubs.rs b/src/stubs.rs index 14c1e0be..f8ebb1e0 100644 --- a/src/stubs.rs +++ b/src/stubs.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use crossbeam::channel; use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage}; -use crate::transport::{InternalSubscription, MessageBus, SubscriptionBuilder}; +use crate::transport::{InternalSubscription, MessageBus, Response, SubscriptionBuilder}; use crate::Error; pub(crate) struct MessageBusStub { @@ -19,49 +19,63 @@ impl MessageBus for MessageBusStub { self.request_messages.read().unwrap().clone() } - fn read_message(&mut self) -> Result { - Ok(ResponseMessage::default()) + fn send_request(&self, request_id: i32, message: &RequestMessage) -> Result { + mock_request(self, Some(request_id), None, message) } - fn write_message(&mut self, message: &RequestMessage) -> Result<(), Error> { - self.request_messages.write().unwrap().push(message.clone()); + fn cancel_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error> { + mock_request(self, Some(request_id), None, packet); Ok(()) } - fn send_request(&mut self, request_id: i32, message: &RequestMessage) -> Result { - mock_request(self, request_id, message) + fn send_order_request(&self, request_id: i32, message: &RequestMessage) -> Result { + mock_request(self, Some(request_id), None, message) } - fn send_order_request(&mut self, request_id: i32, message: &RequestMessage) -> Result { - mock_request(self, request_id, message) + fn cancel_order_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error> { + mock_request(self, Some(request_id), None, packet); + Ok(()) } - fn send_shared_request(&mut self, _message_id: OutgoingMessages, message: &RequestMessage) -> Result { - mock_global_request(self, message) + fn send_shared_request(&self, message_type: OutgoingMessages, message: &RequestMessage) -> Result { + mock_request(self, None, Some(message_type), message) } - fn write(&mut self, _packet: &str) -> Result<(), Error> { + fn cancel_shared_subscription(&self, message_type: OutgoingMessages, packet: &RequestMessage) -> Result<(), Error> { + mock_request(self, None, Some(message_type), packet)?; Ok(()) } - fn process_messages(&mut self, _server_version: i32) -> Result<(), Error> { - Ok(()) - } + // fn process_messages(&mut self, _server_version: i32) -> Result<(), Error> { + // Ok(()) + // } } -fn mock_request(stub: &mut MessageBusStub, _request_id: i32, message: &RequestMessage) -> Result { +fn mock_request( + stub: &MessageBusStub, + request_id: Option, + message_type: Option, + message: &RequestMessage, +) -> Result { stub.request_messages.write().unwrap().push(message.clone()); let (sender, receiver) = channel::unbounded(); let (s1, _r1) = channel::unbounded(); for message in &stub.response_messages { - sender.send(ResponseMessage::from(&message.replace('|', "\0"))).unwrap(); + let message = ResponseMessage::from(&message.replace('|', "\0")); + sender.send(Response::from(message)).unwrap(); } - let subscription = SubscriptionBuilder::new().shared_receiver(Arc::new(receiver)).signaler(s1).build(); + let mut subscription = SubscriptionBuilder::new().shared_receiver(Arc::new(receiver)).signaler(s1); + if let Some(request_id) = request_id { + subscription = subscription.request_id(request_id); + } + if let Some(message_type) = message_type { + subscription = subscription.message_type(message_type); + } - Ok(subscription) + Ok(subscription.build()) } fn mock_global_request(stub: &mut MessageBusStub, message: &RequestMessage) -> Result { @@ -70,7 +84,8 @@ fn mock_global_request(stub: &mut MessageBusStub, message: &RequestMessage) -> R let (sender, receiver) = channel::unbounded(); for message in &stub.response_messages { - sender.send(ResponseMessage::from(&message.replace('|', "\0"))).unwrap(); + let message = ResponseMessage::from(&message.replace('|', "\0")); + sender.send(Response::from(message)).unwrap(); } let subscription = SubscriptionBuilder::new().shared_receiver(Arc::new(receiver)).build(); diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 00000000..353d5d63 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1 @@ +pub fn assert_send_and_sync() {} diff --git a/src/transport.rs b/src/transport.rs index 14311485..d17635cd 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -3,16 +3,19 @@ //! and responses from TWS back to the Client. use std::collections::HashMap; -use std::io::{prelude::*, Cursor}; +use std::io::{prelude::*, Cursor, ErrorKind}; use std::net::TcpStream; -use std::sync::Mutex; -use std::sync::{Arc, RwLock}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread::{self, JoinHandle}; use std::time::Duration; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use crossbeam::channel::{self, Receiver, Sender}; -use log::{debug, error, info}; +use log::{debug, error, info, warn}; +use time::macros::format_description; +use time::OffsetDateTime; +use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, Tz}; use crate::messages::{IncomingMessages, OutgoingMessages}; use crate::messages::{RequestMessage, ResponseMessage}; @@ -21,27 +24,29 @@ use recorder::MessageRecorder; mod recorder; +const MIN_SERVER_VERSION: i32 = 100; +const MAX_SERVER_VERSION: i32 = server_versions::HISTORICAL_SCHEDULE; +const MAX_RETRIES: i32 = 20; + pub(crate) trait MessageBus: Send + Sync { - // Reads the next available message from TWS - fn read_message(&mut self) -> Result; + // Sends formatted message to TWS and creates a reply channel by request id. + fn send_request(&self, request_id: i32, packet: &RequestMessage) -> Result; - // Sends a formatted packet TWS - fn write_message(&mut self, packet: &RequestMessage) -> Result<(), Error>; + // Sends formatted message to TWS and creates a reply channel by request id. + fn cancel_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error>; - // Sends raw data to TWS - fn write(&mut self, packet: &str) -> Result<(), Error>; + // Sends formatted message to TWS and creates a reply channel by message type. + fn send_shared_request(&self, message_id: OutgoingMessages, packet: &RequestMessage) -> Result; - // Sends formatted message to TWS and creates a reply channel by request id. - fn send_request(&mut self, request_id: i32, packet: &RequestMessage) -> Result; + // Sends formatted message to TWS and creates a reply channel by message type. + fn cancel_shared_subscription(&self, message_id: OutgoingMessages, packet: &RequestMessage) -> Result<(), Error>; // Sends formatted order specific message to TWS and creates a reply channel by order id. - fn send_order_request(&mut self, request_id: i32, packet: &RequestMessage) -> Result; + fn send_order_request(&self, request_id: i32, packet: &RequestMessage) -> Result; - // Sends formatted message to TWS and creates a reply channel by message type. - fn send_shared_request(&mut self, message_id: OutgoingMessages, packet: &RequestMessage) -> Result; + fn cancel_order_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error>; - // Starts a dedicated thread to process responses from TWS. - fn process_messages(&mut self, server_version: i32) -> Result<(), Error>; + fn ensure_shutdown(&self) {} // Testing interface. Tracks requests sent messages when Bus is stubbed. #[cfg(test)] @@ -50,14 +55,27 @@ pub(crate) trait MessageBus: Send + Sync { } } +#[derive(Clone, Debug)] +pub(crate) enum Response { + Message(ResponseMessage), + Cancelled, + Disconnected, +} + +impl From for Response { + fn from(val: ResponseMessage) -> Self { + Response::Message(val) + } +} + // For requests without an identifier, shared channels are created // to route request/response pairs based on message type. #[derive(Debug)] struct SharedChannels { // Maps an inbound reply to channel used to send responses. - senders: HashMap>>, + senders: HashMap>>, // Maps an outbound request to channel used to receive responses. - receivers: HashMap>>, + receivers: HashMap>>, } impl SharedChannels { @@ -91,7 +109,7 @@ impl SharedChannels { // Maps an outgoing message to incoming message(s) fn register(&mut self, outbound: OutgoingMessages, inbounds: &[IncomingMessages]) { - let (sender, receiver) = channel::unbounded::(); + let (sender, receiver) = channel::unbounded::(); self.receivers.insert(outbound, Arc::new(receiver)); @@ -103,7 +121,7 @@ impl SharedChannels { } // Get receiver for specified message type. Panics if receiver not found. - pub fn get_receiver(&self, message_type: OutgoingMessages) -> Arc> { + pub fn get_receiver(&self, message_type: OutgoingMessages) -> Arc> { let receiver = self.receivers.get(&message_type).unwrap_or_else(|| { panic!("unsupported request message {message_type:?}. check mapping in SharedChannels::new() located in transport.rs") }); @@ -112,7 +130,7 @@ impl SharedChannels { } // Get sender for specified message type. Panics if sender not found. - pub fn get_sender(&self, message_type: IncomingMessages) -> Arc> { + pub fn get_sender(&self, message_type: IncomingMessages) -> Arc> { let sender = self .senders .get(&message_type) @@ -135,106 +153,302 @@ pub enum Signal { #[derive(Debug)] pub struct TcpMessageBus { - reader: Arc, - writer: Arc>, - handles: Vec>, - requests: Arc>, - orders: Arc>, - recorder: MessageRecorder, - shared_channels: Arc, + connection: Connection, + handles: Mutex>>, + requests: SenderHash, + orders: SenderHash, + executions: SenderHash, + shared_channels: SharedChannels, signals_send: Sender, signals_recv: Receiver, + shutdown_requested: AtomicBool, } impl TcpMessageBus { - // establishes TCP connection to server - pub fn connect(connection_string: &str) -> Result { - let stream = TcpStream::connect(connection_string)?; - - let reader = Arc::new(stream.try_clone()?); - let writer = Arc::new(Mutex::new(stream)); - let requests = Arc::new(SenderHash::new()); - let orders = Arc::new(SenderHash::new()); - + pub fn new(connection: Connection) -> Result { let (signals_send, signals_recv) = channel::unbounded(); Ok(TcpMessageBus { - reader, - writer, - handles: Vec::default(), - requests, - orders, - recorder: MessageRecorder::new(), - shared_channels: Arc::new(SharedChannels::new()), + connection, + handles: Mutex::new(Vec::default()), + requests: SenderHash::new(), + orders: SenderHash::new(), + executions: SenderHash::new(), + shared_channels: SharedChannels::new(), signals_send, signals_recv, + shutdown_requested: AtomicBool::new(false), }) } + fn is_shutting_down(&self) -> bool { + self.shutdown_requested.load(Ordering::SeqCst) + } + + fn request_shutdown(&self) { + self.shutdown_requested.store(true, Ordering::Relaxed); + } + + fn reset(&self) {} + + fn clean_request(&self, request_id: i32) { + self.requests.remove(&request_id); + debug!("released request_id {}, requests.len()={}", request_id, self.requests.len()); + } + + fn clean_order(&self, order_id: i32) { + self.orders.remove(&order_id); + debug!("released order_id {}, orders.len()={}", order_id, self.orders.len()); + } + + fn read_message(&self) -> Result { + self.connection.read_message() + } + // Dispatcher thread reads messages from TWS and dispatches them to // appropriate channel. - fn start_dispatcher_thread(&mut self, server_version: i32) -> JoinHandle { - let reader = Arc::clone(&self.reader); - let requests = Arc::clone(&self.requests); - let recorder = self.recorder.clone(); - let orders = Arc::clone(&self.orders); - let shared_channels = Arc::clone(&self.shared_channels); - let executions = SenderHash::::new(); + fn start_dispatcher_thread(self: &Arc, server_version: i32) -> JoinHandle<()> { + let message_bus = Arc::clone(self); - thread::spawn(move || loop { - match read_packet(&reader) { - Ok(message) => { - recorder.record_response(&message); - dispatch_message(message, server_version, &requests, &orders, &shared_channels, &executions); - } - Err(err) => { - error!("error reading packet: {:?}", err); - continue; + const RECONNECT_ERRORS: &[ErrorKind] = &[ErrorKind::ConnectionReset, ErrorKind::UnexpectedEof]; + const RETRY_ERRORS: &[ErrorKind] = &[ErrorKind::Interrupted]; + + thread::spawn(move || { + let mut backoff = FibonacciBackoff::new(30); + let mut retry_attempt = 0; + + loop { + match message_bus.read_message() { + Ok(message) => { + message_bus.dispatch_message(server_version, message); + + backoff.reset(); + retry_attempt = 1; + } + Err(Error::Io(e)) if RECONNECT_ERRORS.contains(&e.kind()) => { + error!("error reading packet: {:?}", e); + // reset hashes + if let Err(e) = message_bus.connection.reconnect() { + error!("error reconnecting: {:?}", e); + message_bus.request_shutdown(); + return; + } + info!("reconnected"); + message_bus.reset(); + continue; + } + Err(Error::Io(e)) if RETRY_ERRORS.contains(&e.kind()) => { + error!("error reading packet: {:?}", e); + let next_delay = backoff.next_delay(); + if retry_attempt > MAX_RETRIES { + message_bus.request_shutdown(); + return; + } + error!("retry read attempt {retry_attempt} of {MAX_RETRIES}"); + thread::sleep(next_delay); + retry_attempt += 1; + continue; + } + Err(err) => { + error!("error reading packet: {:?}", err); + message_bus.request_shutdown(); + return; + } + }; + + if message_bus.is_shutting_down() { + return; } - }; + } }) } + fn dispatch_message(&self, server_version: i32, message: ResponseMessage) { + match message.message_type() { + IncomingMessages::Error => { + let request_id = message.peek_int(2).unwrap_or(-1); + + if request_id == UNSPECIFIED_REQUEST_ID { + error_event(server_version, message).unwrap(); + } else { + self.process_response(message); + } + } + IncomingMessages::OrderStatus + | IncomingMessages::OpenOrder + | IncomingMessages::OpenOrderEnd + | IncomingMessages::CompletedOrder + | IncomingMessages::CompletedOrdersEnd + | IncomingMessages::ExecutionData + | IncomingMessages::ExecutionDataEnd + | IncomingMessages::CommissionsReport => self.process_orders(message), + _ => self.process_response(message), + }; + } + + fn process_response(&self, message: ResponseMessage) { + let request_id = message.request_id().unwrap_or(-1); // pass in request id? + if self.requests.contains(&request_id) { + self.requests.send(&request_id, Response::Message(message)).unwrap(); + } else if self.orders.contains(&request_id) { + self.orders.send(&request_id, Response::Message(message)).unwrap(); + } else if self.shared_channels.contains_sender(message.message_type()) { + self.shared_channels + .get_sender(message.message_type()) + .send(Response::Message(message)) + .unwrap() + } else { + info!("no recipient found for: {:?}", message) + } + } + + fn process_orders(&self, message: ResponseMessage) { + match message.message_type() { + IncomingMessages::ExecutionData => { + match (message.order_id(), message.request_id()) { + // First check matching orders channel + (Some(order_id), _) if self.orders.contains(&order_id) => { + if let Err(e) = self.orders.send(&order_id, Response::Message(message)) { + error!("error routing message for order_id({order_id}): {e}"); + } + } + (_, Some(request_id)) if self.requests.contains(&request_id) => { + if let Some(sender) = self.requests.copy_sender(request_id) { + if let Some(execution_id) = message.execution_id() { + self.executions.insert(execution_id, sender); + } + } + + if let Err(e) = self.requests.send(&request_id, Response::Message(message)) { + error!("error routing message for request_id({request_id}): {e}"); + } + } + _ => { + error!("could not route message {message:?}"); + } + } + } + IncomingMessages::ExecutionDataEnd => { + match (message.order_id(), message.request_id()) { + // First check matching orders channel + (Some(order_id), _) if self.orders.contains(&order_id) => { + if let Err(e) = self.orders.send(&order_id, Response::from(message)) { + error!("error routing message for order_id({order_id}): {e}"); + } + } + (_, Some(request_id)) if self.requests.contains(&request_id) => { + if let Err(e) = self.requests.send(&request_id, Response::from(message)) { + error!("error routing message for request_id({request_id}): {e}"); + } + } + _ => { + error!("could not route message {message:?}"); + } + } + } + IncomingMessages::OpenOrder | IncomingMessages::OrderStatus => { + if let Some(order_id) = message.order_id() { + if self.orders.contains(&order_id) { + if let Err(e) = self.orders.send(&order_id, Response::from(message)) { + error!("error routing message for order_id({order_id}): {e}"); + } + } else if let Err(e) = self.shared_channels.get_sender(IncomingMessages::OpenOrder).send(Response::from(message)) { + error!("error sending IncomingMessages::OpenOrder: {e}"); + } + } + } + IncomingMessages::CompletedOrder => { + if let Err(e) = self.shared_channels.get_sender(message.message_type()).send(Response::from(message)) { + error!("error sending IncomingMessages::CompletedOrder: {e}"); + } + } + IncomingMessages::OpenOrderEnd => { + if let Err(e) = self.shared_channels.get_sender(message.message_type()).send(Response::from(message)) { + error!("error sending IncomingMessages::OpenOrderEnd: {e}"); + } + } + IncomingMessages::CompletedOrdersEnd => { + if let Err(e) = self.shared_channels.get_sender(message.message_type()).send(Response::from(message)) { + error!("error sending IncomingMessages::CompletedOrdersEnd: {e}"); + } + } + IncomingMessages::CommissionsReport => { + if let Some(execution_id) = message.execution_id() { + if let Err(e) = self.executions.send(&execution_id, Response::from(message)) { + error!("error sending commission report for execution {}: {}", execution_id, e); + } + } + } + _ => (), + } + } + // The cleanup thread receives signals as subscribers are dropped and // releases the sender channels - fn start_cleanup_thread(&mut self) -> JoinHandle { - let requests = Arc::clone(&self.requests); - let orders = Arc::clone(&self.orders); - let signal_recv = self.signals_recv.clone(); + fn start_cleanup_thread(self: &Arc) -> JoinHandle<()> { + let message_bus = Arc::clone(self); thread::spawn(move || loop { + let signal_recv = message_bus.signals_recv.clone(); + for signal in &signal_recv { match signal { Signal::Request(request_id) => { - requests.remove(&request_id); - debug!("released request_id {}, requests.len()={}", request_id, requests.len()); + message_bus.clean_request(request_id); } Signal::Order(order_id) => { - orders.remove(&order_id); - debug!("released order_id {}, orders.len()={}", order_id, requests.len()); + message_bus.clean_order(order_id); } } + + if message_bus.is_shutting_down() { + return; + } } }) } + + pub(crate) fn process_messages(self: &Arc, server_version: i32) -> Result<(), Error> { + let handle = self.start_dispatcher_thread(server_version); + self.add_join_handle(handle); + + let handle = self.start_cleanup_thread(); + self.add_join_handle(handle); + + Ok(()) + } + + fn add_join_handle(&self, handle: JoinHandle<()>) { + let mut handles = self.handles.lock().unwrap(); + handles.push(handle); + } + + pub fn join(&self) { + let mut handles = self.handles.lock().unwrap(); + while !handles.is_empty() { + if let Some(handle) = handles.pop() { + if let Err(e) = handle.join() { + warn!("could not join thread: {e:?}"); + } + } + } + } } const UNSPECIFIED_REQUEST_ID: i32 = -1; impl MessageBus for TcpMessageBus { - fn read_message(&mut self) -> Result { - read_packet(&self.reader) - } - - fn send_request(&mut self, request_id: i32, packet: &RequestMessage) -> Result { + fn send_request(&self, request_id: i32, packet: &RequestMessage) -> Result { let (sender, receiver) = channel::unbounded(); + let sender_copy = sender.clone(); self.requests.insert(request_id, sender); - self.write_message(packet)?; + self.connection.write_message(packet)?; let subscription = SubscriptionBuilder::new() .receiver(receiver) + .sender(sender_copy) .signaler(self.signals_send.clone()) .request_id(request_id) .build(); @@ -242,15 +456,29 @@ impl MessageBus for TcpMessageBus { Ok(subscription) } - fn send_order_request(&mut self, order_id: i32, message: &RequestMessage) -> Result { + fn cancel_subscription(&self, request_id: i32, message: &RequestMessage) -> Result<(), Error> { + self.connection.write_message(message)?; + + if let Err(e) = self.requests.send(&request_id, Response::Cancelled) { + info!("error sending cancel notification: {e}"); + } + + self.requests.remove(&request_id); + + Ok(()) + } + + fn send_order_request(&self, order_id: i32, message: &RequestMessage) -> Result { let (sender, receiver) = channel::unbounded(); + let sender_copy = sender.clone(); self.orders.insert(order_id, sender); - self.write_message(message)?; + self.connection.write_message(message)?; let subscription = SubscriptionBuilder::new() .receiver(receiver) + .sender(sender_copy) .signaler(self.signals_send.clone()) .order_id(order_id) .build(); @@ -258,93 +486,41 @@ impl MessageBus for TcpMessageBus { Ok(subscription) } - fn send_shared_request(&mut self, message_id: OutgoingMessages, message: &RequestMessage) -> Result { - self.write_message(message)?; + fn cancel_order_subscription(&self, request_id: i32, message: &RequestMessage) -> Result<(), Error> { + self.connection.write_message(message)?; - let shared_receiver = self.shared_channels.get_receiver(message_id); + if let Err(e) = self.orders.send(&request_id, Response::Cancelled) { + info!("error sending cancel notification: {e}"); + } - let subscription = SubscriptionBuilder::new().shared_receiver(shared_receiver).build(); + self.orders.remove(&request_id); - Ok(subscription) + Ok(()) } - fn write_message(&mut self, message: &RequestMessage) -> Result<(), Error> { - let data = message.encode(); - debug!("-> {data:?}"); + fn send_shared_request(&self, message_type: OutgoingMessages, message: &RequestMessage) -> Result { + self.connection.write_message(message)?; - let data = data.as_bytes(); - - let mut packet = Vec::with_capacity(data.len() + 4); + let shared_receiver = self.shared_channels.get_receiver(message_type); - packet.write_u32::(data.len() as u32)?; - packet.write_all(data)?; - - self.writer.lock()?.write_all(&packet)?; - - self.recorder.record_request(message); + let subscription = SubscriptionBuilder::new() + .shared_receiver(shared_receiver) + .message_type(message_type) + .build(); - Ok(()) + Ok(subscription) } - fn write(&mut self, data: &str) -> Result<(), Error> { - debug!("{data:?} ->"); - self.writer.lock()?.write_all(data.as_bytes())?; + fn cancel_shared_subscription(&self, _message_type: OutgoingMessages, message: &RequestMessage) -> Result<(), Error> { + self.connection.write_message(message)?; + // TODO send cancel Ok(()) } - fn process_messages(&mut self, server_version: i32) -> Result<(), Error> { - let handle = self.start_dispatcher_thread(server_version); - self.handles.push(handle); - - let handle = self.start_cleanup_thread(); - self.handles.push(handle); - - Ok(()) + fn ensure_shutdown(&self) { + self.join(); } -} - -fn dispatch_message( - message: ResponseMessage, - server_version: i32, - requests: &Arc>, - orders: &Arc>, - shared_channels: &Arc, - executions: &SenderHash, -) { - match message.message_type() { - IncomingMessages::Error => { - let request_id = message.peek_int(2).unwrap_or(-1); - - if request_id == UNSPECIFIED_REQUEST_ID { - error_event(server_version, message).unwrap(); - } else { - process_response(requests, orders, shared_channels, message); - } - } - IncomingMessages::OrderStatus - | IncomingMessages::OpenOrder - | IncomingMessages::OpenOrderEnd - | IncomingMessages::CompletedOrder - | IncomingMessages::CompletedOrdersEnd - | IncomingMessages::ExecutionData - | IncomingMessages::ExecutionDataEnd - | IncomingMessages::CommissionsReport => process_orders(message, requests, orders, executions, shared_channels), - _ => process_response(requests, orders, shared_channels, message), - }; -} - -fn read_packet(mut reader: &TcpStream) -> Result { - let message_size = read_header(reader)?; - let mut data = vec![0_u8; message_size]; - reader.read_exact(&mut data)?; - - let raw_string = String::from_utf8(data)?; - debug!("<- {:?}", raw_string); - - let packet = ResponseMessage::from(&raw_string); - - Ok(packet) } fn read_header(mut reader: &TcpStream) -> Result { @@ -386,111 +562,6 @@ fn error_event(server_version: i32, mut packet: ResponseMessage) -> Result<(), E } } -fn process_response( - requests: &Arc>, - orders: &Arc>, - shared_channels: &Arc, - message: ResponseMessage, -) { - let request_id = message.request_id().unwrap_or(-1); // pass in request id? - if requests.contains(&request_id) { - requests.send(&request_id, message).unwrap(); - } else if orders.contains(&request_id) { - orders.send(&request_id, message).unwrap(); - } else if shared_channels.contains_sender(message.message_type()) { - shared_channels.get_sender(message.message_type()).send(message).unwrap() - } else { - info!("no recipient found for: {:?}", message) - } -} - -fn process_orders( - message: ResponseMessage, - requests: &Arc>, - orders: &Arc>, - executions: &SenderHash, - shared_channels: &Arc, -) { - match message.message_type() { - IncomingMessages::ExecutionData => { - match (message.order_id(), message.request_id()) { - // First check matching orders channel - (Some(order_id), _) if orders.contains(&order_id) => { - if let Err(e) = orders.send(&order_id, message) { - error!("error routing message for order_id({order_id}): {e}"); - } - } - (_, Some(request_id)) if requests.contains(&request_id) => { - if let Some(sender) = requests.copy_sender(request_id) { - if let Some(execution_id) = message.execution_id() { - executions.insert(execution_id, sender); - } - } - - if let Err(e) = requests.send(&request_id, message) { - error!("error routing message for request_id({request_id}): {e}"); - } - } - _ => { - error!("could not route message {message:?}"); - } - } - } - IncomingMessages::ExecutionDataEnd => { - match (message.order_id(), message.request_id()) { - // First check matching orders channel - (Some(order_id), _) if orders.contains(&order_id) => { - if let Err(e) = orders.send(&order_id, message) { - error!("error routing message for order_id({order_id}): {e}"); - } - } - (_, Some(request_id)) if requests.contains(&request_id) => { - if let Err(e) = requests.send(&request_id, message) { - error!("error routing message for request_id({request_id}): {e}"); - } - } - _ => { - error!("could not route message {message:?}"); - } - } - } - IncomingMessages::OpenOrder | IncomingMessages::OrderStatus => { - if let Some(order_id) = message.order_id() { - if orders.contains(&order_id) { - if let Err(e) = orders.send(&order_id, message) { - error!("error routing message for order_id({order_id}): {e}"); - } - } else if let Err(e) = shared_channels.get_sender(IncomingMessages::OpenOrder).send(message) { - error!("error sending IncomingMessages::OpenOrder: {e}"); - } - } - } - IncomingMessages::CompletedOrder => { - if let Err(e) = shared_channels.get_sender(message.message_type()).send(message) { - error!("error sending IncomingMessages::CompletedOrder: {e}"); - } - } - IncomingMessages::OpenOrderEnd => { - if let Err(e) = shared_channels.get_sender(message.message_type()).send(message) { - error!("error sending IncomingMessages::OpenOrderEnd: {e}"); - } - } - IncomingMessages::CompletedOrdersEnd => { - if let Err(e) = shared_channels.get_sender(message.message_type()).send(message) { - error!("error sending IncomingMessages::CompletedOrdersEnd: {e}"); - } - } - IncomingMessages::CommissionsReport => { - if let Some(execution_id) = message.execution_id() { - if let Err(e) = executions.send(&execution_id, message) { - error!("error sending commission report for execution {}: {}", execution_id, e); - } - } - } - _ => (), - } -} - #[derive(Debug)] struct SenderHash { data: RwLock>>, @@ -545,16 +616,18 @@ impl SenderHash>, // requests with request ids receive responses via this channel - shared_receiver: Option>>, // this channel is for responses that share channel based on message type - signaler: Option>, // for client to signal termination - request_id: Option, // initiating request_id - order_id: Option, // initiating order_id + receiver: Option>, // requests with request ids receive responses via this channel + sender: Option>, // requests with request ids receive responses via this channel + shared_receiver: Option>>, // this channel is for responses that share channel based on message type + signaler: Option>, // for client to signal termination + pub(crate) request_id: Option, // initiating request id + pub(crate) order_id: Option, // initiating order id + pub(crate) message_type: Option, // initiating message type } impl InternalSubscription { // Blocks until next message become available. - pub(crate) fn next(&self) -> Option { + pub(crate) fn next(&self) -> Option { if let Some(receiver) = &self.receiver { Self::receive(receiver) } else if let Some(receiver) = &self.shared_receiver { @@ -565,7 +638,7 @@ impl InternalSubscription { } // Returns message if available or immediately returns None. - pub(crate) fn try_next(&self) -> Option { + pub(crate) fn try_next(&self) -> Option { if let Some(receiver) = &self.receiver { Self::try_receive(receiver) } else if let Some(receiver) = &self.shared_receiver { @@ -576,7 +649,7 @@ impl InternalSubscription { } // Waits for next message until specified timeout. - pub(crate) fn next_timeout(&self, timeout: Duration) -> Option { + pub(crate) fn next_timeout(&self, timeout: Duration) -> Option { if let Some(receiver) = &self.receiver { Self::timeout_receive(receiver, timeout) } else if let Some(receiver) = &self.shared_receiver { @@ -586,19 +659,24 @@ impl InternalSubscription { } } - pub(crate) fn cancel(&self) -> Result<(), Error> { - Ok(()) + pub(crate) fn cancel(&self) { + if let Some(sender) = &self.sender { + if let Err(e) = sender.send(Response::Cancelled) { + warn!("error sending cancel notification: {e}") + } + } + // TODO - shared sender } - fn receive(receiver: &Receiver) -> Option { + fn receive(receiver: &Receiver) -> Option { receiver.recv().ok() } - fn try_receive(receiver: &Receiver) -> Option { + fn try_receive(receiver: &Receiver) -> Option { receiver.try_recv().ok() } - fn timeout_receive(receiver: &Receiver, timeout: Duration) -> Option { + fn timeout_receive(receiver: &Receiver, timeout: Duration) -> Option { receiver.recv_timeout(timeout).ok() } } @@ -616,30 +694,39 @@ impl Drop for InternalSubscription { } pub(crate) struct SubscriptionBuilder { - receiver: Option>, - shared_receiver: Option>>, + receiver: Option>, + sender: Option>, + shared_receiver: Option>>, signaler: Option>, order_id: Option, request_id: Option, + message_type: Option, } impl SubscriptionBuilder { pub(crate) fn new() -> Self { Self { receiver: None, + sender: None, shared_receiver: None, signaler: None, order_id: None, request_id: None, + message_type: None, } } - pub(crate) fn receiver(mut self, receiver: Receiver) -> Self { + pub(crate) fn receiver(mut self, receiver: Receiver) -> Self { self.receiver = Some(receiver); self } - pub(crate) fn shared_receiver(mut self, shared_receiver: Arc>) -> Self { + pub(crate) fn sender(mut self, sender: Sender) -> Self { + self.sender = Some(sender); + self + } + + pub(crate) fn shared_receiver(mut self, shared_receiver: Arc>) -> Self { self.shared_receiver = Some(shared_receiver); self } @@ -659,22 +746,31 @@ impl SubscriptionBuilder { self } + pub(crate) fn message_type(mut self, message_type: OutgoingMessages) -> Self { + self.message_type = Some(message_type); + self + } + pub(crate) fn build(self) -> InternalSubscription { if let (Some(receiver), Some(signaler)) = (self.receiver, self.signaler) { InternalSubscription { receiver: Some(receiver), + sender: self.sender, shared_receiver: None, signaler: Some(signaler), request_id: self.request_id, order_id: self.order_id, + message_type: self.message_type, } } else if let Some(receiver) = self.shared_receiver { InternalSubscription { receiver: None, + sender: None, shared_receiver: Some(receiver), signaler: None, request_id: self.request_id, order_id: self.order_id, + message_type: self.message_type, } } else { panic!("bad configuration"); @@ -682,5 +778,306 @@ impl SubscriptionBuilder { } } +#[derive(Default, Clone, Debug)] +pub(crate) struct ConnectionMetadata { + next_order_id: i32, + pub(crate) client_id: i32, + pub(crate) server_version: i32, + pub(crate) managed_accounts: String, + pub(crate) connection_time: Option, + pub(crate) time_zone: Option<&'static Tz>, +} + +#[derive(Debug)] +pub(crate) struct Connection { + client_id: i32, + connection_url: String, + reader: Mutex, + writer: Mutex, + connection_metadata: Mutex, + max_retries: i32, + recorder: MessageRecorder, +} + +impl Connection { + pub fn connect(client_id: i32, connection_url: &str) -> Result { + let reader = TcpStream::connect(connection_url)?; + let writer = reader.try_clone()?; + + let connection = Self { + client_id, + connection_url: connection_url.into(), + reader: Mutex::new(reader), + writer: Mutex::new(writer), + connection_metadata: Mutex::new(ConnectionMetadata::default()), + max_retries: MAX_RETRIES, + recorder: MessageRecorder::new(), + }; + + connection.establish_connection()?; + + Ok(connection) + } + + pub fn connection_metadata(&self) -> ConnectionMetadata { + let metadata = self.connection_metadata.lock().unwrap(); + metadata.clone() + } + + pub fn reconnect(&self) -> Result<(), Error> { + let mut backoff = FibonacciBackoff::new(30); + + for i in 0..self.max_retries { + let next_delay = backoff.next_delay(); + info!("next reconnection attempt in {next_delay:#?}"); + + thread::sleep(next_delay); + + match TcpStream::connect(&self.connection_url) { + Ok(stream) => { + { + let mut reader = self.reader.lock()?; + let mut writer = self.writer.lock()?; + + *reader = stream.try_clone()?; + *writer = stream; + } + + info!("reconnected !!!"); + self.establish_connection()?; + + return Ok(()); + } + Err(e) => { + error!("reconnection attempt {i} of {} failed: {e}", self.max_retries); + } + } + } + + Err(Error::ConnectionFailed) + } + + fn establish_connection(&self) -> Result<(), Error> { + self.handshake()?; + self.start_api()?; + self.receive_account_info()?; + Ok(()) + } + + fn write(&self, data: &str) -> Result<(), Error> { + let mut writer = self.writer.lock()?; + writer.write_all(data.as_bytes())?; + Ok(()) + } + + fn write_message(&self, message: &RequestMessage) -> Result<(), Error> { + let mut writer = self.writer.lock()?; + + let data = message.encode(); + debug!("-> {data:?}"); + + let data = data.as_bytes(); + + let mut packet = Vec::with_capacity(data.len() + 4); + + packet.write_u32::(data.len() as u32)?; + packet.write_all(data)?; + + writer.write_all(&packet)?; + + self.recorder.record_request(message); + + Ok(()) + } + + fn read_message(&self) -> Result { + let mut reader = self.reader.lock()?; + + let message_size = read_header(&reader)?; + let mut data = vec![0_u8; message_size]; + + reader.read_exact(&mut data)?; + + let raw_string = String::from_utf8(data)?; + debug!("<- {:?}", raw_string); + + let message = ResponseMessage::from(&raw_string); + self.recorder.record_response(&message); + + Ok(message) + } + + // sends server handshake + fn handshake(&self) -> Result<(), Error> { + let prefix = "API\0"; + let version = format!("v{MIN_SERVER_VERSION}..{MAX_SERVER_VERSION}"); + + let packet = prefix.to_owned() + &encode_packet(&version); + self.write(&packet)?; + + let ack = self.read_message(); + + let mut connection_metadata = self.connection_metadata.lock()?; + + match ack { + Ok(mut response) => { + connection_metadata.server_version = response.next_int()?; + + let time = response.next_string()?; + (connection_metadata.connection_time, connection_metadata.time_zone) = parse_connection_time(time.as_str()); + } + Err(Error::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + return Err(Error::Simple(format!("The server may be rejecting connections from this host: {err}"))); + } + Err(err) => { + return Err(err); + } + } + Ok(()) + } + + // asks server to start processing messages + fn start_api(&self) -> Result<(), Error> { + const VERSION: i32 = 2; + + let prelude = &mut RequestMessage::default(); + + prelude.push_field(&OutgoingMessages::StartApi); + prelude.push_field(&VERSION); + prelude.push_field(&self.client_id); + + if self.server_version() > server_versions::OPTIONAL_CAPABILITIES { + prelude.push_field(&""); + } + + self.write_message(prelude)?; + + Ok(()) + } + + fn server_version(&self) -> i32 { + let connection_metadata = self.connection_metadata.lock().unwrap(); + connection_metadata.server_version + } + + // Fetches next order id and managed accounts. + fn receive_account_info(&self) -> Result<(), Error> { + let mut saw_next_order_id: bool = false; + let mut saw_managed_accounts: bool = false; + + let mut attempts = 0; + const MAX_ATTEMPTS: i32 = 100; + loop { + let mut message = self.read_message()?; + + match message.message_type() { + IncomingMessages::NextValidId => { + saw_next_order_id = true; + + message.skip(); // message type + message.skip(); // message version + + let mut connection_metadata = self.connection_metadata.lock()?; + connection_metadata.next_order_id = message.next_int()?; + } + IncomingMessages::ManagedAccounts => { + saw_managed_accounts = true; + + message.skip(); // message type + message.skip(); // message version + + let mut connection_metadata = self.connection_metadata.lock()?; + connection_metadata.managed_accounts = message.next_string()?; + } + IncomingMessages::Error => { + error!("message: {message:?}") + } + _ => info!("message: {message:?}"), + } + + attempts += 1; + if (saw_next_order_id && saw_managed_accounts) || attempts > MAX_ATTEMPTS { + break; + } + } + + Ok(()) + } +} + +struct FibonacciBackoff { + previous: u64, + current: u64, + max: u64, +} + +impl FibonacciBackoff { + fn new(max: u64) -> Self { + FibonacciBackoff { + previous: 0, + current: 1, + max, + } + } + + fn next_delay(&mut self) -> Duration { + let next = self.previous + self.current; + self.previous = self.current; + self.current = next; + + if next > self.max { + Duration::from_secs(self.max) + } else { + Duration::from_secs(next) + } + } + + fn reset(&mut self) { + self.previous = 0; + self.current = 1; + } +} + +// Parses following format: 20230405 22:20:39 PST +fn parse_connection_time(connection_time: &str) -> (Option, Option<&'static Tz>) { + let parts: Vec<&str> = connection_time.split(' ').collect(); + + let zones = timezones::find_by_name(parts[2]); + if zones.is_empty() { + error!("time zone not found for {}", parts[2]); + return (None, None); + } + + let timezone = zones[0]; + + let format = format_description!("[year][month][day] [hour]:[minute]:[second]"); + let date_str = format!("{} {}", parts[0], parts[1]); + let date = time::PrimitiveDateTime::parse(date_str.as_str(), format); + match date { + Ok(connected_at) => match connected_at.assume_timezone(timezone) { + OffsetResult::Some(date) => (Some(date), Some(timezone)), + _ => { + error!("error setting timezone"); + (None, Some(timezone)) + } + }, + Err(err) => { + error!("could not parse connection time from {date_str}: {err}"); + (None, Some(timezone)) + } + } +} + +fn encode_packet(message: &str) -> String { + let data = message.as_bytes(); + + let mut packet: Vec = Vec::with_capacity(data.len() + 4); + + packet.write_u32::(data.len() as u32).unwrap(); + packet.write_all(data).unwrap(); + + std::str::from_utf8(&packet).unwrap().into() +} + #[cfg(test)] mod tests; diff --git a/src/transport/tests.rs b/src/transport/tests.rs index 8b137891..eb3a0746 100644 --- a/src/transport/tests.rs +++ b/src/transport/tests.rs @@ -1 +1,40 @@ +use time::macros::datetime; +use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt}; +use crate::tests::assert_send_and_sync; + +use super::*; + +#[test] +fn test_thread_safe() { + assert_send_and_sync::(); + assert_send_and_sync::(); +} + +#[test] +fn test_parse_connection_time() { + let example = "20230405 22:20:39 PST"; + let (connection_time, _) = parse_connection_time(example); + + let la = timezones::db::america::LOS_ANGELES; + if let OffsetResult::Some(other) = datetime!(2023-04-05 22:20:39).assume_timezone(la) { + assert_eq!(connection_time, Some(other)); + } +} + +#[test] +fn test_fibonacci_backoff() { + let mut backoff = FibonacciBackoff::new(10); + + assert_eq!(backoff.next_delay(), Duration::from_secs(1)); + assert_eq!(backoff.next_delay(), Duration::from_secs(2)); + assert_eq!(backoff.next_delay(), Duration::from_secs(3)); + assert_eq!(backoff.next_delay(), Duration::from_secs(5)); + assert_eq!(backoff.next_delay(), Duration::from_secs(8)); + assert_eq!(backoff.next_delay(), Duration::from_secs(10)); + assert_eq!(backoff.next_delay(), Duration::from_secs(10)); + + backoff.reset(); + + assert_eq!(backoff.next_delay(), Duration::from_secs(1)); +}