@@ -166,14 +166,19 @@ where
166166 } )
167167 . await
168168 . map_err ( |mpsc:: error:: SendError ( _) | RpcError :: Shutdown ) ?;
169- response_guard. response ( ) . await
169+
170+ let ( response_ctx, r) = response_guard. response ( ) . await ?;
171+
172+ ctx. shared_context = response_ctx. shared_context ;
173+
174+ Ok ( r)
170175 }
171176}
172177
173178/// A server response that is completed by request dispatch when the corresponding response
174179/// arrives off the wire.
175180struct ResponseGuard < ' a , Resp > {
176- response : & ' a mut oneshot:: Receiver < Result < Resp , RpcError > > ,
181+ response : & ' a mut oneshot:: Receiver < Result < ( ClientContext , Resp ) , RpcError > > ,
177182 cancellation : & ' a RequestCancellation ,
178183 request_id : u64 ,
179184 cancel : bool ,
@@ -201,7 +206,7 @@ pub enum RpcError {
201206}
202207
203208impl < Resp > ResponseGuard < ' _ , Resp > {
204- async fn response ( mut self ) -> Result < Resp , RpcError > {
209+ async fn response ( mut self ) -> Result < ( ClientContext , Resp ) , RpcError > {
205210 let response = ( & mut self . response ) . await ;
206211 // Cancel drop logic once a response has been received.
207212 self . cancel = false ;
@@ -280,7 +285,7 @@ pub struct RequestDispatch<Req, Resp, C> {
280285 /// Requests that were dropped.
281286 canceled_requests : CanceledRequests ,
282287 /// Requests already written to the wire that haven't yet received responses.
283- in_flight_requests : InFlightRequests < Result < Resp , RpcError > > ,
288+ in_flight_requests : InFlightRequests < Resp > ,
284289 /// Configures limits to prevent unlimited resource usage.
285290 config : Config ,
286291 /// Produces errors that can be sent in response to any unprocessed requests at the time
@@ -296,7 +301,7 @@ where
296301{
297302 fn in_flight_requests < ' a > (
298303 self : & ' a mut Pin < & mut Self > ,
299- ) -> & ' a mut InFlightRequests < Result < Resp , RpcError > > {
304+ ) -> & ' a mut InFlightRequests < Resp > {
300305 self . as_mut ( ) . project ( ) . in_flight_requests
301306 }
302307
@@ -522,12 +527,10 @@ where
522527 let trace_context = ctx. trace_context ;
523528 let deadline = ctx. deadline ;
524529
525- let client_context = context:: ClientContext :: new ( ctx) ;
526-
527530 let request = ClientMessage :: Request ( Request {
528531 id : request_id,
529532 message : request,
530- context : client_context ,
533+ context : ClientContext :: new ( ctx ) ,
531534 } ) ;
532535
533536 self . in_flight_requests ( )
@@ -580,7 +583,7 @@ where
580583 fn complete ( mut self : Pin < & mut Self > , response : Response < ClientContext , Resp > ) -> bool {
581584 if let Some ( span) = self . in_flight_requests ( ) . complete_request (
582585 response. request_id ,
583- response. message . map_err ( RpcError :: Server ) ,
586+ response. message . map_err ( RpcError :: Server ) . map ( |m| ( response . context , m ) ) ,
584587 ) {
585588 let _entered = span. enter ( ) ;
586589 tracing:: debug!( "ReceiveResponse" ) ;
@@ -688,11 +691,11 @@ where
688691/// the lifecycle of the request.
689692#[ derive( Debug ) ]
690693struct DispatchRequest < Req , Resp > {
691- pub ctx : context:: SharedContext ,
694+ pub ctx : context:: SharedContext , ///TODO: <-- this should be a &mut ClientContext
692695 pub span : Span ,
693696 pub request_id : u64 ,
694697 pub request : Req ,
695- pub response_completion : oneshot:: Sender < Result < Resp , RpcError > > ,
698+ pub response_completion : oneshot:: Sender < Result < ( ClientContext , Resp ) , RpcError > > ,
696699}
697700
698701#[ cfg( test) ]
@@ -752,7 +755,7 @@ mod tests {
752755 . await
753756 . unwrap ( ) ;
754757 assert_matches ! ( dispatch. as_mut( ) . poll( cx) , Poll :: Pending ) ;
755- assert_matches ! ( rx. try_recv( ) , Ok ( Ok ( resp) ) if resp == "Resp" ) ;
758+ assert_matches ! ( rx. try_recv( ) , Ok ( Ok ( ( _ , resp) ) ) if resp == "Resp" ) ;
756759 }
757760
758761 #[ tokio:: test]
@@ -774,12 +777,7 @@ mod tests {
774777 async fn dispatch_response_doesnt_cancel_after_complete ( ) {
775778 let ( cancellation, mut canceled_requests) = cancellations ( ) ;
776779 let ( tx, mut response) = oneshot:: channel ( ) ;
777- tx. send ( Ok ( Response {
778- request_id : 0 ,
779- context : ClientContext :: current ( ) ,
780- message : Ok ( "well done" ) ,
781- } ) )
782- . unwrap ( ) ;
780+ tx. send ( Ok ( ( ClientContext :: current ( ) , "well done" ) ) ) . unwrap ( ) ;
783781 // resp's drop() is run, but should not send a cancel message.
784782 ResponseGuard {
785783 response : & mut response,
@@ -1116,37 +1114,11 @@ mod tests {
11161114 ( Box :: pin ( dispatch) , channel, server_channel)
11171115 }
11181116
1119- async fn reserve_for_send < ' a > (
1120- channel : & ' a mut Channel < String , String > ,
1121- response_completion : oneshot:: Sender < Result < String , RpcError > > ,
1122- response : & ' a mut oneshot:: Receiver < Result < String , RpcError > > ,
1123- ) -> impl FnOnce ( & str ) -> ResponseGuard < ' a , String > {
1124- let permit = channel. to_dispatch . reserve ( ) . await . unwrap ( ) ;
1125- |request| {
1126- let request_id =
1127- u64:: try_from ( channel. next_request_id . fetch_add ( 1 , Ordering :: Relaxed ) ) . unwrap ( ) ;
1128- let request = DispatchRequest {
1129- ctx : SharedContext :: current ( ) ,
1130- span : Span :: current ( ) ,
1131- request_id,
1132- request : request. to_string ( ) ,
1133- response_completion,
1134- } ;
1135- permit. send ( request) ;
1136- ResponseGuard {
1137- response,
1138- cancellation : & channel. cancellation ,
1139- request_id,
1140- cancel : true ,
1141- }
1142- }
1143- }
1144-
11451117 async fn send_request < ' a > (
11461118 channel : & ' a mut Channel < String , String > ,
11471119 request : & str ,
1148- response_completion : oneshot:: Sender < Result < String , RpcError > > ,
1149- response : & ' a mut oneshot:: Receiver < Result < String , RpcError > > ,
1120+ response_completion : oneshot:: Sender < Result < ( ClientContext , String ) , RpcError > > ,
1121+ response : & ' a mut oneshot:: Receiver < Result < ( ClientContext , String ) , RpcError > > ,
11501122 ) -> ResponseGuard < ' a , String > {
11511123 let request_id =
11521124 u64:: try_from ( channel. next_request_id . fetch_add ( 1 , Ordering :: Relaxed ) ) . unwrap ( ) ;
@@ -1167,6 +1139,32 @@ mod tests {
11671139 response_guard
11681140 }
11691141
1142+ async fn reserve_for_send < ' a > (
1143+ channel : & ' a mut Channel < String , String > ,
1144+ response_completion : oneshot:: Sender < Result < ( ClientContext , String ) , RpcError > > ,
1145+ response : & ' a mut oneshot:: Receiver < Result < ( ClientContext , String ) , RpcError > > ,
1146+ ) -> impl FnOnce ( & str ) -> ResponseGuard < ' a , String > {
1147+ let permit = channel. to_dispatch . reserve ( ) . await . unwrap ( ) ;
1148+ |request| {
1149+ let request_id =
1150+ u64:: try_from ( channel. next_request_id . fetch_add ( 1 , Ordering :: Relaxed ) ) . unwrap ( ) ;
1151+ let request = DispatchRequest {
1152+ ctx : SharedContext :: current ( ) ,
1153+ span : Span :: current ( ) ,
1154+ request_id,
1155+ request : request. to_string ( ) ,
1156+ response_completion,
1157+ } ;
1158+ permit. send ( request) ;
1159+ ResponseGuard {
1160+ response,
1161+ cancellation : & channel. cancellation ,
1162+ request_id,
1163+ cancel : true ,
1164+ }
1165+ }
1166+ }
1167+
11701168 async fn send_response (
11711169 channel : & mut UnboundedChannel <
11721170 ClientMessage < ClientContext , String > ,
0 commit comments