-
Hello! #[derive(PartialEq, Debug, Default, Readable, Writable)]
pub struct ProtocolVersion {
major: u8,
minor: u8,
patch: u8,
}
#[derive(Clone, Default, PartialEq, Debug, Readable, Writable)]
pub struct Text4b {
count: u32,
#[speedy(length = count)]
pub data: Vec<u8>,
}
impl Text4b {
pub fn get_size(&self) -> usize {
self.data.len() + std::mem::size_of_val(&self.count)
}
pub fn set_content(&mut self, text: String) {
let text_bytes = text.as_bytes().to_vec();
let count: u32 = text_bytes.len().try_into().unwrap();
self.data = text_bytes;
self.count = count;
}
}
impl fmt::Display for Text4b {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let text = std::str::from_utf8(&self.data).map_err(|_e| fmt::Error)?;
write!(f, "{}", text)
}
}
#[derive(PartialEq, Debug, Readable, Writable)]
pub struct InitAuthRequest {
pub header: Header,
pub v: ProtocolVersion,
pub login: Text4b,
pub secret: Text4b,
}
struct Peer {
lines: Framed<TlsStream<TcpStream>, AuthProtocolCodec>,
rx: Rx,
}
impl Peer {
/// Create a new instance of `Peer`.
async fn new(
state: Arc<Mutex<Shared>>,
lines: Framed<TlsStream<TcpStream>, AuthProtocolCodec>,
) -> std::io::Result<Peer> {
// Get the client socket address
let addr = lines.get_ref().get_ref().0.peer_addr()?;
// Create a channel for this peer
let (tx, rx) = mpsc::unbounded_channel();
// Add an entry for this `Peer` in the shared state map.
state.lock().await.peers.insert(addr, tx);
Ok(Peer { lines, rx })
}
}
impl Decoder for AuthProtocolCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if buf.len() < 8 { //NOTE: header size = 8 bytes
return Ok(None);
}
let len: usize = buf.len();
let message_header = match frames::Header::read_from_buffer(&buf[..8]) {
Ok(message_header) => message_header,
Err(_e) => {
return Err(errorize("first chunk: body size error"));
}
};
if !message_header.is_req_ok() || message_header.body_size == 0 {
return Err(errorize("first chunk: body size error"));
}
let header_size = size_of::<Header>() as u32;
let estimated_frame_size = (message_header.body_size + header_size) as usize;
// NOTE: https://docs.rs/tokio-util/latest/tokio_util/codec/index.html#the-decoder-trait
return match len.cmp(&estimated_frame_size) {
Ordering::Less => Ok(None),
Ordering::Equal => {
tracing::info!(
"Equal estimated_frame_size: {}, buf len: {}",
estimated_frame_size,
len
);
Ok(Some(buf.to_owned()))
}
Ordering::Greater => {
tracing::info!(
"Greater estimated_frame_size: {}, buf len: {}",
estimated_frame_size,
len
);
Ok(Some(buf.split_to(len)))
}
};
}
async fn process_plain_stream(
processor: Arc<Mutex<FrameProcessor>>,
state: Arc<Mutex<Shared>>,
mut peer: Peer,
addr: SocketAddr,
) -> Result<(), Box<dyn Error>> {
loop {
tokio::select! {
// A message was received from a peer. Send it to the current user.
Some(msg) = peer.rx.recv() => {
tracing::warn!("recv msg: {:?}, len: {}", msg, msg.len());
peer.lines.send(msg).await?; //NOTE: send to peer
}
result = peer.lines.next() => match result {
Some(Ok(msg)) => {
let frames;
{
let addr = peer.lines.get_ref().get_ref().0.peer_addr().map_err(|e| { format!("Cannot get peer_addr. {}", e) })?;
tracing::info!("New frame from peer: {:?}, bytes: {}", addr, msg.len());
let mut processor = processor.lock().await;
let data: Vec<u8> = msg.into();
frames = processor.process(addr, &data).map_err(|e| { format!("Cannot data process because: {} for peer: {:?}", e, addr)})?;
drop(processor);
}
for f in frames
{
let msg = Bytes::from(f.bytes);
match f.peer
{
Some(p) => {
{
let mut state = state.lock().await;
tracing::warn!("send_to_peer msg.len: {}", msg.len());
state.send_to_peer(p, &msg).await;
}
},
None => {
tracing::warn!("send to: {:?}, msg len: {:?}", peer.lines, msg.len());
peer.lines.send(msg).await?;
}
};
}
}
Some(Err(e)) if e.kind() == ErrorKind::UnexpectedEof => {
match peer.lines.get_ref().get_ref().0.peer_addr()
{
Ok(addr) => {
tracing::info!("EOF. Disconnecting... peer: {:?}", addr);
},
Err(_e) => {
tracing::info!("EOF. Disconnecting...");
}
}
}
// An error occurred.
Some(Err(e)) => {
match peer.lines.get_ref().get_ref().0.peer_addr()
{
Ok(addr) => {
tracing::error!("error = {:?}, peer: {:?}", e, addr);
},
Err(_e) => {
tracing::error!("error = {:?}", e);
}
}
// tracing::error!("error = {:?}", e);
}
// The stream has been exhausted.
None => {
match peer.lines.get_ref().get_ref().0.peer_addr()
{
Ok(addr) => {
tracing::info!("Disconnecting... peer: {:?}", addr);
},
Err(_e) => {
tracing::info!("Disconnecting...");
}
}
break
}
}
}
}
// If this section is reached it means that the client was disconnected!
// Let's let everyone still connected know about it.
{
let mut processor = processor.lock().await;
processor.remove_peer(&addr);
let mut state = state.lock().await;
state.peers.remove(&addr);
tracing::info!("Disconnected: {:?}", addr);
}
Ok(())
} Log:
|
Beta Was this translation helpful? Give feedback.
Answered by
Darksonn
May 23, 2023
Replies: 1 comment
-
The |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
hanusek
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The
buf.to_owned()
call doesn't clearbuf
. Instead, I recommend handling both branches withsplit_to
.