diff --git a/Cargo.lock b/Cargo.lock index b1894de..8f7a016 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1291,7 +1291,6 @@ name = "twirp-build" version = "0.8.0" dependencies = [ "prettyplease", - "proc-macro2", "prost-build", "quote", "syn", diff --git a/crates/twirp-build/Cargo.toml b/crates/twirp-build/Cargo.toml index 908c318..900843e 100644 --- a/crates/twirp-build/Cargo.toml +++ b/crates/twirp-build/Cargo.toml @@ -18,4 +18,3 @@ prost-build = "0.13" prettyplease = { version = "0.2" } quote = "1.0" syn = "2.0" -proc-macro2 = "1.0" diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 0540c67..78216aa 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -166,19 +166,25 @@ impl prost_build::ServiceGenerator for ServiceGenerator { let mut client_methods = Vec::with_capacity(service.methods.len()); for m in &service.methods { let name = &m.name; + let name_request = format_ident!("{}_request", name); let input_type = &m.input_type; let output_type = &m.output_type; let request_path = format!("{}/{}", service.fqn, m.proto_name); client_trait_methods.push(quote! { - async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError>; + async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> { + self.#name_request(req)?.send().await + } + }); + client_trait_methods.push(quote! { + fn #name_request(&self, req: #input_type) -> Result, twirp::ClientError>; }); client_methods.push(quote! { - async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> { - self.request(#request_path, req).await + fn #name_request(&self, req: #input_type) -> Result, twirp::ClientError> { + self.request(#request_path, req) } - }) + }); } let client_trait = quote! { #[twirp::async_trait::async_trait] diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 5f8ac5b..c38176b 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use std::vec; use async_trait::async_trait; +use http::{HeaderName, HeaderValue}; use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE}; use reqwest::StatusCode; use thiserror::Error; @@ -155,23 +156,12 @@ impl Client { } } - /// Make an HTTP twirp request. - pub async fn request(&self, path: &str, body: I) -> Result + /// Executes a `Request`. + pub(super) async fn execute(&self, req: reqwest::Request) -> Result where - I: prost::Message, O: prost::Message + Default, { - let mut url = self.inner.base_url.join(path)?; - if let Some(host) = &self.host { - url.set_host(Some(host))? - }; - let path = url.path().to_string(); - let req = self - .http_client - .post(url) - .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) - .body(serialize_proto_message(body)) - .build()?; + let path = req.url().path().to_string(); // Create and execute the middleware handlers let next = Next::new(&self.http_client, &self.inner.middlewares); @@ -204,6 +194,68 @@ impl Client { }), } } + + /// Start building a `Request` with a path and a request body. + /// + /// Returns a `RequestBuilder`, which will allow setting headers before sending. + pub fn request(&self, path: &str, body: I) -> Result> + where + I: prost::Message, + O: prost::Message + Default, + { + let mut url = self.inner.base_url.join(path)?; + if let Some(host) = &self.host { + url.set_host(Some(host))? + }; + + let req = self + .http_client + .post(url) + .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) + .body(serialize_proto_message(body)); + Ok(RequestBuilder::new(self.clone(), req)) + } +} + +pub struct RequestBuilder +where + O: prost::Message + Default, +{ + client: Client, + inner: reqwest::RequestBuilder, + _input: std::marker::PhantomData, + _output: std::marker::PhantomData, +} + +impl RequestBuilder +where + O: prost::Message + Default, +{ + pub fn new(client: Client, inner: reqwest::RequestBuilder) -> Self { + Self { + client, + inner, + _input: std::marker::PhantomData, + _output: std::marker::PhantomData, + } + } + + /// Add a `Header` to this Request. + pub fn header(mut self, key: K, value: V) -> RequestBuilder + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { + self.inner = self.inner.header(key, value); + self + } + + pub async fn send(self) -> Result { + let req = self.inner.build()?; + self.client.execute(req).await + } } // This concept of reqwest middleware is taken pretty much directly from: diff --git a/crates/twirp/src/lib.rs b/crates/twirp/src/lib.rs index 5b66b2b..6cbbb52 100644 --- a/crates/twirp/src/lib.rs +++ b/crates/twirp/src/lib.rs @@ -10,7 +10,7 @@ pub mod test; #[doc(hidden)] pub mod details; -pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result}; +pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, RequestBuilder, Result}; pub use context::Context; pub use error::*; // many constructors like `invalid_argument()` pub use http::Extensions; diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index e80effd..7489a76 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -121,7 +121,7 @@ pub trait TestApiClient { #[async_trait] impl TestApiClient for Client { async fn ping(&self, req: PingRequest) -> Result { - self.request("test.TestAPI/Ping", req).await + self.request("test.TestAPI/Ping", req)?.send().await } async fn boom(&self, _req: PingRequest) -> Result { diff --git a/example/src/bin/client.rs b/example/src/bin/client.rs index 89c6e71..51b132a 100644 --- a/example/src/bin/client.rs +++ b/example/src/bin/client.rs @@ -38,6 +38,14 @@ pub async fn main() -> Result<(), GenericError> { .await; eprintln!("{:?}", resp); + let resp = client + .with_host("localhost") + .make_hat_request(MakeHatRequest { inches: 1 })? + .header("x-custom-header", "a") + .send() + .await?; + eprintln!("{:?}", resp); + Ok(()) } @@ -69,23 +77,39 @@ impl Middleware for PrintResponseHeaders { } } +// NOTE: This is just to demonstrate manually implementing the client trait. You don't need to do this as this code will +// be generated for you by twirp-build. +// +// This is here so that we can visualize changes to the generated client code #[allow(dead_code)] #[derive(Debug)] struct MockHaberdasherApiClient; #[async_trait] impl HaberdasherApiClient for MockHaberdasherApiClient { - async fn make_hat( + fn make_hat_request( &self, _req: MakeHatRequest, - ) -> Result { + ) -> Result, twirp::ClientError> { + todo!() + } + // implementing this one is optional + async fn make_hat(&self, _req: MakeHatRequest) -> Result { todo!() } + fn get_status_request( + &self, + _req: GetStatusRequest, + ) -> Result, twirp::ClientError> + { + todo!() + } + // implementing this one is optional async fn get_status( &self, _req: GetStatusRequest, - ) -> Result { + ) -> Result { todo!() } }