Skip to content

Commit 0cb98c2

Browse files
committed
Make DeflateContext private and add Extensions container
1 parent 78322fe commit 0cb98c2

File tree

4 files changed

+70
-48
lines changed

4 files changed

+70
-48
lines changed

src/extensions/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@
22
// Only `permessage-deflate` is supported at the moment.
33

44
mod compression;
5-
pub use compression::deflate::{DeflateConfig, DeflateContext, DeflateError};
5+
use compression::deflate::DeflateContext;
6+
pub use compression::deflate::{DeflateConfig, DeflateError};
67
use http::HeaderValue;
78

9+
/// Container for configured extensions.
10+
#[derive(Debug, Default)]
11+
pub struct Extensions {
12+
// Per-Message Compression. Only `permessage-deflate` is supported.
13+
pub(crate) compression: Option<DeflateContext>,
14+
}
15+
816
/// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values.
917
pub(crate) fn iter_all<'a>(
1018
values: impl Iterator<Item = &'a HeaderValue>,

src/handshake/client.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use super::{
1717
};
1818
use crate::{
1919
error::{Error, ProtocolError, Result, UrlError},
20-
extensions::{self, DeflateContext},
20+
extensions::{self, Extensions},
2121
protocol::{Role, WebSocket, WebSocketConfig},
2222
};
2323

