Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ all-features = true
async-compression = { version = "0.3.7", features = ["brotli", "deflate", "gzip", "tokio"], optional = true }
bytes = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
futures-channel = { version = "0.3.17", features = ["sink"]}
headers = "0.3"
http = "0.2"
hyper = { version = "0.14", features = ["stream", "server", "http1", "tcp", "client"] }
Expand Down
75 changes: 65 additions & 10 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ use std::net::SocketAddr;
#[cfg(feature = "websocket")]
use std::pin::Pin;
#[cfg(feature = "websocket")]
use std::task::Context;
#[cfg(feature = "websocket")]
use std::task::{self, Poll};

use bytes::Bytes;
#[cfg(feature = "websocket")]
use futures_channel::mpsc;
#[cfg(feature = "websocket")]
use futures_util::StreamExt;
use futures_util::{future, FutureExt, TryFutureExt};
use http::{
Expand All @@ -102,15 +106,17 @@ use http::{
use serde::Serialize;
use serde_json;
#[cfg(feature = "websocket")]
use tokio::sync::{mpsc, oneshot};
#[cfg(feature = "websocket")]
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio::sync::oneshot;

use crate::filter::Filter;
#[cfg(feature = "websocket")]
use crate::filters::ws::Message;
use crate::reject::IsReject;
use crate::reply::Reply;
use crate::route::{self, Route};
use crate::Request;
#[cfg(feature = "websocket")]
use crate::{Sink, Stream};

use self::inner::OneOrTuple;

Expand Down Expand Up @@ -484,9 +490,8 @@ impl WsBuilder {
F::Error: IsReject + Send,
{
let (upgraded_tx, upgraded_rx) = oneshot::channel();
let (wr_tx, wr_rx) = mpsc::unbounded_channel();
let wr_rx = UnboundedReceiverStream::new(wr_rx);
let (rd_tx, rd_rx) = mpsc::unbounded_channel();
let (wr_tx, wr_rx) = mpsc::unbounded();
let (rd_tx, rd_rx) = mpsc::unbounded();

tokio::spawn(async move {
use tokio_tungstenite::tungstenite::protocol;
Expand Down Expand Up @@ -546,7 +551,7 @@ impl WsBuilder {
Ok(m) => future::ready(!m.is_close()),
})
.for_each(move |item| {
rd_tx.send(item).expect("ws receive error");
rd_tx.unbounded_send(item).expect("ws receive error");
future::ready(())
});

Expand All @@ -573,13 +578,13 @@ impl WsClient {

/// Send a websocket message to the server.
pub async fn send(&mut self, msg: crate::ws::Message) {
self.tx.send(msg).unwrap();
self.tx.unbounded_send(msg).unwrap();
}

/// Receive a websocket message from the server.
pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
self.rx
.recv()
.next()
.await
.map(|result| result.map_err(WsError::new))
.unwrap_or_else(|| {
Expand All @@ -591,7 +596,7 @@ impl WsClient {
/// Assert the server has closed the connection.
pub async fn recv_closed(&mut self) -> Result<(), WsError> {
self.rx
.recv()
.next()
.await
.map(|result| match result {
Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
Expand All @@ -602,6 +607,11 @@ impl WsClient {
Ok(())
})
}

fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
let this = Pin::into_inner(self);
Pin::new(&mut this.tx)
}
}

#[cfg(feature = "websocket")]
Expand All @@ -611,6 +621,51 @@ impl fmt::Debug for WsClient {
}
}

#[cfg(feature = "websocket")]
impl Sink<crate::ws::Message> for WsClient {
type Error = WsError;

fn poll_ready(
self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.pinned_tx().poll_ready(context).map_err(WsError::new)
}

fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
self.pinned_tx().start_send(message).map_err(WsError::new)
}

fn poll_flush(
self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.pinned_tx().poll_flush(context).map_err(WsError::new)
}

fn poll_close(
self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.pinned_tx().poll_close(context).map_err(WsError::new)
}
}

#[cfg(feature = "websocket")]
impl Stream for WsClient {
type Item = Result<crate::ws::Message, WsError>;

fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = Pin::into_inner(self);
let rx = Pin::new(&mut this.rx);
match rx.poll_next(context) {
Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}

// ===== impl WsError =====

#[cfg(feature = "websocket")]
Expand Down
15 changes: 15 additions & 0 deletions tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ async fn binary() {
assert_eq!(msg.as_bytes(), &b"bonk"[..]);
}

#[tokio::test]
async fn wsclient_sink_and_stream() {
let _ = pretty_env_logger::try_init();

let mut client = warp::test::ws()
.handshake(ws_echo())
.await
.expect("handshake");

let message = warp::ws::Message::text("hello");
SinkExt::send(&mut client, message.clone()).await.unwrap();
let received_message = client.next().await.unwrap().unwrap();
assert_eq!(message, received_message);
}

#[tokio::test]
async fn close_frame() {
let _ = pretty_env_logger::try_init();
Expand Down