diff --git a/examples/positions.rs b/examples/positions.rs index 0382ad1d..9b2eb2cc 100644 --- a/examples/positions.rs +++ b/examples/positions.rs @@ -1,4 +1,4 @@ -use ibapi::{accounts::PositionResponse, Client}; +use ibapi::{accounts::PositionUpdate, Client}; fn main() { let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); @@ -6,10 +6,10 @@ fn main() { let mut positions = client.positions().expect("request failed"); while let Some(position_update) = positions.next() { match position_update { - PositionResponse::Position(position) => { + PositionUpdate::Position(position) => { println!("{:4} {:4} @ {}", position.position, position.contract.symbol, position.average_cost) } - PositionResponse::PositionEnd => { + PositionUpdate::PositionEnd => { println!("PositionEnd"); // all positions received. could continue listening for new additions or cancel. positions.cancel(); diff --git a/examples/positions_multi.rs b/examples/positions_multi.rs new file mode 100644 index 00000000..efc69110 --- /dev/null +++ b/examples/positions_multi.rs @@ -0,0 +1,14 @@ +use std::env; + +use ibapi::Client; + +pub fn main() { + let account = env::var("IBKR_ACCOUNT").expect("Please set IBKR_ACCOUNT environment variable to an account ID"); + + let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + + let subscription = client.positions_multi(Some(&account), None).expect("error requesting positions by model"); + for position in subscription { + println!("{position:?}") + } +} diff --git a/src/accounts.rs b/src/accounts.rs index 75cbb8ad..21115528 100644 --- a/src/accounts.rs +++ b/src/accounts.rs @@ -38,11 +38,8 @@ impl Subscribable for PnL { } fn cancel_message(_server_version: i32, request_id: Option) -> Result { - if let Some(request_id) = request_id { - encoders::encode_cancel_pnl(request_id) - } else { - Err(Error::Simple("Request id request to encode cancel pnl single".into())) - } + let request_id = request_id.expect("Request ID required to encode cancel pnl"); + encoders::encode_cancel_pnl(request_id) } } @@ -69,11 +66,8 @@ impl Subscribable for PnLSingle { } fn cancel_message(_server_version: i32, request_id: Option) -> Result { - if let Some(request_id) = request_id { - encoders::encode_cancel_pnl_single(request_id) - } else { - Err(Error::Simple("Request id request to encode cancel pnl single".into())) - } + let request_id = request_id.expect("Request ID required to encode cancel pnl single"); + encoders::encode_cancel_pnl_single(request_id) } } @@ -91,24 +85,24 @@ pub struct Position { #[allow(clippy::large_enum_variant)] #[derive(Clone, Debug)] -pub enum PositionResponse { +pub enum PositionUpdate { Position(Position), PositionEnd, } -impl From for PositionResponse { +impl From for PositionUpdate { fn from(val: Position) -> Self { - PositionResponse::Position(val) + PositionUpdate::Position(val) } } -impl Subscribable for PositionResponse { +impl Subscribable for PositionUpdate { const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::Position, IncomingMessages::PositionEnd]; fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::Position => Ok(PositionResponse::Position(decoders::decode_position(message)?)), - IncomingMessages::PositionEnd => Ok(PositionResponse::PositionEnd), + IncomingMessages::Position => Ok(PositionUpdate::Position(decoders::decode_position(message)?)), + IncomingMessages::PositionEnd => Ok(PositionUpdate::PositionEnd), message => Err(Error::Simple(format!("unexpected message: {message:?}"))), } } @@ -118,6 +112,51 @@ impl Subscribable for PositionResponse { } } +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug)] +pub enum PositionUpdateMulti { + Position(PositionMulti), + PositionEnd, +} + +impl From for PositionUpdateMulti { + fn from(val: PositionMulti) -> Self { + PositionUpdateMulti::Position(val) + } +} + +/// Portfolio's open positions. +#[derive(Debug, Clone, Default)] +pub struct PositionMulti { + /// The account holding the position. + account: String, + /// The model code holding the position. + model_code: String, + /// The position's Contract + contract: Contract, + /// The number of positions held. + position: f64, + /// The average cost of the position. + average_cost: f64, +} + +impl Subscribable for PositionUpdateMulti { + const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::PositionMulti, IncomingMessages::PositionMultiEnd]; + + fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + match message.message_type() { + IncomingMessages::PositionMulti => Ok(PositionUpdateMulti::Position(decoders::decode_position_multi(message)?)), + IncomingMessages::PositionMultiEnd => Ok(PositionUpdateMulti::PositionEnd), + message => Err(Error::Simple(format!("unexpected message: {message:?}"))), + } + } + + fn cancel_message(_server_version: i32, request_id: Option) -> Result { + let request_id = request_id.expect("Request ID required to encode cancel positions multi"); + Ok(encoders::encode_cancel_positions_multi(request_id)?) + } +} + #[derive(Debug, Default)] pub struct FamilyCode { /// Account ID @@ -128,12 +167,11 @@ pub struct FamilyCode { // Subscribes to position updates for all accessible accounts. // All positions sent initially, and then only updates as positions change. -pub(crate) fn positions(client: &Client) -> Result, Error> { +pub(crate) fn positions(client: &Client) -> Result, Error> { client.check_server_version(server_versions::ACCOUNT_SUMMARY, "It does not support position requests.")?; - let message = encoders::encode_request_positions()?; - - let responses = client.send_shared_request(OutgoingMessages::RequestPositions, message)?; + let request = encoders::encode_request_positions()?; + let responses = client.send_shared_request(OutgoingMessages::RequestPositions, request)?; Ok(Subscription { client, @@ -143,7 +181,27 @@ pub(crate) fn positions(client: &Client) -> Result {} +impl SharesChannel for Subscription<'_, PositionUpdate> {} + +pub(crate) fn positions_multi<'a>( + client: &'a Client, + account: Option<&str>, + model_code: Option<&str>, +) -> Result, Error> { + client.check_server_version(server_versions::MODELS_SUPPORT, "It does not support positions multi requests.")?; + + let request_id = client.next_request_id(); + + let request = encoders::encode_request_positions_multi(request_id, account, model_code)?; + let responses = client.send_request(request_id, request)?; + + Ok(Subscription { + client, + request_id: Some(request_id), + responses, + phantom: PhantomData, + }) +} // Determine whether an account exists under an account family and find the account family code. pub(crate) fn family_codes(client: &Client) -> Result, Error> { @@ -176,7 +234,7 @@ pub(crate) fn pnl<'a>(client: &'a Client, account: &str, model_code: Option<&str Ok(Subscription { client, - request_id: None, + request_id: Some(request_id), responses, phantom: PhantomData, }) @@ -204,7 +262,7 @@ pub(crate) fn pnl_single<'a>( Ok(Subscription { client, - request_id: None, + request_id: Some(request_id), responses, phantom: PhantomData, }) diff --git a/src/accounts/decoders.rs b/src/accounts/decoders.rs index 3828e27b..e901942a 100644 --- a/src/accounts/decoders.rs +++ b/src/accounts/decoders.rs @@ -2,7 +2,7 @@ use crate::contracts::SecurityType; use crate::messages::ResponseMessage; use crate::{server_versions, Error}; -use super::{FamilyCode, PnL, PnLSingle, Position}; +use super::{FamilyCode, PnL, PnLSingle, Position, PositionMulti}; pub(crate) fn decode_position(message: &mut ResponseMessage) -> Result { message.skip(); // message type @@ -38,6 +38,35 @@ pub(crate) fn decode_position(message: &mut ResponseMessage) -> Result Result { + message.skip(); // message type + message.skip(); // message version + message.skip(); // request id + + let mut position = PositionMulti { + account: message.next_string()?, + ..Default::default() + }; + + position.contract.contract_id = message.next_int()?; + position.contract.symbol = message.next_string()?; + position.contract.security_type = SecurityType::from(&message.next_string()?); + position.contract.last_trade_date_or_contract_month = message.next_string()?; + position.contract.strike = message.next_double()?; + position.contract.right = message.next_string()?; + position.contract.multiplier = message.next_string()?; + position.contract.exchange = message.next_string()?; + position.contract.currency = message.next_string()?; + position.contract.local_symbol = message.next_string()?; + position.contract.trading_class = message.next_string()?; + + position.position = message.next_double()?; + position.average_cost = message.next_double()?; + position.model_code = message.next_string()?; + + Ok(position) +} + pub(crate) fn decode_family_codes(message: &mut ResponseMessage) -> Result, Error> { message.skip(); // message type diff --git a/src/accounts/decoders/tests.rs b/src/accounts/decoders/tests.rs index a2a1db47..10bb2804 100644 --- a/src/accounts/decoders/tests.rs +++ b/src/accounts/decoders/tests.rs @@ -29,6 +29,35 @@ fn test_decode_positions() { assert_eq!(position.average_cost, 196.77, "position.average_cost"); } +#[test] +fn test_decode_position_multi() { + let mut message = super::ResponseMessage::from("61\03\06\0DU1234567\076792991\0TSLA\0STK\0\00.0\0\0\0NASDAQ\0USD\0TSLA\0NMS\0500\0196.77\0"); + + let position = super::decode_position_multi(&mut message).expect("error decoding position multi"); + + assert_eq!(position.account, "DU1234567", "position.account"); + assert_eq!(position.contract.contract_id, 76792991, "position.contract.contract_id"); + assert_eq!(position.contract.symbol, "TSLA", "position.contract.symbol"); + assert_eq!( + position.contract.security_type, + super::SecurityType::Stock, + "position.contract.security_type" + ); + assert_eq!( + position.contract.last_trade_date_or_contract_month, "", + "position.contract.last_trade_date_or_contract_month" + ); + assert_eq!(position.contract.strike, 0.0, "position.contract.strike"); + assert_eq!(position.contract.right, "", "position.contract.right"); + assert_eq!(position.contract.multiplier, "", "position.contract.multiplier"); + assert_eq!(position.contract.exchange, "NASDAQ", "position.contract.exchange"); + assert_eq!(position.contract.currency, "USD", "position.contract.currency"); + assert_eq!(position.contract.local_symbol, "TSLA", "position.contract.local_symbol"); + assert_eq!(position.contract.trading_class, "NMS", "position.contract.trading_class"); + assert_eq!(position.position, 500.0, "position.position"); + assert_eq!(position.average_cost, 196.77, "position.average_cost"); +} + #[test] fn test_decode_family_codes() { let mut message = super::ResponseMessage::from("78\01\0*\0\0"); diff --git a/src/accounts/encoders.rs b/src/accounts/encoders.rs index 82cc0b50..735c0f11 100644 --- a/src/accounts/encoders.rs +++ b/src/accounts/encoders.rs @@ -10,6 +10,32 @@ pub(crate) fn encode_cancel_positions() -> Result { encode_simple(OutgoingMessages::CancelPositions, 1) } +pub(crate) fn encode_request_positions_multi(request_id: i32, account: Option<&str>, model_code: Option<&str>) -> Result { + let mut message = RequestMessage::new(); + + const VERSION: i32 = 1; + + message.push_field(&OutgoingMessages::RequestPositionsMulti); + message.push_field(&VERSION); + message.push_field(&request_id); + message.push_field(&account); + message.push_field(&model_code); + + Ok(message) +} + +pub(crate) fn encode_cancel_positions_multi(request_id: i32) -> Result { + let mut message = RequestMessage::new(); + + const VERSION: i32 = 1; + + message.push_field(&OutgoingMessages::CancelPositionsMulti); + message.push_field(&VERSION); + message.push_field(&request_id); + + Ok(message) +} + pub(crate) fn encode_request_family_codes() -> Result { encode_simple(OutgoingMessages::RequestFamilyCodes, 1) } @@ -20,12 +46,7 @@ pub(crate) fn encode_request_pnl(request_id: i32, account: &str, model_code: Opt message.push_field(&OutgoingMessages::RequestPnL); message.push_field(&request_id); message.push_field(&account); - - if let Some(model_code) = model_code { - message.push_field(&model_code); - } else { - message.push_field(&""); - } + message.push_field(&model_code); Ok(message) } diff --git a/src/accounts/encoders/tests.rs b/src/accounts/encoders/tests.rs index 559b445d..9d39b635 100644 --- a/src/accounts/encoders/tests.rs +++ b/src/accounts/encoders/tests.rs @@ -3,7 +3,7 @@ use crate::ToField; use super::*; #[test] -fn test_request_positions() { +fn test_encode_request_positions() { let message = super::encode_request_positions().expect("error encoding request"); assert_eq!(message[0], OutgoingMessages::RequestPositions.to_field(), "message.type"); @@ -11,7 +11,7 @@ fn test_request_positions() { } #[test] -fn test_cancel_positions() { +fn test_encode_cancel_positions() { let message = super::encode_cancel_positions().expect("error encoding request"); assert_eq!(message[0], OutgoingMessages::CancelPositions.to_field(), "message.type"); @@ -19,7 +19,35 @@ fn test_cancel_positions() { } #[test] -fn test_request_family_codes() { +fn test_encode_request_positions_multi() { + let request_id = 9000; + let version = 1; + let account = Some("U1234567"); + let model_code = Some("TARGET2024"); + + let message = super::encode_request_positions_multi(request_id, account, model_code).expect("error encoding request"); + + assert_eq!(message[0], OutgoingMessages::RequestPositionsMulti.to_field(), "message.type"); + assert_eq!(message[1], version.to_field(), "message.version"); + assert_eq!(message[2], request_id.to_field(), "message.request_id"); + assert_eq!(message[3], account.to_field(), "message.account"); + assert_eq!(message[4], model_code.to_field(), "message.model_code"); +} + +#[test] +fn test_encode_cancel_positions_multi() { + let request_id = 9000; + let version = 1; + + let message = super::encode_cancel_positions_multi(request_id).expect("error encoding request"); + + assert_eq!(message[0], OutgoingMessages::CancelPositionsMulti.to_field(), "message.type"); + assert_eq!(message[1], version.to_field(), "message.version"); + assert_eq!(message[2], request_id.to_field(), "message.request_id"); +} + +#[test] +fn test_encode_request_family_codes() { let message = super::encode_request_family_codes().expect("error encoding request"); assert_eq!(message[0], OutgoingMessages::RequestFamilyCodes.to_field(), "message.type"); diff --git a/src/accounts/tests.rs b/src/accounts/tests.rs index 8799446b..c509bf1c 100644 --- a/src/accounts/tests.rs +++ b/src/accounts/tests.rs @@ -15,10 +15,15 @@ fn test_pnl() { let model_code = Some("TARGET2024"); let _ = client.pnl(account, model_code).expect("request pnl failed"); + let _ = client.pnl(account, None).expect("request pnl failed"); let request_messages = client.message_bus.lock().unwrap().request_messages(); assert_eq!(request_messages[0].encode_simple(), "92|9000|DU1234567|TARGET2024|"); + assert_eq!(request_messages[1].encode_simple(), "93|9000|"); + + assert_eq!(request_messages[2].encode_simple(), "92|9001|DU1234567||"); + assert_eq!(request_messages[3].encode_simple(), "93|9001|"); } #[test] @@ -35,8 +40,54 @@ fn test_pnl_single() { let model_code = Some("TARGET2024"); let _ = client.pnl_single(account, contract_id, model_code).expect("request pnl failed"); + let _ = client.pnl_single(account, contract_id, None).expect("request pnl failed"); let request_messages = client.message_bus.lock().unwrap().request_messages(); assert_eq!(request_messages[0].encode_simple(), "94|9000|DU1234567|TARGET2024|1001|"); + assert_eq!(request_messages[1].encode_simple(), "95|9000|"); + + assert_eq!(request_messages[2].encode_simple(), "94|9001|DU1234567||1001|"); + assert_eq!(request_messages[3].encode_simple(), "95|9001|"); +} + +#[test] +fn test_positions() { + let message_bus = Arc::new(Mutex::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![], + })); + + let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); + + let _ = client.positions().expect("request positions failed"); + + let request_messages = client.message_bus.lock().unwrap().request_messages(); + + assert_eq!(request_messages[0].encode_simple(), "61|1|"); + assert_eq!(request_messages[1].encode_simple(), "64|1|"); +} + +#[test] +fn test_positions_multi() { + let message_bus = Arc::new(Mutex::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![], + })); + + let client = Client::stubbed(message_bus, server_versions::SIZE_RULES); + + let account = Some("DU1234567"); + let model_code = Some("TARGET2024"); + + let _ = client.positions_multi(account, model_code).expect("request positions failed"); + let _ = client.positions_multi(None, model_code).expect("request positions failed"); + + let request_messages = client.message_bus.lock().unwrap().request_messages(); + + assert_eq!(request_messages[0].encode_simple(), "74|1|9000|DU1234567|TARGET2024|"); + assert_eq!(request_messages[1].encode_simple(), "75|1|9000|"); + + assert_eq!(request_messages[2].encode_simple(), "74|1|9001||TARGET2024|"); + assert_eq!(request_messages[3].encode_simple(), "75|1|9001|"); } diff --git a/src/client.rs b/src/client.rs index 124c0b9a..920ffd8c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,7 @@ use time::macros::format_description; use time::OffsetDateTime; use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, Tz}; -use crate::accounts::{FamilyCode, PnL, PnLSingle, PositionResponse}; +use crate::accounts::{FamilyCode, PnL, PnLSingle, PositionMulti, PositionUpdate, PositionUpdateMulti}; use crate::contracts::Contract; use crate::errors::Error; use crate::market_data::historical; @@ -211,11 +211,52 @@ impl Client { // === Accounts === - /// Get current [Position](accounts::Position)s for all accessible accounts. - pub fn positions(&self) -> core::result::Result, Error> { + /// Subscribes to [PositionUpdate](accounts::PositionUpdate)s for all accessible accounts. + /// All positions sent initially, and then only updates as positions change. + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::Client; + /// use ibapi::accounts::PositionUpdate; + /// + /// let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + /// let subscription = client.positions().expect("error requesting positions"); + /// for position_response in subscription { + /// match position_response { + /// PositionUpdate::Position(position) => println!("{position:?}"), + /// PositionUpdate::PositionEnd => println!("initial set of positions received"), + /// } + /// } + /// ``` + pub fn positions(&self) -> core::result::Result, Error> { accounts::positions(self) } + /// Subscribes to [PositionUpdateMulti](accounts::PositionUpdateMulti) updates for account and/or model. + /// Initially all positions are returned, and then updates are returned for any position changes in real time. + /// + /// # Arguments + /// * `account` - If an account Id is provided, only the account’s positions belonging to the specified model will be delivered. + /// * `model_code` - The code of the model’s positions we are interested in. + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::Client; + /// + /// let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + /// + /// let account = "U1234567"; + /// let subscription = client.positions_multi(Some(account), None).expect("error requesting positions by model"); + /// for position in subscription { + /// println!("{position:?}") + /// } + /// ``` + pub fn positions_multi(&self, account: Option<&str>, model_code: Option<&str>) -> Result, Error> { + accounts::positions_multi(self, account, model_code) + } + /// Creates subscription for real time daily PnL and unrealized PnL updates. /// /// # Arguments diff --git a/src/lib.rs b/src/lib.rs index 3a083370..aa2fad80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -157,6 +157,12 @@ impl ToField for &str { } } +impl ToField for Option<&str> { + fn to_field(&self) -> String { + encode_option_field(self) + } +} + impl ToField for usize { fn to_field(&self) -> String { self.to_string()