@@ -89,10 +89,14 @@ use std::net::SocketAddr;
8989#[ cfg( feature = "websocket" ) ]
9090use std:: pin:: Pin ;
9191#[ cfg( feature = "websocket" ) ]
92+ use std:: task:: Context ;
93+ #[ cfg( feature = "websocket" ) ]
9294use std:: task:: { self , Poll } ;
9395
9496use bytes:: Bytes ;
9597#[ cfg( feature = "websocket" ) ]
98+ use futures_channel:: mpsc;
99+ #[ cfg( feature = "websocket" ) ]
96100use futures_util:: StreamExt ;
97101use futures_util:: { future, FutureExt , TryFutureExt } ;
98102use http:: {
@@ -102,15 +106,17 @@ use http::{
102106use serde:: Serialize ;
103107use serde_json;
104108#[ cfg( feature = "websocket" ) ]
105- use tokio:: sync:: { mpsc, oneshot} ;
106- #[ cfg( feature = "websocket" ) ]
107- use tokio_stream:: wrappers:: UnboundedReceiverStream ;
109+ use tokio:: sync:: oneshot;
108110
109111use crate :: filter:: Filter ;
112+ #[ cfg( feature = "websocket" ) ]
113+ use crate :: filters:: ws:: Message ;
110114use crate :: reject:: IsReject ;
111115use crate :: reply:: Reply ;
112116use crate :: route:: { self , Route } ;
113117use crate :: Request ;
118+ #[ cfg( feature = "websocket" ) ]
119+ use crate :: { Sink , Stream } ;
114120
115121use self :: inner:: OneOrTuple ;
116122
@@ -484,9 +490,8 @@ impl WsBuilder {
484490 F :: Error : IsReject + Send ,
485491 {
486492 let ( upgraded_tx, upgraded_rx) = oneshot:: channel ( ) ;
487- let ( wr_tx, wr_rx) = mpsc:: unbounded_channel ( ) ;
488- let wr_rx = UnboundedReceiverStream :: new ( wr_rx) ;
489- let ( rd_tx, rd_rx) = mpsc:: unbounded_channel ( ) ;
493+ let ( wr_tx, wr_rx) = mpsc:: unbounded ( ) ;
494+ let ( rd_tx, rd_rx) = mpsc:: unbounded ( ) ;
490495
491496 tokio:: spawn ( async move {
492497 use tokio_tungstenite:: tungstenite:: protocol;
@@ -546,7 +551,7 @@ impl WsBuilder {
546551 Ok ( m) => future:: ready ( !m. is_close ( ) ) ,
547552 } )
548553 . for_each ( move |item| {
549- rd_tx. send ( item) . expect ( "ws receive error" ) ;
554+ rd_tx. unbounded_send ( item) . expect ( "ws receive error" ) ;
550555 future:: ready ( ( ) )
551556 } ) ;
552557
@@ -573,13 +578,13 @@ impl WsClient {
573578
574579 /// Send a websocket message to the server.
575580 pub async fn send ( & mut self , msg : crate :: ws:: Message ) {
576- self . tx . send ( msg) . unwrap ( ) ;
581+ self . tx . unbounded_send ( msg) . unwrap ( ) ;
577582 }
578583
579584 /// Receive a websocket message from the server.
580585 pub async fn recv ( & mut self ) -> Result < crate :: filters:: ws:: Message , WsError > {
581586 self . rx
582- . recv ( )
587+ . next ( )
583588 . await
584589 . map ( |result| result. map_err ( WsError :: new) )
585590 . unwrap_or_else ( || {
@@ -591,7 +596,7 @@ impl WsClient {
591596 /// Assert the server has closed the connection.
592597 pub async fn recv_closed ( & mut self ) -> Result < ( ) , WsError > {
593598 self . rx
594- . recv ( )
599+ . next ( )
595600 . await
596601 . map ( |result| match result {
597602 Ok ( msg) => Err ( WsError :: new ( format ! ( "received message: {:?}" , msg) ) ) ,
@@ -602,6 +607,11 @@ impl WsClient {
602607 Ok ( ( ) )
603608 } )
604609 }
610+
611+ fn pinned_tx ( self : Pin < & mut Self > ) -> Pin < & mut mpsc:: UnboundedSender < crate :: ws:: Message > > {
612+ let this = Pin :: into_inner ( self ) ;
613+ Pin :: new ( & mut this. tx )
614+ }
605615}
606616
607617#[ cfg( feature = "websocket" ) ]
@@ -611,6 +621,51 @@ impl fmt::Debug for WsClient {
611621 }
612622}
613623
624+ #[ cfg( feature = "websocket" ) ]
625+ impl Sink < crate :: ws:: Message > for WsClient {
626+ type Error = WsError ;
627+
628+ fn poll_ready (
629+ self : Pin < & mut Self > ,
630+ context : & mut Context < ' _ > ,
631+ ) -> Poll < Result < ( ) , Self :: Error > > {
632+ self . pinned_tx ( ) . poll_ready ( context) . map_err ( WsError :: new)
633+ }
634+
635+ fn start_send ( self : Pin < & mut Self > , message : Message ) -> Result < ( ) , Self :: Error > {
636+ self . pinned_tx ( ) . start_send ( message) . map_err ( WsError :: new)
637+ }
638+
639+ fn poll_flush (
640+ self : Pin < & mut Self > ,
641+ context : & mut Context < ' _ > ,
642+ ) -> Poll < Result < ( ) , Self :: Error > > {
643+ self . pinned_tx ( ) . poll_flush ( context) . map_err ( WsError :: new)
644+ }
645+
646+ fn poll_close (
647+ self : Pin < & mut Self > ,
648+ context : & mut Context < ' _ > ,
649+ ) -> Poll < Result < ( ) , Self :: Error > > {
650+ self . pinned_tx ( ) . poll_close ( context) . map_err ( WsError :: new)
651+ }
652+ }
653+
654+ #[ cfg( feature = "websocket" ) ]
655+ impl Stream for WsClient {
656+ type Item = Result < crate :: ws:: Message , WsError > ;
657+
658+ fn poll_next ( self : Pin < & mut Self > , context : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
659+ let this = Pin :: into_inner ( self ) ;
660+ let rx = Pin :: new ( & mut this. rx ) ;
661+ match rx. poll_next ( context) {
662+ Poll :: Ready ( Some ( result) ) => Poll :: Ready ( Some ( result. map_err ( WsError :: new) ) ) ,
663+ Poll :: Ready ( None ) => Poll :: Ready ( None ) ,
664+ Poll :: Pending => Poll :: Pending ,
665+ }
666+ }
667+ }
668+
614669// ===== impl WsError =====
615670
616671#[ cfg( feature = "websocket" ) ]
0 commit comments