Skip to content

Commit

Permalink
added RwLock
Browse files Browse the repository at this point in the history
  • Loading branch information
Wil Boayue committed Oct 14, 2024
1 parent 635752f commit f09879a
Showing 1 changed file with 55 additions and 64 deletions.
119 changes: 55 additions & 64 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::io::{prelude::*, Cursor, ErrorKind};
use std::net::TcpStream;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::{self, JoinHandle};
use std::time::Duration;

Expand Down Expand Up @@ -157,7 +157,7 @@ pub enum Signal {

#[derive(Debug)]
pub struct TcpMessageBus {
connection: Arc<Connection>,
connection: Arc<RwLock<Connection>>,
handles: Vec<JoinHandle<i32>>,
requests: Arc<SenderHash<i32, Response>>,
orders: Arc<SenderHash<i32, Response>>,
Expand All @@ -177,7 +177,7 @@ impl TcpMessageBus {
let (signals_send, signals_recv) = channel::unbounded();

Ok(TcpMessageBus {
connection: Arc::new(connection),
connection: Arc::new(RwLock::new(connection)),
handles: Vec::default(),
requests,
orders,
Expand Down Expand Up @@ -205,17 +205,22 @@ impl TcpMessageBus {
const RETRY_ERRORS: &[ErrorKind] = &[ErrorKind::Interrupted];

thread::spawn(move || loop {
// connection.read_message()
let _connection = connection.read().unwrap();

match read_packet(&connection.stream) {
match read_packet(&_connection.stream) {
Ok(message) => {
recorder.record_response(&message);
dispatch_message(message, server_version, &requests, &orders, &shared_channels, &executions);
}
Err(Error::Io(e)) if RECONNECT_ERRORS.contains(&e.kind()) => {
let mut connection = connection.write().unwrap();
error!("error reading packet: {:?}", e);
// reset hashes
// connection.reconnect()
if let Err(e) = connection.reconnect() {
error!("error reconnecting: {:?}", e);
shutdown_requested.store(true, Ordering::Relaxed);
return 0;
}
}
Err(Error::Io(e)) if RETRY_ERRORS.contains(&e.kind()) => {
error!("error reading packet: {:?}", e);
Expand Down Expand Up @@ -267,13 +272,13 @@ const UNSPECIFIED_REQUEST_ID: i32 = -1;

impl MessageBus for TcpMessageBus {
fn send_request(&mut self, request_id: i32, packet: &RequestMessage) -> Result<InternalSubscription, Error> {
let connection = self.connection.read()?;

let (sender, receiver) = channel::unbounded();

self.requests.insert(request_id, sender);

//FIXME
// write_message(&mut self.connection.stream, packet)?;
// self.connection.write_message(packet)?;
connection.write_message(packet)?;

let subscription = SubscriptionBuilder::new()
.receiver(receiver)
Expand All @@ -284,19 +289,28 @@ impl MessageBus for TcpMessageBus {
Ok(subscription)
}

fn cancel_subscription(&mut self, request_id: i32, packet: &RequestMessage) -> Result<(), Error> {
// write_message(&self.connection.stream, packet)?;
fn cancel_subscription(&mut self, request_id: i32, message: &RequestMessage) -> Result<(), Error> {
let connection = self.connection.read()?;

connection.write_message(message)?;

if let Err(e) = self.requests.send(&request_id, Response::Cancelled) {
info!("error sending cancel notification: {e}");
}

self.requests.remove(&request_id);

Ok(())
}

fn send_order_request(&mut self, order_id: i32, message: &RequestMessage) -> Result<InternalSubscription, Error> {
let connection = self.connection.read()?;

let (sender, receiver) = channel::unbounded();

self.orders.insert(order_id, sender);

// FIXME
//self.write_message(message)?;
connection.write_message(message)?;

let subscription = SubscriptionBuilder::new()
.receiver(receiver)
Expand All @@ -307,15 +321,24 @@ impl MessageBus for TcpMessageBus {
Ok(subscription)
}

fn cancel_order_subscription(&mut self, request_id: i32, packet: &RequestMessage) -> Result<(), Error> {
// write_message(&self.connection.stream, packet)?;
fn cancel_order_subscription(&mut self, request_id: i32, message: &RequestMessage) -> Result<(), Error> {
let connection = self.connection.read()?;

connection.write_message(message)?;

if let Err(e) = self.orders.send(&request_id, Response::Cancelled) {
info!("error sending cancel notification: {e}");
}

self.orders.remove(&request_id);

Ok(())
}

fn send_shared_request(&mut self, message_type: OutgoingMessages, message: &RequestMessage) -> Result<InternalSubscription, Error> {
// FIXME
//self.write_message(message)?;
let connection = self.connection.read()?;

connection.write_message(message)?;

let shared_receiver = self.shared_channels.get_receiver(message_type);

Expand All @@ -327,8 +350,11 @@ impl MessageBus for TcpMessageBus {
Ok(subscription)
}

fn cancel_shared_subscription(&mut self, message_type: OutgoingMessages, packet: &RequestMessage) -> Result<(), Error> {
// write_message(&self.connection.stream, packet)?;
fn cancel_shared_subscription(&mut self, _message_type: OutgoingMessages, message: &RequestMessage) -> Result<(), Error> {
let connection = self.connection.read()?;

connection.write_message(message)?;
// TODO send cancel
Ok(())
}

Expand Down Expand Up @@ -627,9 +653,9 @@ impl InternalSubscription {
}
}

pub(crate) fn cancel(&self) -> Result<(), Error> {
Ok(())
}
// pub(crate) fn cancel(&self) -> Result<(), Error> {
// Ok(())
// }

fn receive(receiver: &Receiver<Response>) -> Option<Response> {
receiver.recv().ok()
Expand Down Expand Up @@ -745,6 +771,7 @@ pub(crate) struct Connection {
pub(crate) client_id: i32,
pub(crate) connection_url: String,
stream: TcpStream,
writer: Mutex<TcpStream>,
pub(crate) server_version: i32,
pub(crate) connection_time: Option<OffsetDateTime>,
pub(crate) time_zone: Option<&'static Tz>,
Expand All @@ -754,11 +781,13 @@ pub(crate) struct Connection {
impl Connection {
pub fn connect(client_id: i32, connection_url: &str) -> Result<Self, Error> {
let stream = TcpStream::connect(connection_url)?;
let writer = Mutex::new(stream.try_clone()?);

let mut connection = Self {
client_id,
connection_url: connection_url.into(),
stream,
writer,
server_version: -1,
connection_time: None,
time_zone: None,
Expand Down Expand Up @@ -791,7 +820,9 @@ impl Connection {
Ok(())
}

fn write_message(&mut self, message: &RequestMessage) -> Result<(), Error> {
fn write_message(&self, message: &RequestMessage) -> Result<(), Error> {
let mut writer = self.writer.lock()?;

let data = message.encode();
debug!("-> {data:?}");

Expand All @@ -802,7 +833,7 @@ impl Connection {
packet.write_u32::<BigEndian>(data.len() as u32)?;
packet.write_all(data)?;

self.stream.write_all(&packet)?;
writer.write_all(&packet)?;

Ok(())
}
Expand Down Expand Up @@ -951,45 +982,5 @@ fn encode_packet(message: &str) -> String {
std::str::from_utf8(&packet).unwrap().into()
}

fn write_message(stream: &mut TcpStream, message: &RequestMessage) -> Result<(), Error> {
let data = message.encode();
debug!("-> {data:?}");

let data = data.as_bytes();

let mut packet = Vec::with_capacity(data.len() + 4);

packet.write_u32::<BigEndian>(data.len() as u32)?;
packet.write_all(data)?;

stream.write_all(&packet)?;

Ok(())
}

// fn write_message(&mut self, message: &RequestMessage) -> Result<(), Error> {
// let data = message.encode();
// debug!("-> {data:?}");

// let data = data.as_bytes();

// let mut packet = Vec::with_capacity(data.len() + 4);

// packet.write_u32::<BigEndian>(data.len() as u32)?;
// packet.write_all(data)?;

// self.writer.lock()?.write_all(&packet)?;

// self.recorder.record_request(message);

// Ok(())
// }

// fn write(&mut self, data: &str) -> Result<(), Error> {
// debug!("{data:?} ->");
// self.writer.lock()?.write_all(data.as_bytes())?;
// Ok(())
// }

#[cfg(test)]
mod tests;

0 comments on commit f09879a

Please sign in to comment.