From 9da634ef12551dd3b27cccbb591ca220c64287dd Mon Sep 17 00:00:00 2001
From: Dov Reshef <reshef.dov@gmail.com>
Date: Tue, 5 May 2020 23:10:44 +0300
Subject: [PATCH] Add support for post_fields without buffer copy

---
 src/easy/handle.rs  | 58 +++++++++++++++++++++++++++++++++++---------
 src/easy/handler.rs | 16 +++++++++++-
 tests/easy.rs       | 59 +++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 121 insertions(+), 12 deletions(-)

diff --git a/src/easy/handle.rs b/src/easy/handle.rs
index 3152bb39c0..40ed48ac35 100644
--- a/src/easy/handle.rs
+++ b/src/easy/handle.rs
@@ -90,9 +90,14 @@ pub struct Easy {
 /// The callbacks attached to a `Transfer` are only active for that one transfer
 /// object, and they allow to elide both the `Send` and `'static` bounds to
 /// close over stack-local information.
+///
+/// Likewise, the post_data attached to a `Transfer` are only active for that one
+/// transfer object, and they allow to elide both the `Send` and `'static` bounds
+/// to close over stack-local information.
 pub struct Transfer<'easy, 'data> {
     easy: &'easy mut Easy,
-    data: Box<Callbacks<'data>>,
+    callbacks: Box<Callbacks<'data>>,
+    is_postfields: Cell<bool>,
 }
 
 pub struct EasyData {
@@ -719,6 +724,11 @@ impl Easy {
         self.inner.post_fields_copy(data)
     }
 
+    /// Same as [`Easy2::post_field`](struct.Easy2.html#method.post_field)
+    pub fn post_fields(&mut self, data: &'static [u8]) -> Result<(), Error> {
+        self.inner.post_fields(data)
+    }
+
     /// Same as [`Easy2::post_field_size`](struct.Easy2.html#method.post_field_size)
     pub fn post_field_size(&mut self, size: u64) -> Result<(), Error> {
         self.inner.post_field_size(size)
@@ -1216,8 +1226,9 @@ impl Easy {
     pub fn transfer<'data, 'easy>(&'easy mut self) -> Transfer<'easy, 'data> {
         assert!(!self.inner.get_ref().running.get());
         Transfer {
-            data: Box::new(Callbacks::default()),
+            callbacks: Box::new(Callbacks::default()),
             easy: self,
+            is_postfields: Cell::new(false),
         }
     }
 
@@ -1379,7 +1390,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(&[u8]) -> Result<usize, WriteError> + 'data,
     {
-        self.data.write = Some(Box::new(f));
+        self.callbacks.write = Some(Box::new(f));
         Ok(())
     }
 
@@ -1389,7 +1400,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(&mut [u8]) -> Result<usize, ReadError> + 'data,
     {
-        self.data.read = Some(Box::new(f));
+        self.callbacks.read = Some(Box::new(f));
         Ok(())
     }
 
@@ -1399,7 +1410,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(SeekFrom) -> SeekResult + 'data,
     {
-        self.data.seek = Some(Box::new(f));
+        self.callbacks.seek = Some(Box::new(f));
         Ok(())
     }
 
@@ -1409,7 +1420,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(f64, f64, f64, f64) -> bool + 'data,
     {
-        self.data.progress = Some(Box::new(f));
+        self.callbacks.progress = Some(Box::new(f));
         Ok(())
     }
 
@@ -1419,7 +1430,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(*mut c_void) -> Result<(), Error> + Send + 'data,
     {
-        self.data.ssl_ctx = Some(Box::new(f));
+        self.callbacks.ssl_ctx = Some(Box::new(f));
         Ok(())
     }
 
@@ -1429,7 +1440,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(InfoType, &[u8]) + 'data,
     {
-        self.data.debug = Some(Box::new(f));
+        self.callbacks.debug = Some(Box::new(f));
         Ok(())
     }
 
@@ -1439,7 +1450,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     where
         F: FnMut(&[u8]) -> bool + 'data,
     {
-        self.data.header = Some(Box::new(f));
+        self.callbacks.header = Some(Box::new(f));
         Ok(())
     }
 
@@ -1454,7 +1465,7 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
         // This should be ok, however, because `do_perform` checks for recursive
         // invocations of `perform` and disallows them. Our type also isn't
         // `Sync`.
-        inner.borrowed.set(&*self.data as *const _ as *mut _);
+        inner.borrowed.set(&*self.callbacks as *const _ as *mut _);
 
         // Make sure to reset everything back to the way it was before when
         // we're done.
@@ -1466,7 +1477,21 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
         }
         let _reset = Reset(&inner.borrowed);
 
-        self.easy.do_perform()
+        let res = self.easy.do_perform();
+
+        // restore configuration
+        if self.is_postfields.get() {
+            self.easy
+                .inner
+                .setopt_ptr(curl_sys::CURLOPT_POSTFIELDS, 0 as *const _)
+                .expect("Failed to reset post_field_size");
+            self.easy
+                .inner
+                .setopt_ptr(curl_sys::CURLOPT_POSTFIELDS, ptr::null() as *const _)
+                .expect("Failed to set postfields to null");
+            self.is_postfields.set(false);
+        }
+        res
     }
 
     /// Same as `Easy::unpause_read`.
@@ -1478,6 +1503,17 @@ impl<'easy, 'data> Transfer<'easy, 'data> {
     pub fn unpause_write(&self) -> Result<(), Error> {
         self.easy.unpause_write()
     }
+
+    /// Similar to [`Easy2::post_field`](struct.Easy2.html#method.post_field) just
+    /// takes a non `'static` lifetime corresponding to the lifetime of this transfer.
+    pub fn post_fields(&mut self, data: &'data [u8]) -> Result<(), Error> {
+        // Set the length before the pointer so libcurl knows how much to read
+        self.is_postfields.set(true);
+        self.easy.inner.post_field_size(data.len() as u64)?;
+        self.easy
+            .inner
+            .setopt_ptr(curl_sys::CURLOPT_POSTFIELDS, data.as_ptr() as *const _)
+    }
 }
 
 impl<'easy, 'data> fmt::Debug for Transfer<'easy, 'data> {
diff --git a/src/easy/handler.rs b/src/easy/handler.rs
index f6c4e02818..29c1944845 100644
--- a/src/easy/handler.rs
+++ b/src/easy/handler.rs
@@ -1209,6 +1209,16 @@ impl<H> Easy2<H> {
         self.setopt_ptr(curl_sys::CURLOPT_COPYPOSTFIELDS, data.as_ptr() as *const _)
     }
 
+    /// Configures the data that will be uploaded as part of a POST.
+    ///
+    /// By default this option is not set and corresponds to
+    /// `CURLOPT_POSTFIELDS`.
+    pub fn post_fields(&mut self, data: &'static [u8]) -> Result<(), Error> {
+        // Set the length before the pointer so libcurl knows how much to read
+        self.post_field_size(data.len() as u64)?;
+        self.setopt_ptr(curl_sys::CURLOPT_POSTFIELDS, data.as_ptr() as *const _)
+    }
+
     /// Configures the size of data that's going to be uploaded as part of a
     /// POST operation.
     ///
@@ -2827,7 +2837,11 @@ impl<H> Easy2<H> {
         self.setopt_ptr(opt, val.as_ptr())
     }
 
-    fn setopt_ptr(&self, opt: curl_sys::CURLoption, val: *const c_char) -> Result<(), Error> {
+    pub(crate) fn setopt_ptr(
+        &self,
+        opt: curl_sys::CURLoption,
+        val: *const c_char,
+    ) -> Result<(), Error> {
         unsafe { self.cvt(curl_sys::curl_easy_setopt(self.inner.handle, opt, val)) }
     }
 
diff --git a/tests/easy.rs b/tests/easy.rs
index 27e8505346..f2c97aa297 100644
--- a/tests/easy.rs
+++ b/tests/easy.rs
@@ -515,6 +515,32 @@ fn post3() {
     t!(h.perform());
 }
 
+#[test]
+fn post4() {
+    let s = Server::new();
+    s.receive(
+        "\
+         POST / HTTP/1.1\r\n\
+         Host: 127.0.0.1:$PORT\r\n\
+         Accept: */*\r\n\
+         Content-Length: 5\r\n\
+         Content-Type: application/x-www-form-urlencoded\r\n\
+         \r\n\
+         data\n",
+    );
+    s.send(
+        "\
+         HTTP/1.1 200 OK\r\n\
+         \r\n",
+    );
+
+    let mut h = handle();
+    t!(h.url(&s.url("/")));
+    t!(h.post(true));
+    t!(h.post_fields(b"data\n"));
+    t!(h.perform());
+}
+
 #[test]
 fn referer() {
     let s = Server::new();
@@ -769,6 +795,39 @@ b",
     t!(h.transfer.borrow().perform());
 }
 
+#[test]
+fn transfer_post_fields() {
+    let s = Server::new();
+    s.receive(
+        "\
+         POST / HTTP/1.1\r\n\
+         Host: 127.0.0.1:$PORT\r\n\
+         Accept: */*\r\n\
+         Content-Length: 5\r\n\
+         Content-Type: application/x-www-form-urlencoded\r\n\
+         \r\n\
+         data\n",
+    );
+    s.send(
+        "\
+         HTTP/1.1 200 OK\r\n\
+         \r\n",
+    );
+
+    fn do_transfer_post(e: &mut Easy, data: &[u8]) {
+        let mut transfer = e.transfer();
+        t!(transfer.post_fields(data));
+        t!(transfer.perform());
+    }
+
+    let mut h = handle();
+    t!(h.url(&s.url("/")));
+    t!(h.post(true));
+    let mut data = Vec::new();
+    data.extend_from_slice(b"data\n");
+    do_transfer_post(&mut h, &data);
+}
+
 #[test]
 fn perform_in_perform_is_bad() {
     let s = Server::new();