Skip to content

Commit d3fe5c0

Browse files
committed
allow server to mutate shared context
1 parent 5522a32 commit d3fe5c0

File tree

2 files changed

+53
-56
lines changed

2 files changed

+53
-56
lines changed

tarpc/src/client.rs

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,19 @@ where
166166
})
167167
.await
168168
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
169-
response_guard.response().await
169+
170+
let (response_ctx, r) = response_guard.response().await?;
171+
172+
ctx.shared_context = response_ctx.shared_context;
173+
174+
Ok(r)
170175
}
171176
}
172177

173178
/// A server response that is completed by request dispatch when the corresponding response
174179
/// arrives off the wire.
175180
struct ResponseGuard<'a, Resp> {
176-
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
181+
response: &'a mut oneshot::Receiver<Result<(ClientContext, Resp), RpcError>>,
177182
cancellation: &'a RequestCancellation,
178183
request_id: u64,
179184
cancel: bool,
@@ -201,7 +206,7 @@ pub enum RpcError {
201206
}
202207

203208
impl<Resp> ResponseGuard<'_, Resp> {
204-
async fn response(mut self) -> Result<Resp, RpcError> {
209+
async fn response(mut self) -> Result<(ClientContext, Resp), RpcError> {
205210
let response = (&mut self.response).await;
206211
// Cancel drop logic once a response has been received.
207212
self.cancel = false;
@@ -280,7 +285,7 @@ pub struct RequestDispatch<Req, Resp, C> {
280285
/// Requests that were dropped.
281286
canceled_requests: CanceledRequests,
282287
/// Requests already written to the wire that haven't yet received responses.
283-
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
288+
in_flight_requests: InFlightRequests<Resp>,
284289
/// Configures limits to prevent unlimited resource usage.
285290
config: Config,
286291
/// Produces errors that can be sent in response to any unprocessed requests at the time
@@ -296,7 +301,7 @@ where
296301
{
297302
fn in_flight_requests<'a>(
298303
self: &'a mut Pin<&mut Self>,
299-
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
304+
) -> &'a mut InFlightRequests<Resp> {
300305
self.as_mut().project().in_flight_requests
301306
}
302307

@@ -522,12 +527,10 @@ where
522527
let trace_context = ctx.trace_context;
523528
let deadline = ctx.deadline;
524529

525-
let client_context = context::ClientContext::new(ctx);
526-
527530
let request = ClientMessage::Request(Request {
528531
id: request_id,
529532
message: request,
530-
context: client_context,
533+
context: ClientContext::new(ctx),
531534
});
532535

533536
self.in_flight_requests()
@@ -580,7 +583,7 @@ where
580583
fn complete(mut self: Pin<&mut Self>, response: Response<ClientContext, Resp>) -> bool {
581584
if let Some(span) = self.in_flight_requests().complete_request(
582585
response.request_id,
583-
response.message.map_err(RpcError::Server),
586+
response.message.map_err(RpcError::Server).map(|m| (response.context, m)),
584587
) {
585588
let _entered = span.enter();
586589
tracing::debug!("ReceiveResponse");
@@ -688,11 +691,11 @@ where
688691
/// the lifecycle of the request.
689692
#[derive(Debug)]
690693
struct DispatchRequest<Req, Resp> {
691-
pub ctx: context::SharedContext,
694+
pub ctx: context::SharedContext, ///TODO: <-- this should be a &mut ClientContext
692695
pub span: Span,
693696
pub request_id: u64,
694697
pub request: Req,
695-
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
698+
pub response_completion: oneshot::Sender<Result<(ClientContext, Resp), RpcError>>,
696699
}
697700

698701
#[cfg(test)]
@@ -752,7 +755,7 @@ mod tests {
752755
.await
753756
.unwrap();
754757
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
755-
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
758+
assert_matches!(rx.try_recv(), Ok(Ok((_, resp))) if resp == "Resp");
756759
}
757760

758761
#[tokio::test]
@@ -774,12 +777,7 @@ mod tests {
774777
async fn dispatch_response_doesnt_cancel_after_complete() {
775778
let (cancellation, mut canceled_requests) = cancellations();
776779
let (tx, mut response) = oneshot::channel();
777-
tx.send(Ok(Response {
778-
request_id: 0,
779-
context: ClientContext::current(),
780-
message: Ok("well done"),
781-
}))
782-
.unwrap();
780+
tx.send(Ok((ClientContext::current(), "well done"))).unwrap();
783781
// resp's drop() is run, but should not send a cancel message.
784782
ResponseGuard {
785783
response: &mut response,
@@ -1116,37 +1114,11 @@ mod tests {
11161114
(Box::pin(dispatch), channel, server_channel)
11171115
}
11181116

1119-
async fn reserve_for_send<'a>(
1120-
channel: &'a mut Channel<String, String>,
1121-
response_completion: oneshot::Sender<Result<String, RpcError>>,
1122-
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
1123-
) -> impl FnOnce(&str) -> ResponseGuard<'a, String> {
1124-
let permit = channel.to_dispatch.reserve().await.unwrap();
1125-
|request| {
1126-
let request_id =
1127-
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
1128-
let request = DispatchRequest {
1129-
ctx: SharedContext::current(),
1130-
span: Span::current(),
1131-
request_id,
1132-
request: request.to_string(),
1133-
response_completion,
1134-
};
1135-
permit.send(request);
1136-
ResponseGuard {
1137-
response,
1138-
cancellation: &channel.cancellation,
1139-
request_id,
1140-
cancel: true,
1141-
}
1142-
}
1143-
}
1144-
11451117
async fn send_request<'a>(
11461118
channel: &'a mut Channel<String, String>,
11471119
request: &str,
1148-
response_completion: oneshot::Sender<Result<String, RpcError>>,
1149-
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
1120+
response_completion: oneshot::Sender<Result<(ClientContext, String), RpcError>>,
1121+
response: &'a mut oneshot::Receiver<Result<(ClientContext, String), RpcError>>,
11501122
) -> ResponseGuard<'a, String> {
11511123
let request_id =
11521124
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
@@ -1167,6 +1139,32 @@ mod tests {
11671139
response_guard
11681140
}
11691141

1142+
async fn reserve_for_send<'a>(
1143+
channel: &'a mut Channel<String, String>,
1144+
response_completion: oneshot::Sender<Result<(ClientContext, String), RpcError>>,
1145+
response: &'a mut oneshot::Receiver<Result<(ClientContext, String), RpcError>>,
1146+
) -> impl FnOnce(&str) -> ResponseGuard<'a, String> {
1147+
let permit = channel.to_dispatch.reserve().await.unwrap();
1148+
|request| {
1149+
let request_id =
1150+
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
1151+
let request = DispatchRequest {
1152+
ctx: SharedContext::current(),
1153+
span: Span::current(),
1154+
request_id,
1155+
request: request.to_string(),
1156+
response_completion,
1157+
};
1158+
permit.send(request);
1159+
ResponseGuard {
1160+
response,
1161+
cancellation: &channel.cancellation,
1162+
request_id,
1163+
cancel: true,
1164+
}
1165+
}
1166+
}
1167+
11701168
async fn send_response(
11711169
channel: &mut UnboundedChannel<
11721170
ClientMessage<ClientContext, String>,

tarpc/src/client/in_flight_requests.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use crate::{
2-
trace,
3-
util::{Compact, TimeUntil},
4-
};
1+
use crate::{trace, util::{Compact, TimeUntil}};
52
use fnv::FnvHashMap;
63
use std::time::Instant;
74
use std::{
@@ -11,6 +8,8 @@ use std::{
118
use tokio::sync::oneshot;
129
use tokio_util::time::delay_queue::{self, DelayQueue};
1310
use tracing::Span;
11+
use crate::client::RpcError;
12+
use crate::context::ClientContext;
1413

1514
/// Requests already written to the wire that haven't yet received responses.
1615
#[derive(Debug)]
@@ -32,7 +31,7 @@ impl<Resp> Default for InFlightRequests<Resp> {
3231
struct RequestData<Res> {
3332
ctx: trace::Context,
3433
span: Span,
35-
response_completion: oneshot::Sender<Res>,
34+
response_completion: oneshot::Sender<Result<(ClientContext, Res), RpcError>>,
3635
/// The key to remove the timer for the request's deadline.
3736
deadline_key: delay_queue::Key,
3837
}
@@ -60,7 +59,7 @@ impl<Res> InFlightRequests<Res> {
6059
ctx: trace::Context,
6160
deadline: Instant,
6261
span: Span,
63-
response_completion: oneshot::Sender<Res>,
62+
response_completion: oneshot::Sender<Result<(ClientContext, Res), RpcError>>,
6463
) -> Result<(), AlreadyExistsError> {
6564
match self.request_data.entry(request_id) {
6665
hash_map::Entry::Vacant(vacant) => {
@@ -78,8 +77,8 @@ impl<Res> InFlightRequests<Res> {
7877
}
7978
}
8079

81-
/// Removes a request without aborting. Returns true iff the request was found.
82-
pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
80+
/// Removes a request without aborting. Returns true if the request was found.
81+
pub fn complete_request(&mut self, request_id: u64, result: Result<(ClientContext, Res), RpcError>) -> Option<Span> {
8382
if let Some(request_data) = self.request_data.remove(&request_id) {
8483
self.request_data.compact(0.1);
8584
self.deadlines.remove(&request_data.deadline_key);
@@ -97,7 +96,7 @@ impl<Res> InFlightRequests<Res> {
9796
/// Returns Spans for all completes requests.
9897
pub fn complete_all_requests<'a>(
9998
&'a mut self,
100-
mut result: impl FnMut() -> Res + 'a,
99+
mut result: impl FnMut() -> Result<(ClientContext, Res), RpcError> + 'a,
101100
) -> impl Iterator<Item = Span> + 'a {
102101
self.deadlines.clear();
103102
self.request_data.drain().map(move |(_, request_data)| {
@@ -123,7 +122,7 @@ impl<Res> InFlightRequests<Res> {
123122
pub fn poll_expired(
124123
&mut self,
125124
cx: &mut Context,
126-
expired_error: impl Fn() -> Res,
125+
expired_error: impl Fn() -> Result<(ClientContext, Res), RpcError>,
127126
) -> Poll<Option<u64>> {
128127
self.deadlines.poll_expired(cx).map(|expired| {
129128
let request_id = expired?.into_inner();

0 commit comments

Comments
 (0)