diff --git a/src/transport.rs b/src/transport.rs index c6c9303b..c3e3d470 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::io::{prelude::*, Cursor, ErrorKind}; use std::net::TcpStream; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread::{self, JoinHandle}; use std::time::Duration; @@ -157,7 +157,7 @@ pub enum Signal { #[derive(Debug)] pub struct TcpMessageBus { - connection: Arc, + connection: Arc>, handles: Vec>, requests: Arc>, orders: Arc>, @@ -177,7 +177,7 @@ impl TcpMessageBus { let (signals_send, signals_recv) = channel::unbounded(); Ok(TcpMessageBus { - connection: Arc::new(connection), + connection: Arc::new(RwLock::new(connection)), handles: Vec::default(), requests, orders, @@ -205,17 +205,22 @@ impl TcpMessageBus { const RETRY_ERRORS: &[ErrorKind] = &[ErrorKind::Interrupted]; thread::spawn(move || loop { - // connection.read_message() + let _connection = connection.read().unwrap(); - match read_packet(&connection.stream) { + match read_packet(&_connection.stream) { Ok(message) => { recorder.record_response(&message); dispatch_message(message, server_version, &requests, &orders, &shared_channels, &executions); } Err(Error::Io(e)) if RECONNECT_ERRORS.contains(&e.kind()) => { + let mut connection = connection.write().unwrap(); error!("error reading packet: {:?}", e); // reset hashes - // connection.reconnect() + if let Err(e) = connection.reconnect() { + error!("error reconnecting: {:?}", e); + shutdown_requested.store(true, Ordering::Relaxed); + return 0; + } } Err(Error::Io(e)) if RETRY_ERRORS.contains(&e.kind()) => { error!("error reading packet: {:?}", e); @@ -267,13 +272,13 @@ const UNSPECIFIED_REQUEST_ID: i32 = -1; impl MessageBus for TcpMessageBus { fn send_request(&mut self, request_id: i32, packet: &RequestMessage) -> Result { + let connection = self.connection.read()?; + let (sender, receiver) = channel::unbounded(); self.requests.insert(request_id, sender); - //FIXME - // write_message(&mut self.connection.stream, packet)?; - // self.connection.write_message(packet)?; + connection.write_message(packet)?; let subscription = SubscriptionBuilder::new() .receiver(receiver) @@ -284,19 +289,28 @@ impl MessageBus for TcpMessageBus { Ok(subscription) } - fn cancel_subscription(&mut self, request_id: i32, packet: &RequestMessage) -> Result<(), Error> { - // write_message(&self.connection.stream, packet)?; + fn cancel_subscription(&mut self, request_id: i32, message: &RequestMessage) -> Result<(), Error> { + let connection = self.connection.read()?; + + connection.write_message(message)?; + + if let Err(e) = self.requests.send(&request_id, Response::Cancelled) { + info!("error sending cancel notification: {e}"); + } + self.requests.remove(&request_id); + Ok(()) } fn send_order_request(&mut self, order_id: i32, message: &RequestMessage) -> Result { + let connection = self.connection.read()?; + let (sender, receiver) = channel::unbounded(); self.orders.insert(order_id, sender); - // FIXME - //self.write_message(message)?; + connection.write_message(message)?; let subscription = SubscriptionBuilder::new() .receiver(receiver) @@ -307,15 +321,24 @@ impl MessageBus for TcpMessageBus { Ok(subscription) } - fn cancel_order_subscription(&mut self, request_id: i32, packet: &RequestMessage) -> Result<(), Error> { - // write_message(&self.connection.stream, packet)?; + fn cancel_order_subscription(&mut self, request_id: i32, message: &RequestMessage) -> Result<(), Error> { + let connection = self.connection.read()?; + + connection.write_message(message)?; + + if let Err(e) = self.orders.send(&request_id, Response::Cancelled) { + info!("error sending cancel notification: {e}"); + } + self.orders.remove(&request_id); + Ok(()) } fn send_shared_request(&mut self, message_type: OutgoingMessages, message: &RequestMessage) -> Result { - // FIXME - //self.write_message(message)?; + let connection = self.connection.read()?; + + connection.write_message(message)?; let shared_receiver = self.shared_channels.get_receiver(message_type); @@ -327,8 +350,11 @@ impl MessageBus for TcpMessageBus { Ok(subscription) } - fn cancel_shared_subscription(&mut self, message_type: OutgoingMessages, packet: &RequestMessage) -> Result<(), Error> { - // write_message(&self.connection.stream, packet)?; + fn cancel_shared_subscription(&mut self, _message_type: OutgoingMessages, message: &RequestMessage) -> Result<(), Error> { + let connection = self.connection.read()?; + + connection.write_message(message)?; + // TODO send cancel Ok(()) } @@ -627,9 +653,9 @@ impl InternalSubscription { } } - pub(crate) fn cancel(&self) -> Result<(), Error> { - Ok(()) - } + // pub(crate) fn cancel(&self) -> Result<(), Error> { + // Ok(()) + // } fn receive(receiver: &Receiver) -> Option { receiver.recv().ok() @@ -745,6 +771,7 @@ pub(crate) struct Connection { pub(crate) client_id: i32, pub(crate) connection_url: String, stream: TcpStream, + writer: Mutex, pub(crate) server_version: i32, pub(crate) connection_time: Option, pub(crate) time_zone: Option<&'static Tz>, @@ -754,11 +781,13 @@ pub(crate) struct Connection { impl Connection { pub fn connect(client_id: i32, connection_url: &str) -> Result { let stream = TcpStream::connect(connection_url)?; + let writer = Mutex::new(stream.try_clone()?); let mut connection = Self { client_id, connection_url: connection_url.into(), stream, + writer, server_version: -1, connection_time: None, time_zone: None, @@ -791,7 +820,9 @@ impl Connection { Ok(()) } - fn write_message(&mut self, message: &RequestMessage) -> Result<(), Error> { + fn write_message(&self, message: &RequestMessage) -> Result<(), Error> { + let mut writer = self.writer.lock()?; + let data = message.encode(); debug!("-> {data:?}"); @@ -802,7 +833,7 @@ impl Connection { packet.write_u32::(data.len() as u32)?; packet.write_all(data)?; - self.stream.write_all(&packet)?; + writer.write_all(&packet)?; Ok(()) } @@ -951,45 +982,5 @@ fn encode_packet(message: &str) -> String { std::str::from_utf8(&packet).unwrap().into() } -fn write_message(stream: &mut TcpStream, message: &RequestMessage) -> Result<(), Error> { - let data = message.encode(); - debug!("-> {data:?}"); - - let data = data.as_bytes(); - - let mut packet = Vec::with_capacity(data.len() + 4); - - packet.write_u32::(data.len() as u32)?; - packet.write_all(data)?; - - stream.write_all(&packet)?; - - Ok(()) -} - -// fn write_message(&mut self, message: &RequestMessage) -> Result<(), Error> { -// let data = message.encode(); -// debug!("-> {data:?}"); - -// let data = data.as_bytes(); - -// let mut packet = Vec::with_capacity(data.len() + 4); - -// packet.write_u32::(data.len() as u32)?; -// packet.write_all(data)?; - -// self.writer.lock()?.write_all(&packet)?; - -// self.recorder.record_request(message); - -// Ok(()) -// } - -// fn write(&mut self, data: &str) -> Result<(), Error> { -// debug!("{data:?} ->"); -// self.writer.lock()?.write_all(data.as_bytes())?; -// Ok(()) -// } - #[cfg(test)] mod tests;