Skip to content

Commit cf3c983

Browse files
committed
optimistic message sending tested
1 parent b52fbcc commit cf3c983

File tree

6 files changed

+168
-69
lines changed

6 files changed

+168
-69
lines changed

src/client/legacy/connect/proxy/socks/mod.rs

+28-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ pub use v5::{SocksV5, SocksV5Error};
44
mod v4;
55
pub use v4::{SocksV4, SocksV4Error};
66

7+
use bytes::BytesMut;
8+
79
use hyper::rt::Read;
810

911
#[derive(Debug)]
@@ -25,6 +27,7 @@ pub enum SocksError<C> {
2527
#[derive(Debug)]
2628
pub enum ParsingError {
2729
Incomplete,
30+
WouldOverflow,
2831
Other,
2932
}
3033

@@ -33,24 +36,33 @@ pub enum SerializeError {
3336
WouldOverflow,
3437
}
3538

36-
async fn read_message<T, M, C>(mut conn: &mut T, buf: &mut [u8]) -> Result<M, SocksError<C>>
39+
async fn read_message<T, M, C>(mut conn: &mut T, buf: &mut BytesMut) -> Result<M, SocksError<C>>
3740
where
3841
T: Read + Unpin,
39-
M: for<'a> TryFrom<&'a [u8], Error = ParsingError>,
42+
M: for<'a> TryFrom<&'a mut BytesMut, Error = ParsingError>,
4043
{
41-
let mut n = 0;
4244
loop {
43-
let read = crate::rt::read(&mut conn, buf).await?;
44-
45-
if read == 0 {
46-
return Err(
47-
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof").into(),
48-
);
49-
}
50-
51-
n += read;
52-
match M::try_from(&buf[..n]) {
53-
Err(ParsingError::Incomplete) => continue,
45+
let n = unsafe {
46+
let spare = &mut *(buf.spare_capacity_mut() as *mut _ as *mut [u8]);
47+
let n = crate::rt::read(&mut conn, spare).await?;
48+
buf.set_len(buf.len() + n);
49+
n
50+
};
51+
52+
match M::try_from(buf) {
53+
Err(ParsingError::Incomplete) => {
54+
if n == 0 {
55+
if buf.spare_capacity_mut().len() == 0 {
56+
return Err(SocksError::Parsing(ParsingError::WouldOverflow));
57+
} else {
58+
return Err(std::io::Error::new(
59+
std::io::ErrorKind::UnexpectedEof,
60+
"unexpected eof",
61+
)
62+
.into());
63+
}
64+
}
65+
}
5466
Err(err) => return Err(err.into()),
5567
Ok(res) => return Ok(res),
5668
}
@@ -78,6 +90,8 @@ impl<C> std::fmt::Display for SocksError<C> {
7890
}
7991
}
8092

93+
impl<C: std::fmt::Debug + std::fmt::Display> std::error::Error for SocksError<C> {}
94+
8195
impl<C> From<std::io::Error> for SocksError<C> {
8296
fn from(err: std::io::Error) -> Self {
8397
Self::Io(err)

src/client/legacy/connect/proxy/socks/v4/messages.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::super::{ParsingError, SerializeError};
22

3-
use bytes::{Buf, BufMut};
3+
use bytes::{Buf, BufMut, BytesMut};
44
use std::net::SocketAddrV4;
55

66
/// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+
@@ -80,24 +80,19 @@ impl Request<'_> {
8080
}
8181
}
8282

83-
impl TryFrom<&[u8]> for Response {
83+
impl TryFrom<&mut BytesMut> for Response {
8484
type Error = ParsingError;
8585

86-
fn try_from(mut buf: &[u8]) -> Result<Self, Self::Error> {
87-
println!("===");
88-
println!("{buf:?}");
89-
println!("===");
90-
86+
fn try_from(buf: &mut BytesMut) -> Result<Self, Self::Error> {
9187
if buf.remaining() < 8 {
9288
return Err(ParsingError::Incomplete);
9389
}
9490

95-
if buf.get_u8() != 0x04 {
91+
if buf.get_u8() != 0x00 {
9692
return Err(ParsingError::Other);
9793
}
9894

9995
let status = buf.get_u8().try_into()?;
100-
10196
let _addr = {
10297
let port = buf.get_u16();
10398
let mut ip = [0; 4];

src/client/legacy/connect/proxy/socks/v4/mod.rs

+21-9
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ use http::Uri;
1414
use hyper::rt::{Read, Write};
1515
use tower_service::Service;
1616

17+
use bytes::BytesMut;
18+
1719
use pin_project_lite::pin_project;
1820

1921
/// TODO
20-
#[derive(Debug)]
22+
#[derive(Debug, Clone)]
2123
pub struct SocksV4<C> {
2224
inner: C,
2325
config: SocksConfig,
@@ -61,6 +63,15 @@ impl<C> SocksV4<C> {
6163
config: SocksConfig::new(proxy_dst),
6264
}
6365
}
66+
67+
/// Resolve domain names locally on the client, rather than on the proxy server.
68+
///
69+
/// Disabled by default as local resolution of domain names can be detected as a
70+
/// DNS leak.
71+
pub fn local_dns(mut self, local_dns: bool) -> Self {
72+
self.config.local_dns = local_dns;
73+
self
74+
}
6475
}
6576

6677
impl SocksConfig {
@@ -84,9 +95,7 @@ impl SocksConfig {
8495
Ok(IpAddr::V6(_)) => return Err(SocksV4Error::IpV6.into()),
8596
Ok(IpAddr::V4(ip)) => Address::Socket(SocketAddrV4::new(ip.into(), port)),
8697
Err(_) => {
87-
if !self.local_dns {
88-
Address::Domain(host, port)
89-
} else {
98+
if self.local_dns {
9099
(host, port)
91100
.to_socket_addrs()?
92101
.find_map(|s| {
@@ -97,19 +106,22 @@ impl SocksConfig {
97106
}
98107
})
99108
.ok_or(super::SocksError::DnsFailure)?
109+
} else {
110+
Address::Domain(host, port)
100111
}
101112
}
102113
};
103114

104-
let mut buf = vec![0; 512];
115+
let mut send_buf = BytesMut::with_capacity(1024);
116+
let mut recv_buf = BytesMut::with_capacity(1024);
105117

106118
// Send Request
107119
let req = Request(&address);
108-
let n = req.write_to_buf(&mut buf[..])?;
109-
crate::rt::write_all(&mut conn, &buf[..n]).await?;
120+
let n = req.write_to_buf(&mut send_buf)?;
121+
crate::rt::write_all(&mut conn, &send_buf[..n]).await?;
110122

111123
// Read Response
112-
let res: Response = super::read_message(&mut conn, &mut buf).await?;
124+
let res: Response = super::read_message(&mut conn, &mut recv_buf).await?;
113125
if res.0 == Status::Success {
114126
Ok(conn)
115127
} else {
@@ -138,7 +150,7 @@ where
138150
let connecting = self.inner.call(config.proxy.clone());
139151

140152
let fut = async move {
141-
let port = dst.port().ok_or(super::SocksError::MissingPort)?.as_u16();
153+
let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
142154
let host = dst
143155
.host()
144156
.ok_or(super::SocksError::MissingHost)?

src/client/legacy/connect/proxy/socks/v5/messages.rs

+19-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::super::{ParsingError, SerializeError};
22

3-
use bytes::{Buf, BufMut};
3+
use bytes::{Buf, BufMut, BytesMut};
44
use std::net::SocketAddr;
55

66
/// +----+----------+----------+
@@ -79,8 +79,8 @@ pub enum Status {
7979
}
8080

8181
impl NegotiationReq<'_> {
82-
pub fn write_to_buf<B: BufMut>(&self, mut buf: B) -> Result<usize, SerializeError> {
83-
if buf.remaining_mut() < 3 {
82+
pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> {
83+
if buf.capacity() - buf.len() < 3 {
8484
return Err(SerializeError::WouldOverflow);
8585
}
8686

@@ -92,10 +92,10 @@ impl NegotiationReq<'_> {
9292
}
9393
}
9494

95-
impl TryFrom<&[u8]> for NegotiationRes {
95+
impl TryFrom<&mut BytesMut> for NegotiationRes {
9696
type Error = ParsingError;
9797

98-
fn try_from(mut buf: &[u8]) -> Result<Self, ParsingError> {
98+
fn try_from(buf: &mut BytesMut) -> Result<Self, ParsingError> {
9999
if buf.remaining() < 2 {
100100
return Err(ParsingError::Incomplete);
101101
}
@@ -110,8 +110,8 @@ impl TryFrom<&[u8]> for NegotiationRes {
110110
}
111111

112112
impl AuthenticationReq<'_> {
113-
pub fn write_to_buf<B: BufMut>(&self, mut buf: B) -> Result<usize, SerializeError> {
114-
if buf.remaining_mut() < 3 + self.0.len() + self.1.len() {
113+
pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> {
114+
if buf.capacity() - buf.len() < 3 + self.0.len() + self.1.len() {
115115
return Err(SerializeError::WouldOverflow);
116116
}
117117

@@ -127,10 +127,10 @@ impl AuthenticationReq<'_> {
127127
}
128128
}
129129

130-
impl TryFrom<&[u8]> for AuthenticationRes {
130+
impl TryFrom<&mut BytesMut> for AuthenticationRes {
131131
type Error = ParsingError;
132132

133-
fn try_from(mut buf: &[u8]) -> Result<Self, ParsingError> {
133+
fn try_from(buf: &mut BytesMut) -> Result<Self, ParsingError> {
134134
if buf.remaining() < 2 {
135135
return Err(ParsingError::Incomplete);
136136
}
@@ -148,14 +148,14 @@ impl TryFrom<&[u8]> for AuthenticationRes {
148148
}
149149

150150
impl ProxyReq<'_> {
151-
pub fn write_to_buf<B: BufMut>(&self, mut buf: B) -> Result<usize, SerializeError> {
151+
pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> {
152152
let addr_len = match self.0 {
153153
Address::Socket(SocketAddr::V4(_)) => 1 + 4 + 2,
154154
Address::Socket(SocketAddr::V6(_)) => 1 + 16 + 2,
155155
Address::Domain(ref domain, _) => 1 + 1 + domain.len() + 2,
156156
};
157157

158-
if buf.remaining_mut() < 3 + addr_len {
158+
if buf.capacity() - buf.len() < 3 + addr_len {
159159
return Err(SerializeError::WouldOverflow);
160160
}
161161

@@ -168,10 +168,10 @@ impl ProxyReq<'_> {
168168
}
169169
}
170170

171-
impl TryFrom<&[u8]> for ProxyRes {
171+
impl TryFrom<&mut BytesMut> for ProxyRes {
172172
type Error = ParsingError;
173173

174-
fn try_from(mut buf: &[u8]) -> Result<Self, ParsingError> {
174+
fn try_from(buf: &mut BytesMut) -> Result<Self, ParsingError> {
175175
if buf.remaining() < 2 {
176176
return Err(ParsingError::Incomplete);
177177
}
@@ -197,10 +197,10 @@ impl TryFrom<&[u8]> for ProxyRes {
197197
}
198198

199199
impl Address {
200-
pub fn write_to_buf<B: BufMut>(&self, mut buf: B) -> Result<usize, SerializeError> {
200+
pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> {
201201
match self {
202202
Self::Socket(SocketAddr::V4(v4)) => {
203-
if buf.remaining_mut() < 1 + 4 + 2 {
203+
if buf.capacity() - buf.len() < 1 + 4 + 2 {
204204
return Err(SerializeError::WouldOverflow);
205205
}
206206

@@ -212,7 +212,7 @@ impl Address {
212212
}
213213

214214
Self::Socket(SocketAddr::V6(v6)) => {
215-
if buf.remaining_mut() < 1 + 16 + 2 {
215+
if buf.capacity() - buf.len() < 1 + 16 + 2 {
216216
return Err(SerializeError::WouldOverflow);
217217
}
218218

@@ -224,7 +224,7 @@ impl Address {
224224
}
225225

226226
Self::Domain(domain, port) => {
227-
if buf.remaining_mut() < 1 + 1 + domain.len() + 2 {
227+
if buf.capacity() - buf.len() < 1 + 1 + domain.len() + 2 {
228228
return Err(SerializeError::WouldOverflow);
229229
}
230230

@@ -239,10 +239,10 @@ impl Address {
239239
}
240240
}
241241

242-
impl TryFrom<&[u8]> for Address {
242+
impl TryFrom<&mut BytesMut> for Address {
243243
type Error = ParsingError;
244244

245-
fn try_from(mut buf: &[u8]) -> Result<Self, Self::Error> {
245+
fn try_from(buf: &mut BytesMut) -> Result<Self, Self::Error> {
246246
if buf.remaining() < 2 {
247247
return Err(ParsingError::Incomplete);
248248
}

0 commit comments

Comments
 (0)