1717 * Boston, MA 02111-1307, USA.
1818 */
1919
20+ use actix:: clock:: interval;
2021use actix:: prelude:: * ;
22+ use actix_web:: rt:: pin;
2123use actix_web:: web:: Bytes ;
2224use actix_web:: { get, rt, web, HttpRequest , HttpResponse } ;
2325use actix_web_validator:: Query ;
24- use actix_ws:: { AggregatedMessage , ProtocolError , Session } ;
26+ use actix_ws:: { AggregatedMessage , Session } ;
2527use db_connector:: models:: wg_keys:: WgKey ;
2628use diesel:: prelude:: * ;
29+ use futures_util:: future:: Either ;
30+ use futures_util:: StreamExt ;
2731use serde:: { Deserialize , Serialize } ;
2832use std:: collections:: HashMap ;
2933use std:: str:: FromStr ;
3034use std:: sync:: Arc ;
31- use std:: time:: Instant ;
35+ use std:: time:: { Duration , Instant } ;
3236use validator:: { Validate , ValidationError } ;
3337use futures_util:: lock:: Mutex ;
3438
@@ -44,6 +48,9 @@ use crate::{
4448 AppState , BridgeState ,
4549} ;
4650
51+ const HEARTBEAT_INTERVAL : Duration = Duration :: from_secs ( 5 ) ;
52+ const CLIENT_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
53+
4754#[ derive( Deserialize , Serialize , Validate ) ]
4855struct WsQuery {
4956 #[ validate( custom( function = validate_key_id) ) ]
@@ -106,12 +113,16 @@ impl WebClient {
106113 }
107114 }
108115
109- pub async fn handle_message ( & mut self , msg : Result < AggregatedMessage , ProtocolError > ) {
116+ pub async fn handle_message ( & mut self , msg : AggregatedMessage , last_heartbeat : & mut Instant ) {
110117 match msg {
111- Ok ( AggregatedMessage :: Ping ( msg) ) => {
118+ AggregatedMessage :: Ping ( msg) => {
112119 self . session . pong ( & msg) . await . unwrap ( ) ;
120+ * last_heartbeat = Instant :: now ( ) ;
121+ } ,
122+ AggregatedMessage :: Pong ( _) => {
123+ * last_heartbeat = Instant :: now ( ) ;
113124 } ,
114- Ok ( AggregatedMessage :: Binary ( msg) ) => {
125+ AggregatedMessage :: Binary ( msg) => {
115126 let peer_sock_addr = {
116127 let meta = RemoteConnMeta {
117128 charger_id : self . charger_id . clone ( ) ,
@@ -324,7 +335,7 @@ async fn start_ws(
324335 bridge_state. port_discovery . clone ( ) ,
325336 ) . await ?;
326337
327- let ( resp, session, stream) = actix_ws:: handle ( & req, stream) ?;
338+ let ( resp, mut session, stream) = actix_ws:: handle ( & req, stream) ?;
328339 let mut stream = stream. aggregate_continuations ( )
329340 . max_continuation_size ( 2_usize . pow ( 20 ) ) ;
330341
@@ -351,14 +362,31 @@ async fn start_ws(
351362 state,
352363 bridge_state,
353364 keys. connection_no ,
354- session,
365+ session. clone ( ) ,
355366 ) . await ;
356- while let Some ( msg) = stream. recv ( ) . await {
357- if let Ok ( AggregatedMessage :: Close ( _) ) = & msg {
358- log:: info!( "/ws close connection" ) ;
359- break ;
367+
368+ let mut last_heartbeat = Instant :: now ( ) ;
369+ let mut interval = interval ( HEARTBEAT_INTERVAL ) ;
370+ loop {
371+ let tick = interval. tick ( ) ;
372+ pin ! ( tick) ;
373+
374+ match futures_util:: future:: select ( stream. next ( ) , tick) . await {
375+ Either :: Left ( ( Some ( Ok ( AggregatedMessage :: Close ( _) ) ) , _) ) => break ,
376+ Either :: Left ( ( Some ( Ok ( msg) ) , _) ) => client. handle_message ( msg, & mut last_heartbeat) . await ,
377+ Either :: Left ( ( Some ( err) , _) ) => {
378+ log:: error!( "Websocket Error during connection: {:?}" , err) ;
379+ break ;
380+ } ,
381+ Either :: Left ( ( None , _) ) => break ,
382+ Either :: Right ( _) => {
383+ if Instant :: now ( ) . duration_since ( last_heartbeat) > CLIENT_TIMEOUT {
384+ log:: debug!( "Client quietly quit." ) ;
385+ break ;
386+ }
387+ let _ = session. ping ( b"" ) . await ;
388+ }
360389 }
361- client. handle_message ( msg) . await ;
362390 }
363391 client. stop ( ) . await ;
364392 } ) ;
0 commit comments