Skip to content

Commit 1948044

Browse files
authored
Implement Sink and Stream for WsClient (#907)
* Switch WsClient from tokio's unbounded channel to futures unbounded channel * Implement Sink for WsClient * Implement Stream for WsClient * Test sink and stream * Use WsError for the Sink implementation
1 parent 25eedf6 commit 1948044

File tree

3 files changed

+81
-10
lines changed

3 files changed

+81
-10
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ all-features = true
2020
async-compression = { version = "0.3.7", features = ["brotli", "deflate", "gzip", "tokio"], optional = true }
2121
bytes = "1.0"
2222
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
23+
futures-channel = { version = "0.3.17", features = ["sink"]}
2324
headers = "0.3"
2425
http = "0.2"
2526
hyper = { version = "0.14", features = ["stream", "server", "http1", "tcp", "client"] }

src/test.rs

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,14 @@ use std::net::SocketAddr;
8989
#[cfg(feature = "websocket")]
9090
use std::pin::Pin;
9191
#[cfg(feature = "websocket")]
92+
use std::task::Context;
93+
#[cfg(feature = "websocket")]
9294
use std::task::{self, Poll};
9395

9496
use bytes::Bytes;
9597
#[cfg(feature = "websocket")]
98+
use futures_channel::mpsc;
99+
#[cfg(feature = "websocket")]
96100
use futures_util::StreamExt;
97101
use futures_util::{future, FutureExt, TryFutureExt};
98102
use http::{
@@ -102,15 +106,17 @@ use http::{
102106
use serde::Serialize;
103107
use serde_json;
104108
#[cfg(feature = "websocket")]
105-
use tokio::sync::{mpsc, oneshot};
106-
#[cfg(feature = "websocket")]
107-
use tokio_stream::wrappers::UnboundedReceiverStream;
109+
use tokio::sync::oneshot;
108110

109111
use crate::filter::Filter;
112+
#[cfg(feature = "websocket")]
113+
use crate::filters::ws::Message;
110114
use crate::reject::IsReject;
111115
use crate::reply::Reply;
112116
use crate::route::{self, Route};
113117
use crate::Request;
118+
#[cfg(feature = "websocket")]
119+
use crate::{Sink, Stream};
114120

115121
use self::inner::OneOrTuple;
116122

@@ -484,9 +490,8 @@ impl WsBuilder {
484490
F::Error: IsReject + Send,
485491
{
486492
let (upgraded_tx, upgraded_rx) = oneshot::channel();
487-
let (wr_tx, wr_rx) = mpsc::unbounded_channel();
488-
let wr_rx = UnboundedReceiverStream::new(wr_rx);
489-
let (rd_tx, rd_rx) = mpsc::unbounded_channel();
493+
let (wr_tx, wr_rx) = mpsc::unbounded();
494+
let (rd_tx, rd_rx) = mpsc::unbounded();
490495

491496
tokio::spawn(async move {
492497
use tokio_tungstenite::tungstenite::protocol;
@@ -546,7 +551,7 @@ impl WsBuilder {
546551
Ok(m) => future::ready(!m.is_close()),
547552
})
548553
.for_each(move |item| {
549-
rd_tx.send(item).expect("ws receive error");
554+
rd_tx.unbounded_send(item).expect("ws receive error");
550555
future::ready(())
551556
});
552557

@@ -573,13 +578,13 @@ impl WsClient {
573578

574579
/// Send a websocket message to the server.
575580
pub async fn send(&mut self, msg: crate::ws::Message) {
576-
self.tx.send(msg).unwrap();
581+
self.tx.unbounded_send(msg).unwrap();
577582
}
578583

579584
/// Receive a websocket message from the server.
580585
pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
581586
self.rx
582-
.recv()
587+
.next()
583588
.await
584589
.map(|result| result.map_err(WsError::new))
585590
.unwrap_or_else(|| {
@@ -591,7 +596,7 @@ impl WsClient {
591596
/// Assert the server has closed the connection.
592597
pub async fn recv_closed(&mut self) -> Result<(), WsError> {
593598
self.rx
594-
.recv()
599+
.next()
595600
.await
596601
.map(|result| match result {
597602
Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
@@ -602,6 +607,11 @@ impl WsClient {
602607
Ok(())
603608
})
604609
}
610+
611+
fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
612+
let this = Pin::into_inner(self);
613+
Pin::new(&mut this.tx)
614+
}
605615
}
606616

607617
#[cfg(feature = "websocket")]
@@ -611,6 +621,51 @@ impl fmt::Debug for WsClient {
611621
}
612622
}
613623

624+
#[cfg(feature = "websocket")]
625+
impl Sink<crate::ws::Message> for WsClient {
626+
type Error = WsError;
627+
628+
fn poll_ready(
629+
self: Pin<&mut Self>,
630+
context: &mut Context<'_>,
631+
) -> Poll<Result<(), Self::Error>> {
632+
self.pinned_tx().poll_ready(context).map_err(WsError::new)
633+
}
634+
635+
fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
636+
self.pinned_tx().start_send(message).map_err(WsError::new)
637+
}
638+
639+
fn poll_flush(
640+
self: Pin<&mut Self>,
641+
context: &mut Context<'_>,
642+
) -> Poll<Result<(), Self::Error>> {
643+
self.pinned_tx().poll_flush(context).map_err(WsError::new)
644+
}
645+
646+
fn poll_close(
647+
self: Pin<&mut Self>,
648+
context: &mut Context<'_>,
649+
) -> Poll<Result<(), Self::Error>> {
650+
self.pinned_tx().poll_close(context).map_err(WsError::new)
651+
}
652+
}
653+
654+
#[cfg(feature = "websocket")]
655+
impl Stream for WsClient {
656+
type Item = Result<crate::ws::Message, WsError>;
657+
658+
fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
659+
let this = Pin::into_inner(self);
660+
let rx = Pin::new(&mut this.rx);
661+
match rx.poll_next(context) {
662+
Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
663+
Poll::Ready(None) => Poll::Ready(None),
664+
Poll::Pending => Poll::Pending,
665+
}
666+
}
667+
}
668+
614669
// ===== impl WsError =====
615670

616671
#[cfg(feature = "websocket")]

tests/ws.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ async fn binary() {
8181
assert_eq!(msg.as_bytes(), &b"bonk"[..]);
8282
}
8383

84+
#[tokio::test]
85+
async fn wsclient_sink_and_stream() {
86+
let _ = pretty_env_logger::try_init();
87+
88+
let mut client = warp::test::ws()
89+
.handshake(ws_echo())
90+
.await
91+
.expect("handshake");
92+
93+
let message = warp::ws::Message::text("hello");
94+
SinkExt::send(&mut client, message.clone()).await.unwrap();
95+
let received_message = client.next().await.unwrap().unwrap();
96+
assert_eq!(message, received_message);
97+
}
98+
8499
#[tokio::test]
85100
async fn close_frame() {
86101
let _ = pretty_env_logger::try_init();

0 commit comments

Comments
 (0)