From 9c445efa4069be6d2e01cb5825fc075b56cfa201 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 8 Nov 2024 23:32:40 -0800 Subject: [PATCH 1/5] support serialization --- src/market_data/historical.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/market_data/historical.rs b/src/market_data/historical.rs index ac3b1484..8fcccb0e 100644 --- a/src/market_data/historical.rs +++ b/src/market_data/historical.rs @@ -2,6 +2,7 @@ use std::collections::VecDeque; use std::fmt::{Debug, Display}; use log::{error, warn}; +use serde::{Deserialize, Serialize}; use time::{Date, OffsetDateTime}; use crate::contracts::Contract; @@ -15,7 +16,7 @@ mod encoders; mod tests; /// Bar describes the historical data bar. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Copy, Serialize, Deserialize)] pub struct Bar { /// The bar's date and time (either as a yyyymmss hh:mm:ss formatted string or as system time according to the request). Time zone is the TWS time zone chosen on login. // pub time: OffsetDateTime, @@ -36,7 +37,7 @@ pub struct Bar { pub count: i32, } -#[derive(Clone, Debug, Copy)] +#[derive(Clone, Debug, Copy, PartialEq, Serialize, Deserialize)] pub enum BarSize { Sec, Sec5, @@ -91,7 +92,7 @@ impl ToField for BarSize { } } -#[derive(Clone, Debug, Copy)] +#[derive(Clone, Debug, Copy, PartialEq, Serialize, Deserialize)] pub struct Duration { value: i32, unit: char, @@ -166,20 +167,20 @@ impl ToDuration for i32 { } } -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub struct HistogramEntry { pub price: f64, pub size: i32, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct HistoricalData { pub start: OffsetDateTime, pub end: OffsetDateTime, pub bars: Vec, } -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct Schedule { pub start: OffsetDateTime, pub end: OffsetDateTime, @@ -187,7 +188,7 @@ pub struct Schedule { pub sessions: Vec, } -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub struct Session { pub reference: Date, pub start: OffsetDateTime, @@ -195,7 +196,7 @@ pub struct Session { } /// The historical tick's description. Used when requesting historical tick data with whatToShow = MIDPOINT -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub struct TickMidpoint { /// timestamp of the historical tick. pub timestamp: OffsetDateTime, @@ -206,7 +207,7 @@ pub struct TickMidpoint { } /// The historical tick's description. Used when requesting historical tick data with whatToShow = BID_ASK. -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub struct TickBidAsk { /// Timestamp of the historical tick. pub timestamp: OffsetDateTime, @@ -222,14 +223,14 @@ pub struct TickBidAsk { pub size_ask: i32, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub struct TickAttributeBidAsk { pub bid_past_low: bool, pub ask_past_high: bool, } /// The historical last tick's description. Used when requesting historical tick data with whatToShow = TRADES. -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct TickLast { /// Timestamp of the historical tick. pub timestamp: OffsetDateTime, @@ -245,7 +246,7 @@ pub struct TickLast { pub special_conditions: String, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)] pub struct TickAttributeLast { pub past_limit: bool, pub unreported: bool, From 9577a84c6dc97ee45254b6bc3af61e8aa4f741b3 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 9 Nov 2024 09:11:58 -0800 Subject: [PATCH 2/5] refactor subscription use --- src/market_data/historical.rs | 38 +++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/market_data/historical.rs b/src/market_data/historical.rs index 8fcccb0e..cf005323 100644 --- a/src/market_data/historical.rs +++ b/src/market_data/historical.rs @@ -302,13 +302,14 @@ pub(crate) fn head_timestamp(client: &Client, contract: &Contract, what_to_show: let request_id = client.next_request_id(); let request = encoders::encode_request_head_timestamp(request_id, contract, what_to_show, use_rth)?; - let subscription = client.send_request(request_id, request)?; - if let Some(Ok(mut message)) = subscription.next() { - decoders::decode_head_timestamp(&mut message) - } else { - Err(Error::Simple("did not receive head timestamp message".into())) + match subscription.next() { + Some(Ok(mut message)) if message.message_type() == IncomingMessages::HeadTimestamp => Ok(decoders::decode_head_timestamp(&mut message)?), + Some(Ok(message)) => Err(Error::UnexpectedResponse(message)), + Some(Err(Error::ConnectionReset)) => head_timestamp(client, contract, what_to_show, use_rth), + Some(Err(e)) => Err(e), + None => Err(Error::UnexpectedEndOfStream), } } @@ -352,20 +353,23 @@ pub(crate) fn historical_data( let subscription = client.send_request(request_id, request)?; - if let Some(Ok(mut message)) = subscription.next() { - let time_zone = if let Some(tz) = client.time_zone { - tz - } else { - warn!("server timezone unknown. assuming UTC, but that may be incorrect!"); - time_tz::timezones::db::UTC - }; - match message.message_type() { - IncomingMessages::HistoricalData => decoders::decode_historical_data(client.server_version, time_zone, &mut message), - IncomingMessages::Error => Err(Error::Simple(message.peek_string(4))), - _ => Err(Error::Simple(format!("unexpected message: {:?}", message.message_type()))), + match subscription.next() { + Some(Ok(mut message)) if message.message_type() == IncomingMessages::HistoricalData => { + Ok(decoders::decode_historical_data(client.server_version, time_zone(client), &mut message)?) } + Some(Ok(message)) => Err(Error::UnexpectedResponse(message)), + Some(Err(Error::ConnectionReset)) => historical_data(client, contract, end_date, duration, bar_size, what_to_show, use_rth), + Some(Err(e)) => Err(e), + None => Err(Error::UnexpectedEndOfStream), + } +} + +fn time_zone(client: &Client) -> &time_tz::Tz { + if let Some(tz) = client.time_zone { + tz } else { - Err(Error::Simple("did not receive historical data response".into())) + warn!("server timezone unknown. assuming UTC, but that may be incorrect!"); + time_tz::timezones::db::UTC } } From dfdd6e3fb7e4e3ff8d05cb680b0c8542a3489d64 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 9 Nov 2024 15:05:15 -0800 Subject: [PATCH 3/5] refactor into subscription --- src/client.rs | 18 +-- src/market_data/historical.rs | 217 ++++++++++++++++++++++++---- src/market_data/historical/tests.rs | 2 +- src/messages.rs | 2 +- 4 files changed, 198 insertions(+), 41 deletions(-) diff --git a/src/client.rs b/src/client.rs index c88aa32e..f75df282 100644 --- a/src/client.rs +++ b/src/client.rs @@ -928,7 +928,7 @@ impl Client { number_of_ticks: i32, use_rth: bool, ignore_size: bool, - ) -> Result, Error> { + ) -> Result, Error> { historical::historical_ticks_bid_ask(self, contract, start, end, number_of_ticks, use_rth, ignore_size) } @@ -968,7 +968,7 @@ impl Client { end: Option, number_of_ticks: i32, use_rth: bool, - ) -> Result, Error> { + ) -> Result, Error> { historical::historical_ticks_mid_point(self, contract, start, end, number_of_ticks, use_rth) } @@ -1008,7 +1008,7 @@ impl Client { end: Option, number_of_ticks: i32, use_rth: bool, - ) -> Result, Error> { + ) -> Result, Error> { historical::historical_ticks_trade(self, contract, start, end, number_of_ticks, use_rth) } @@ -1650,7 +1650,7 @@ pub struct Subscription<'a, T: DataStream> { cancelled: AtomicBool, subscription: InternalSubscription, response_context: ResponseContext, - error: Arc>>, + error: Mutex>, } // Extra metadata that might be need @@ -1672,7 +1672,7 @@ impl<'a, T: DataStream> Subscription<'a, T> { response_context: context, phantom: PhantomData, cancelled: AtomicBool::new(false), - error: Arc::new(Mutex::new(None)), + error: Mutex::new(None), } } else if let Some(order_id) = subscription.order_id { Subscription { @@ -1684,7 +1684,7 @@ impl<'a, T: DataStream> Subscription<'a, T> { response_context: context, phantom: PhantomData, cancelled: AtomicBool::new(false), - error: Arc::new(Mutex::new(None)), + error: Mutex::new(None), } } else if let Some(message_type) = subscription.message_type { Subscription { @@ -1696,7 +1696,7 @@ impl<'a, T: DataStream> Subscription<'a, T> { response_context: context, phantom: PhantomData, cancelled: AtomicBool::new(false), - error: Arc::new(Mutex::new(None)), + error: Mutex::new(None), } } else { panic!("unsupported internal subscription: {:?}", subscription) @@ -1740,8 +1740,6 @@ impl<'a, T: DataStream> Subscription<'a, T> { /// * `Some(T)` - The next available item from the subscription /// * `None` - If the subscription has ended or encountered an error pub fn next(&self) -> Option { - self.clear_error(); - match self.process_response(self.subscription.next()) { Some(val) => Some(val), None => match self.error() { @@ -1755,6 +1753,8 @@ impl<'a, T: DataStream> Subscription<'a, T> { } fn process_response(&self, response: Option>) -> Option { + self.clear_error(); + match response { Some(Ok(message)) => self.process_message(message), Some(Err(e)) => { diff --git a/src/market_data/historical.rs b/src/market_data/historical.rs index cf005323..bf4ac277 100644 --- a/src/market_data/historical.rs +++ b/src/market_data/historical.rs @@ -1,7 +1,9 @@ use std::collections::VecDeque; use std::fmt::{Debug, Display}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; -use log::{error, warn}; +use log::{debug, error, warn}; use serde::{Deserialize, Serialize}; use time::{Date, OffsetDateTime}; @@ -407,14 +409,14 @@ pub(crate) fn historical_schedule( let subscription = client.send_request(request_id, request)?; - if let Some(Ok(mut message)) = subscription.next() { - match message.message_type() { - IncomingMessages::HistoricalSchedule => decoders::decode_historical_schedule(&mut message), - IncomingMessages::Error => Err(Error::Simple(message.peek_string(4))), - _ => Err(Error::Simple(format!("unexpected message: {:?}", message.message_type()))), + match subscription.next() { + Some(Ok(mut message)) if message.message_type() == IncomingMessages::HistoricalSchedule => { + Ok(decoders::decode_historical_schedule(&mut message)?) } - } else { - Err(Error::Simple("did not receive historical schedule response".into())) + Some(Ok(message)) => Err(Error::UnexpectedResponse(message)), + Some(Err(Error::ConnectionReset)) => historical_schedule(client, contract, end_date, duration), + Some(Err(e)) => Err(e), + None => Err(Error::UnexpectedEndOfStream), } } @@ -426,7 +428,7 @@ pub(crate) fn historical_ticks_bid_ask( number_of_ticks: i32, use_rth: bool, ignore_size: bool, -) -> Result, Error> { +) -> Result, Error> { client.check_server_version(server_versions::HISTORICAL_TICKS, "It does not support historical ticks request.")?; let request_id = client.next_request_id(); @@ -443,7 +445,7 @@ pub(crate) fn historical_ticks_bid_ask( let messages = client.send_request(request_id, message)?; - Ok(TickIterator::new(messages)) + Ok(TickSubscription::new(messages)) } pub(crate) fn historical_ticks_mid_point( @@ -453,7 +455,7 @@ pub(crate) fn historical_ticks_mid_point( end: Option, number_of_ticks: i32, use_rth: bool, -) -> Result, Error> { +) -> Result, Error> { client.check_server_version(server_versions::HISTORICAL_TICKS, "It does not support historical ticks request.")?; let request_id = client.next_request_id(); @@ -461,7 +463,7 @@ pub(crate) fn historical_ticks_mid_point( let messages = client.send_request(request_id, message)?; - Ok(TickIterator::new(messages)) + Ok(TickSubscription::new(messages)) } pub(crate) fn historical_ticks_trade( @@ -471,7 +473,7 @@ pub(crate) fn historical_ticks_trade( end: Option, number_of_ticks: i32, use_rth: bool, -) -> Result, Error> { +) -> Result, Error> { client.check_server_version(server_versions::HISTORICAL_TICKS, "It does not support historical ticks request.")?; let request_id = client.next_request_id(); @@ -479,7 +481,7 @@ pub(crate) fn historical_ticks_trade( let messages = client.send_request(request_id, message)?; - Ok(TickIterator::new(messages)) + Ok(TickSubscription::new(messages)) } pub(crate) fn histogram_data(client: &Client, contract: &Contract, use_rth: bool, period: BarSize) -> Result, Error> { @@ -497,7 +499,7 @@ pub(crate) fn histogram_data(client: &Client, contract: &Contract, use_rth: bool } } -pub(crate) trait TickDecoder { +pub trait TickDecoder { fn decode(message: &mut ResponseMessage) -> Result<(Vec, bool), Error>; fn message_type() -> IncomingMessages; } @@ -529,31 +531,144 @@ impl TickDecoder for TickMidpoint { } } -pub(crate) struct TickIterator> { - done: bool, +pub struct TickSubscription> { + done: AtomicBool, messages: InternalSubscription, - buffer: VecDeque, + buffer: Mutex>, + error: Mutex>, } -impl> TickIterator { +impl> TickSubscription { fn new(messages: InternalSubscription) -> Self { Self { - done: false, + done: false.into(), messages, - buffer: VecDeque::new(), + buffer: Mutex::new(VecDeque::new()), + error: Mutex::new(None), + } + } + + pub fn next(&self) -> Option { + self.clear_error(); + + if let Some(message) = self.next_buffered() { + return Some(message); + } + + if self.done.load(Ordering::Relaxed) { + return None; + } + + match self.messages.next() { + Some(Ok(message)) if message.message_type() == T::message_type() => { + self.fill_buffer(message); + self.next() + } + Some(Ok(message)) => { + debug!("unexpected message: {:?}", message); + self.next() + } + Some(Err(e)) => { + self.set_error(e); + None + } + None => None, } } + + pub fn try_next(&self) -> Option { + self.clear_error(); + + if let Some(message) = self.next_buffered() { + return Some(message); + } + + if self.done.load(Ordering::Relaxed) { + return None; + } + + match self.messages.try_next() { + Some(Ok(message)) if message.message_type() == T::message_type() => { + self.fill_buffer(message); + self.try_next() + } + Some(Ok(message)) => { + debug!("unexpected message: {:?}", message); + self.try_next() + } + Some(Err(e)) => { + self.set_error(e); + None + } + None => None, + } + } + + pub fn next_timeout(&self, duration: std::time::Duration) -> Option { + self.clear_error(); + + if let Some(message) = self.next_buffered() { + return Some(message); + } + + if self.done.load(Ordering::Relaxed) { + return None; + } + + match self.messages.next_timeout(duration) { + Some(Ok(message)) if message.message_type() == T::message_type() => { + self.fill_buffer(message); + self.next_timeout(duration) + } + Some(Ok(message)) => { + debug!("unexpected message: {:?}", message); + self.next_timeout(duration) + } + Some(Err(e)) => { + self.set_error(e); + None + } + None => None, + } + } + + fn next_buffered(&self) -> Option { + let mut buffer = self.buffer.lock().unwrap(); + buffer.pop_front() + } + + fn set_error(&self, e: Error) { + let mut error = self.error.lock().unwrap(); + *error = Some(e); + } + + fn clear_error(&self) { + let mut error = self.error.lock().unwrap(); + *error = None; + } + + fn fill_buffer(&self, mut message: ResponseMessage) { + let mut buffer = self.buffer.lock().unwrap(); + + let (ticks, done) = T::decode(&mut message).unwrap(); + + buffer.append(&mut ticks.into()); + self.done.store(done, Ordering::Relaxed); + } } -impl + Debug> Iterator for TickIterator { +impl + Debug> Iterator for TickSubscription { type Item = T; fn next(&mut self) -> Option { - if !self.buffer.is_empty() { - return self.buffer.pop_front(); + { + let mut buffer = self.buffer.lock().unwrap(); + if !buffer.is_empty() { + return buffer.pop_front(); + } } - if self.done { + if self.done.load(Ordering::Relaxed) { return None; } @@ -561,17 +676,19 @@ impl + Debug> Iterator for TickIterator { match self.messages.next() { Some(Ok(mut message)) => { if message.message_type() == Self::Item::message_type() { + let mut buffer = self.buffer.lock().unwrap(); + let (ticks, done) = Self::Item::decode(&mut message).unwrap(); - self.buffer.append(&mut ticks.into()); - self.done = done; + buffer.append(&mut ticks.into()); + self.done.store(done, Ordering::Relaxed); - if self.buffer.is_empty() && self.done { + if buffer.is_empty() && self.done.load(Ordering::Relaxed) { return None; } - if !self.buffer.is_empty() { - return self.buffer.pop_front(); + if !buffer.is_empty() { + return buffer.pop_front(); } } else if message.message_type() == IncomingMessages::Error { error!("error reading ticks: {:?}", message.peek_string(4)); @@ -586,3 +703,43 @@ impl + Debug> Iterator for TickIterator { } } } + +/// An iterator that yields items as they become available, blocking if necessary. +pub struct TickSubscriptionIter<'a, T: TickDecoder> { + subscription: &'a TickSubscription, +} + +impl<'a, T: TickDecoder> Iterator for TickSubscriptionIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.subscription.next() + } +} + +/// An iterator that yields items if they are available, without waiting. +pub struct TickSubscriptionTryIter<'a, T: TickDecoder> { + subscription: &'a TickSubscription, +} + +impl<'a, T: TickDecoder> Iterator for TickSubscriptionTryIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.subscription.try_next() + } +} + +/// An iterator that waits for the specified timeout duration for available data. +pub struct TickSubscriptionTimeoutIter<'a, T: TickDecoder> { + subscription: &'a TickSubscription, + timeout: std::time::Duration, +} + +impl<'a, T: TickDecoder> Iterator for TickSubscriptionTimeoutIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.subscription.next_timeout(self.timeout) + } +} diff --git a/src/market_data/historical/tests.rs b/src/market_data/historical/tests.rs index 03a2fc28..90cb03d0 100644 --- a/src/market_data/historical/tests.rs +++ b/src/market_data/historical/tests.rs @@ -12,7 +12,7 @@ use super::*; fn test_head_timestamp() { let message_bus = Arc::new(MessageBusStub { request_messages: RwLock::new(vec![]), - response_messages: vec!["9|9000|1678323335|".to_owned()], + response_messages: vec!["88|9000|1678323335|".to_owned()], }); let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); diff --git a/src/messages.rs b/src/messages.rs index b27addbb..dcc4a731 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -392,7 +392,7 @@ impl Index for RequestMessage { } #[derive(Clone, Default, Debug)] -pub(crate) struct ResponseMessage { +pub struct ResponseMessage { pub i: usize, pub fields: Vec, } From 0a3d31ecf721afcaf971216f0c223ca65f722312 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 9 Nov 2024 15:38:44 -0800 Subject: [PATCH 4/5] implement into iterator --- src/market_data/historical.rs | 108 ++++++++++++++++------------------ 1 file changed, 52 insertions(+), 56 deletions(-) diff --git a/src/market_data/historical.rs b/src/market_data/historical.rs index bf4ac277..d0b23579 100644 --- a/src/market_data/historical.rs +++ b/src/market_data/historical.rs @@ -3,7 +3,7 @@ use std::fmt::{Debug, Display}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Mutex; -use log::{debug, error, warn}; +use log::{debug, warn}; use serde::{Deserialize, Serialize}; use time::{Date, OffsetDateTime}; @@ -500,35 +500,32 @@ pub(crate) fn histogram_data(client: &Client, contract: &Contract, use_rth: bool } pub trait TickDecoder { + const MESSAGE_TYPE: IncomingMessages; fn decode(message: &mut ResponseMessage) -> Result<(Vec, bool), Error>; - fn message_type() -> IncomingMessages; } impl TickDecoder for TickBidAsk { + const MESSAGE_TYPE: IncomingMessages = IncomingMessages::HistoricalTickBidAsk; + fn decode(message: &mut ResponseMessage) -> Result<(Vec, bool), Error> { decoders::decode_historical_ticks_bid_ask(message) } - fn message_type() -> IncomingMessages { - IncomingMessages::HistoricalTickBidAsk - } } impl TickDecoder for TickLast { + const MESSAGE_TYPE: IncomingMessages = IncomingMessages::HistoricalTickLast; + fn decode(message: &mut ResponseMessage) -> Result<(Vec, bool), Error> { decoders::decode_historical_ticks_last(message) } - fn message_type() -> IncomingMessages { - IncomingMessages::HistoricalTickLast - } } impl TickDecoder for TickMidpoint { + const MESSAGE_TYPE: IncomingMessages = IncomingMessages::HistoricalTick; + fn decode(message: &mut ResponseMessage) -> Result<(Vec, bool), Error> { decoders::decode_historical_ticks_mid_point(message) } - fn message_type() -> IncomingMessages { - IncomingMessages::HistoricalTick - } } pub struct TickSubscription> { @@ -548,6 +545,21 @@ impl> TickSubscription { } } + pub fn iter(&self) -> TickSubscriptionIter { + TickSubscriptionIter { subscription: self } + } + + pub fn try_iter(&self) -> TickSubscriptionTryIter { + TickSubscriptionTryIter { subscription: self } + } + + pub fn timeout_iter(&self, duration: std::time::Duration) -> TickSubscriptionTimeoutIter { + TickSubscriptionTimeoutIter { + subscription: self, + timeout: duration, + } + } + pub fn next(&self) -> Option { self.clear_error(); @@ -560,7 +572,7 @@ impl> TickSubscription { } match self.messages.next() { - Some(Ok(message)) if message.message_type() == T::message_type() => { + Some(Ok(message)) if message.message_type() == T::MESSAGE_TYPE => { self.fill_buffer(message); self.next() } @@ -588,7 +600,7 @@ impl> TickSubscription { } match self.messages.try_next() { - Some(Ok(message)) if message.message_type() == T::message_type() => { + Some(Ok(message)) if message.message_type() == T::MESSAGE_TYPE => { self.fill_buffer(message); self.try_next() } @@ -616,7 +628,7 @@ impl> TickSubscription { } match self.messages.next_timeout(duration) { - Some(Ok(message)) if message.message_type() == T::message_type() => { + Some(Ok(message)) if message.message_type() == T::MESSAGE_TYPE => { self.fill_buffer(message); self.next_timeout(duration) } @@ -657,59 +669,34 @@ impl> TickSubscription { } } -impl + Debug> Iterator for TickSubscription { +/// An iterator that yields items as they become available, blocking if necessary. +pub struct TickSubscriptionIter<'a, T: TickDecoder> { + subscription: &'a TickSubscription, +} + +impl<'a, T: TickDecoder> Iterator for TickSubscriptionIter<'a, T> { type Item = T; fn next(&mut self) -> Option { - { - let mut buffer = self.buffer.lock().unwrap(); - if !buffer.is_empty() { - return buffer.pop_front(); - } - } + self.subscription.next() + } +} - if self.done.load(Ordering::Relaxed) { - return None; - } +impl<'a, T: TickDecoder> IntoIterator for &'a TickSubscription { + type Item = T; + type IntoIter = TickSubscriptionIter<'a, T>; - loop { - match self.messages.next() { - Some(Ok(mut message)) => { - if message.message_type() == Self::Item::message_type() { - let mut buffer = self.buffer.lock().unwrap(); - - let (ticks, done) = Self::Item::decode(&mut message).unwrap(); - - buffer.append(&mut ticks.into()); - self.done.store(done, Ordering::Relaxed); - - if buffer.is_empty() && self.done.load(Ordering::Relaxed) { - return None; - } - - if !buffer.is_empty() { - return buffer.pop_front(); - } - } else if message.message_type() == IncomingMessages::Error { - error!("error reading ticks: {:?}", message.peek_string(4)); - return None; - } else { - error!("unexpected message: {:?}", message) - } - } - // TODO enumerate - _ => return None, - } - } + fn into_iter(self) -> Self::IntoIter { + self.iter() } } /// An iterator that yields items as they become available, blocking if necessary. -pub struct TickSubscriptionIter<'a, T: TickDecoder> { - subscription: &'a TickSubscription, +pub struct TickSubscriptionOwnedIter> { + subscription: TickSubscription, } -impl<'a, T: TickDecoder> Iterator for TickSubscriptionIter<'a, T> { +impl> Iterator for TickSubscriptionOwnedIter { type Item = T; fn next(&mut self) -> Option { @@ -717,6 +704,15 @@ impl<'a, T: TickDecoder> Iterator for TickSubscriptionIter<'a, T> { } } +impl> IntoIterator for TickSubscription { + type Item = T; + type IntoIter = TickSubscriptionOwnedIter; + + fn into_iter(self) -> Self::IntoIter { + TickSubscriptionOwnedIter { subscription: self } + } +} + /// An iterator that yields items if they are available, without waiting. pub struct TickSubscriptionTryIter<'a, T: TickDecoder> { subscription: &'a TickSubscription, From 317bd765523d6a8e3982df31ab86124c0844b3fc Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 9 Nov 2024 21:13:48 -0800 Subject: [PATCH 5/5] cleanup duplications --- src/market_data/historical.rs | 228 +++++++++++++++------------------- src/transport.rs | 2 +- 2 files changed, 99 insertions(+), 131 deletions(-) diff --git a/src/market_data/historical.rs b/src/market_data/historical.rs index d0b23579..f74ab76d 100644 --- a/src/market_data/historical.rs +++ b/src/market_data/historical.rs @@ -9,7 +9,7 @@ use time::{Date, OffsetDateTime}; use crate::contracts::Contract; use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; -use crate::transport::InternalSubscription; +use crate::transport::{InternalSubscription, Response}; use crate::{server_versions, Client, Error, ToField}; mod decoders; @@ -339,30 +339,32 @@ pub(crate) fn historical_data( )?; } - let request_id = client.next_request_id(); - let request = encoders::encode_request_historical_data( - client.server_version(), - request_id, - contract, - end_date, - duration, - bar_size, - what_to_show, - use_rth, - false, - Vec::::default(), - )?; + loop { + let request_id = client.next_request_id(); + let request = encoders::encode_request_historical_data( + client.server_version(), + request_id, + contract, + end_date, + duration, + bar_size, + what_to_show, + use_rth, + false, + Vec::::default(), + )?; - let subscription = client.send_request(request_id, request)?; + let subscription = client.send_request(request_id, request)?; - match subscription.next() { - Some(Ok(mut message)) if message.message_type() == IncomingMessages::HistoricalData => { - Ok(decoders::decode_historical_data(client.server_version, time_zone(client), &mut message)?) + match subscription.next() { + Some(Ok(mut message)) if message.message_type() == IncomingMessages::HistoricalData => { + return decoders::decode_historical_data(client.server_version, time_zone(client), &mut message) + } + Some(Ok(message)) => return Err(Error::UnexpectedResponse(message)), + Some(Err(Error::ConnectionReset)) => continue, + Some(Err(e)) => return Err(e), + None => return Err(Error::UnexpectedEndOfStream), } - Some(Ok(message)) => Err(Error::UnexpectedResponse(message)), - Some(Err(Error::ConnectionReset)) => historical_data(client, contract, end_date, duration, bar_size, what_to_show, use_rth), - Some(Err(e)) => Err(e), - None => Err(Error::UnexpectedEndOfStream), } } @@ -393,30 +395,32 @@ pub(crate) fn historical_schedule( "It does not support requesting of historical schedule.", )?; - let request_id = client.next_request_id(); - let request = encoders::encode_request_historical_data( - client.server_version(), - request_id, - contract, - end_date, - duration, - BarSize::Day, - Some(WhatToShow::Schedule), - true, - false, - Vec::::default(), - )?; + loop { + let request_id = client.next_request_id(); + let request = encoders::encode_request_historical_data( + client.server_version(), + request_id, + contract, + end_date, + duration, + BarSize::Day, + Some(WhatToShow::Schedule), + true, + false, + Vec::::default(), + )?; - let subscription = client.send_request(request_id, request)?; + let subscription = client.send_request(request_id, request)?; - match subscription.next() { - Some(Ok(mut message)) if message.message_type() == IncomingMessages::HistoricalSchedule => { - Ok(decoders::decode_historical_schedule(&mut message)?) + match subscription.next() { + Some(Ok(mut message)) if message.message_type() == IncomingMessages::HistoricalSchedule => { + return decoders::decode_historical_schedule(&mut message) + } + Some(Ok(message)) => return Err(Error::UnexpectedResponse(message)), + Some(Err(Error::ConnectionReset)) => continue, + Some(Err(e)) => return Err(e), + None => return Err(Error::UnexpectedEndOfStream), } - Some(Ok(message)) => Err(Error::UnexpectedResponse(message)), - Some(Err(Error::ConnectionReset)) => historical_schedule(client, contract, end_date, duration), - Some(Err(e)) => Err(e), - None => Err(Error::UnexpectedEndOfStream), } } @@ -432,7 +436,7 @@ pub(crate) fn historical_ticks_bid_ask( client.check_server_version(server_versions::HISTORICAL_TICKS, "It does not support historical ticks request.")?; let request_id = client.next_request_id(); - let message = encoders::encode_request_historical_ticks( + let request = encoders::encode_request_historical_ticks( request_id, contract, start, @@ -442,10 +446,9 @@ pub(crate) fn historical_ticks_bid_ask( use_rth, ignore_size, )?; + let subscription = client.send_request(request_id, request)?; - let messages = client.send_request(request_id, message)?; - - Ok(TickSubscription::new(messages)) + Ok(TickSubscription::new(subscription)) } pub(crate) fn historical_ticks_mid_point( @@ -459,11 +462,10 @@ pub(crate) fn historical_ticks_mid_point( client.check_server_version(server_versions::HISTORICAL_TICKS, "It does not support historical ticks request.")?; let request_id = client.next_request_id(); - let message = encoders::encode_request_historical_ticks(request_id, contract, start, end, number_of_ticks, WhatToShow::MidPoint, use_rth, false)?; - - let messages = client.send_request(request_id, message)?; + let request = encoders::encode_request_historical_ticks(request_id, contract, start, end, number_of_ticks, WhatToShow::MidPoint, use_rth, false)?; + let subscription = client.send_request(request_id, request)?; - Ok(TickSubscription::new(messages)) + Ok(TickSubscription::new(subscription)) } pub(crate) fn historical_ticks_trade( @@ -477,25 +479,26 @@ pub(crate) fn historical_ticks_trade( client.check_server_version(server_versions::HISTORICAL_TICKS, "It does not support historical ticks request.")?; let request_id = client.next_request_id(); - let message = encoders::encode_request_historical_ticks(request_id, contract, start, end, number_of_ticks, WhatToShow::Trades, use_rth, false)?; - - let messages = client.send_request(request_id, message)?; + let request = encoders::encode_request_historical_ticks(request_id, contract, start, end, number_of_ticks, WhatToShow::Trades, use_rth, false)?; + let subscription = client.send_request(request_id, request)?; - Ok(TickSubscription::new(messages)) + Ok(TickSubscription::new(subscription)) } pub(crate) fn histogram_data(client: &Client, contract: &Contract, use_rth: bool, period: BarSize) -> Result, Error> { client.check_server_version(server_versions::REQ_HISTOGRAM, "It does not support histogram data requests.")?; - let request_id = client.next_request_id(); - let message = encoders::encode_request_histogram_data(request_id, contract, use_rth, period)?; - - let subscription = client.send_request(request_id, message)?; + loop { + let request_id = client.next_request_id(); + let request = encoders::encode_request_histogram_data(request_id, contract, use_rth, period)?; + let subscription = client.send_request(request_id, request)?; - match subscription.next() { - Some(Ok(mut message)) => decoders::decode_histogram_data(&mut message), - Some(Err(e)) => Err(e), - None => Ok(Vec::new()), + match subscription.next() { + Some(Ok(mut message)) => return decoders::decode_histogram_data(&mut message), + Some(Err(Error::ConnectionReset)) => continue, + Some(Err(e)) => return Err(e), + None => return Ok(Vec::new()), + } } } @@ -561,86 +564,60 @@ impl> TickSubscription { } pub fn next(&self) -> Option { - self.clear_error(); - - if let Some(message) = self.next_buffered() { - return Some(message); - } - - if self.done.load(Ordering::Relaxed) { - return None; - } - - match self.messages.next() { - Some(Ok(message)) if message.message_type() == T::MESSAGE_TYPE => { - self.fill_buffer(message); - self.next() - } - Some(Ok(message)) => { - debug!("unexpected message: {:?}", message); - self.next() - } - Some(Err(e)) => { - self.set_error(e); - None - } - None => None, - } + self.next_helper(|| self.messages.next()) } pub fn try_next(&self) -> Option { - self.clear_error(); + self.next_helper(|| self.messages.try_next()) + } - if let Some(message) = self.next_buffered() { - return Some(message); - } + pub fn next_timeout(&self, duration: std::time::Duration) -> Option { + self.next_helper(|| self.messages.next_timeout(duration)) + } - if self.done.load(Ordering::Relaxed) { - return None; - } + fn next_helper(&self, next_response: F) -> Option + where + F: Fn() -> Option, + { + self.clear_error(); - match self.messages.try_next() { - Some(Ok(message)) if message.message_type() == T::MESSAGE_TYPE => { - self.fill_buffer(message); - self.try_next() + loop { + if let Some(message) = self.next_buffered() { + return Some(message); } - Some(Ok(message)) => { - debug!("unexpected message: {:?}", message); - self.try_next() + + if self.done.load(Ordering::Relaxed) { + return None; } - Some(Err(e)) => { - self.set_error(e); - None + + match self.fill_buffer(next_response()) { + Ok(()) => continue, + Err(()) => return None, } - None => None, } } - pub fn next_timeout(&self, duration: std::time::Duration) -> Option { - self.clear_error(); + fn fill_buffer(&self, response: Option) -> Result<(), ()> { + match response { + Some(Ok(mut message)) if message.message_type() == T::MESSAGE_TYPE => { + let mut buffer = self.buffer.lock().unwrap(); - if let Some(message) = self.next_buffered() { - return Some(message); - } + let (ticks, done) = T::decode(&mut message).unwrap(); - if self.done.load(Ordering::Relaxed) { - return None; - } + buffer.append(&mut ticks.into()); + self.done.store(done, Ordering::Relaxed); - match self.messages.next_timeout(duration) { - Some(Ok(message)) if message.message_type() == T::MESSAGE_TYPE => { - self.fill_buffer(message); - self.next_timeout(duration) + Ok(()) } Some(Ok(message)) => { debug!("unexpected message: {:?}", message); - self.next_timeout(duration) + Ok(()) } Some(Err(e)) => { self.set_error(e); - None + Err(()) } - None => None, + None => Err(()), } } @@ -658,15 +635,6 @@ impl> TickSubscription { let mut error = self.error.lock().unwrap(); *error = None; } - - fn fill_buffer(&self, mut message: ResponseMessage) { - let mut buffer = self.buffer.lock().unwrap(); - - let (ticks, done) = T::decode(&mut message).unwrap(); - - buffer.append(&mut ticks.into()); - self.done.store(done, Ordering::Relaxed); - } } /// An iterator that yields items as they become available, blocking if necessary. diff --git a/src/transport.rs b/src/transport.rs index 7d073fc3..e0d2ef86 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -55,7 +55,7 @@ pub(crate) trait MessageBus: Send + Sync { } } -type Response = Result; +pub(crate) type Response = Result; // For requests without an identifier, shared channels are created // to route request/response pairs based on message type.