@@ -85,7 +85,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
8585
StageResult::DoneReading { stream, result, tail } => {
8686
let (result, pmce) = self.verify_data.verify_response(result, &self.config)?;
8787
debug!("Client handshake done.");
88-
let websocket = WebSocket::from_partially_read_with_compression(
88+
let websocket = WebSocket::from_partially_read_with_extensions(
8989
stream,
9090
tail,
9191
Role::Client,
@@ -161,7 +161,7 @@ impl VerifyData {
161161
&self,
162162
response: Response,
163163
config: &Option<WebSocketConfig>,
164-
) -> Result<(Response, Option<DeflateContext>)> {
164+
) -> Result<(Response, Option<Extensions>)> {
165165
// 1. If the status code received from the server is not 101, the
166166
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
167167
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@@ -201,15 +201,15 @@ impl VerifyData {
201201
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
202202
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
203203
}
204-
let mut pmce = None;
204+
let mut extensions = None;
205205
// 5. If the response includes a |Sec-WebSocket-Extensions| header
206206
// field and this header field indicates the use of an extension
207207
// that was not present in the client's handshake (the server has
208208
// indicated an extension not requested by the client), the client
209209
// MUST _Fail the WebSocket Connection_. (RFC 6455)
210-
let mut extensions = headers.get_all("Sec-WebSocket-Extensions").iter();
211-
if let Some(value) = extensions.next() {
212-
if extensions.next().is_some() {
210+
let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter();
211+
if let Some(value) = extensions_values.next() {
212+
if extensions_values.next().is_some() {
213213
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
214214
}
215215

@@ -223,12 +223,15 @@ impl VerifyData {
223223
}
224224

225225
// Already had PMCE configured
226-
if pmce.is_some() {
226+
if extensions.is_some() {
227227
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
228228
name.to_string(),
229229
)));
230230
}
231-
pmce = Some(compression.accept_response(params)?);
231+
232+
extensions = Some(Extensions {
233+
compression: Some(compression.accept_response(params)?),
234+
});
232235
}
233236
} else if let Some((name, _)) = exts.next() {
234237
// The client didn't request anything, but got something
@@ -243,7 +246,7 @@ impl VerifyData {
243246
// the WebSocket Connection_. (RFC 6455)
244247
// TODO
245248

246-
Ok((response, pmce))
249+
Ok((response, extensions))
247250
}
248251
}
249252

src/handshake/server.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use super::{
2020
};
2121
use crate::{
2222
error::{Error, ProtocolError, Result},
23-
extensions,
23+
extensions::Extensions,
2424
protocol::{Role, WebSocket, WebSocketConfig},
2525
};
2626

@@ -203,8 +203,8 @@ pub struct ServerHandshake<S, C> {
203203
config: Option<WebSocketConfig>,
204204
/// Error code/flag. If set, an error will be returned after sending response to the client.
205205
error_response: Option<ErrorResponse>,
206-
// Negotiated Per-Message Compression Extension context for server.
207-
pmce: Option<extensions::DeflateContext>,
206+
// Negotiated extension context for server.
207+
extensions: Option<Extensions>,
208208
/// Internal stream type.
209209
_marker: PhantomData<S>,
210210
}
@@ -222,7 +222,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
222222
callback: Some(callback),
223223
config,
224224
error_response: None,
225-
pmce: None,
225+
extensions: None,
226226
_marker: PhantomData,
227227
},
228228
}
@@ -246,10 +246,10 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
246246

247247
let mut response = create_response(&result)?;
248248
if let Some(config) = &self.config {
249-
let extensions = result.headers().get_all("Sec-WebSocket-Extensions").iter();
250-
if let Some((agreed, pmce)) = config.accept_offers(extensions) {
251-
self.pmce = Some(pmce);
249+
let values = result.headers().get_all("Sec-WebSocket-Extensions").iter();
250+
if let Some((agreed, extensions)) = config.accept_offers(values) {
252251
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
252+
self.extensions = Some(extensions);
253253
}
254254
}
255255

@@ -292,11 +292,11 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
292292
return Err(Error::Http(err));
293293
} else {
294294
debug!("Server handshake done.");
295-
let websocket = WebSocket::from_raw_socket_with_compression(
295+
let websocket = WebSocket::from_raw_socket_with_extensions(
296296
stream,
297297
Role::Server,
298298
self.config,
299-
self.pmce.take(),
299+
self.extensions.take(),
300300
);
301301
ProcessingResult::Done(websocket)
302302
}

src/protocol/mod.rs

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use self::{
2323
};
2424
use crate::{
2525
error::{Error, ProtocolError, Result},
26-
extensions::{self, DeflateContext},
26+
extensions::{self, Extensions},
2727
util::NonBlockingResult,
2828
};
2929

@@ -81,15 +81,14 @@ impl WebSocketConfig {
8181
self.compression.map(|c| c.generate_offer())
8282
}
8383

84-
// TODO Replace `DeflateContext` with something more general
85-
// This can be used with `WebSocket::from_raw_socket_with_compression` for integration.
86-
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
84+
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
85+
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
8786
pub fn accept_offers<'a>(
8887
&'a self,
8988
extensions: impl Iterator<Item = &'a HeaderValue>,
90-
) -> Option<(HeaderValue, DeflateContext)> {
89+
) -> Option<(HeaderValue, Extensions)> {
9190
if let Some(compression) = &self.compression {
92-
let extensions = crate::extensions::iter_all(extensions);
91+
let extensions = extensions::iter_all(extensions);
9392
let offers =
9493
extensions.filter_map(
9594
|(k, v)| {
@@ -100,7 +99,12 @@ impl WebSocketConfig {
10099
}
101100
},
102101
);
103-
compression.accept_offer(offers)
102+
103+
// To support more extensions, store extension context in `Extensions` and
104+
// concatenate negotiation responses from each extension.
105+
compression
106+
.accept_offer(offers)
107+
.map(|(agreed, pmce)| (agreed, Extensions { compression: Some(pmce) }))
104108
} else {
105109
None
106110
}
@@ -130,14 +134,14 @@ impl<Stream> WebSocket<Stream> {
130134
}
131135

132136
/// Convert a raw socket into a WebSocket without performing a handshake.
133-
pub fn from_raw_socket_with_compression(
137+
pub fn from_raw_socket_with_extensions(
134138
stream: Stream,
135139
role: Role,
136140
config: Option<WebSocketConfig>,
137-
pmce: Option<DeflateContext>,
141+
extensions: Option<Extensions>,
138142
) -> Self {
139143
let mut context = WebSocketContext::new(role, config);
140-
context.pmce = pmce;
144+
context.extensions = extensions;
141145
WebSocket { socket: stream, context }
142146
}
143147

@@ -158,17 +162,17 @@ impl<Stream> WebSocket<Stream> {
158162
}
159163
}
160164

161-
pub(crate) fn from_partially_read_with_compression(
165+
pub(crate) fn from_partially_read_with_extensions(
162166
stream: Stream,
163167
part: Vec<u8>,
164168
role: Role,
165169
config: Option<WebSocketConfig>,
166-
pmce: Option<DeflateContext>,
170+
extensions: Option<Extensions>,
167171
) -> Self {
168172
WebSocket {
169173
socket: stream,
170-
context: WebSocketContext::from_partially_read_with_compression(
171-
part, role, config, pmce,
174+
context: WebSocketContext::from_partially_read_with_extensions(
175+
part, role, config, extensions,
172176
),
173177
}
174178
}
@@ -306,8 +310,8 @@ pub struct WebSocketContext {
306310
pong: Option<Frame>,
307311
/// The configuration for the websocket session.
308312
config: WebSocketConfig,
309-
/// Per-Message Compression Extension. Only deflate is supported at the moment.
310-
pub(crate) pmce: Option<extensions::DeflateContext>,
313+
// Container for extensions.
314+
pub(crate) extensions: Option<Extensions>,
311315
}
312316

313317
impl WebSocketContext {
@@ -321,7 +325,7 @@ impl WebSocketContext {
321325
send_queue: VecDeque::new(),
322326
pong: None,
323327
config: config.unwrap_or_else(WebSocketConfig::default),
324-
pmce: None,
328+
extensions: None,
325329
}
326330
}
327331

@@ -333,15 +337,15 @@ impl WebSocketContext {
333337
}
334338
}
335339

336-
pub(crate) fn from_partially_read_with_compression(
340+
pub(crate) fn from_partially_read_with_extensions(
337341
part: Vec<u8>,
338342
role: Role,
339343
config: Option<WebSocketConfig>,
340-
pmce: Option<DeflateContext>,
344+
extensions: Option<Extensions>,
341345
) -> Self {
342346
WebSocketContext {
343347
frame: FrameCodec::from_partially_read(part),
344-
pmce,
348+
extensions,
345349
..WebSocketContext::new(role, config)
346350
}
347351
}
@@ -447,11 +451,12 @@ impl WebSocketContext {
447451
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
448452
let opcode = OpCode::Data(opdata);
449453
let is_final = true;
450-
let frame = if let Some(pmce) = self.pmce.as_mut() {
451-
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
452-
} else {
453-
Frame::message(data, opcode, is_final)
454-
};
454+
let frame =
455+
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
456+
Frame::compressed_message(pmce.compress(&data)?, opcode, is_final)
457+
} else {
458+
Frame::message(data, opcode, is_final)
459+
};
455460
Ok(frame)
456461
}
457462

@@ -533,7 +538,7 @@ impl WebSocketContext {
533538
// Connection_.
534539
let is_compressed = {
535540
let hdr = frame.header();
536-
if (hdr.rsv1 && self.pmce.is_none()) || hdr.rsv2 || hdr.rsv3 {
541+
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
537542
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
538543
}
539544

@@ -606,8 +611,9 @@ impl WebSocketContext {
606611
if let Some(ref mut msg) = self.incomplete {
607612
let data = if msg.compressed() {
608613
// `msg.compressed` is only set when compression is enabled so it's safe to unwrap
609-
self.pmce
614+
self.extensions
610615
.as_mut()
616+
.and_then(|x| x.compression.as_mut())
611617
.unwrap()
612618
.decompress(frame.into_data(), fin)?
613619
} else {
@@ -637,8 +643,9 @@ impl WebSocketContext {
637643
};
638644
let mut m = IncompleteMessage::new(message_type, is_compressed);
639645
let data = if is_compressed {
640-
self.pmce
646+
self.extensions
641647
.as_mut()
648+
.and_then(|x| x.compression.as_mut())
642649
.unwrap()
643650
.decompress(frame.into_data(), fin)?
644651
} else {
@@ -729,6 +736,10 @@ impl WebSocketContext {
729736
trace!("Sending frame: {:?}", frame);
730737
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
731738
}
739+
740+
fn has_compression(&self) -> bool {
741+
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
742+
}
732743
}
733744

734745
/// The current connection state.

0 commit comments

Comments
 (0)