Skip to content

support websocket compression #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions include/crow/compression.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#ifdef CROW_ENABLE_COMPRESSION
#pragma once

#ifndef ASIO_STANDALONE
#define ASIO_STANDALONE
#endif
#include <asio.hpp>
#include <memory>
#include <string>
#include <zlib.h>

Expand Down Expand Up @@ -93,6 +98,152 @@ namespace crow

return inflated_string;
}

class Compressor
{
public:
Compressor(bool reset_before_compress, int window_bits, int level):
reset_before_compress_(reset_before_compress), window_bits_(window_bits)
{
stream_ = std::make_unique<z_stream>();
stream_->zalloc = 0;
stream_->zfree = 0;
stream_->opaque = 0;

::deflateInit2(stream_.get(),
level,
Z_DEFLATED,
-window_bits_,
8,
Z_DEFAULT_STRATEGY);
}

~Compressor()
{
::deflateEnd(stream_.get());
}

bool needs_reset() const
{
return reset_before_compress_;
}

int window_bits() const
{
return window_bits_;
}

std::string compress(const std::string& src)
{
if (reset_before_compress_)
{
::deflateReset(stream_.get());
}

stream_->next_in = reinterpret_cast<uint8_t*>(const_cast<char*>(src.c_str()));
stream_->avail_in = src.size();

constexpr const uint64_t bufferSize = 8192;
asio::streambuf buffer;
do
{
asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize);

uint8_t* next_out = asio::buffer_cast<uint8_t*>(chunk);

stream_->next_out = next_out;
stream_->avail_out = bufferSize;

::deflate(stream_.get(), reset_before_compress_ ? Z_FINISH : Z_SYNC_FLUSH);

uint64_t outputSize = stream_->next_out - next_out;
buffer.commit(outputSize);
} while (stream_->avail_out == 0);

uint64_t buffer_size = buffer.size();
if (!reset_before_compress_)
{
buffer_size -= 4;
}

return std::string(asio::buffer_cast<const char*>(buffer.data()), buffer_size);
}

private:
std::unique_ptr<z_stream> stream_;

bool reset_before_compress_;
int window_bits_;
};

class Decompressor
{
public:
Decompressor(bool reset_before_decompress, int window_bits):
reset_before_decompress_(reset_before_decompress), window_bits_(window_bits)
{
stream_ = std::make_unique<z_stream>();
stream_->zalloc = 0;
stream_->zfree = 0;
stream_->opaque = 0;

::inflateInit2(stream_.get(), -window_bits_);
}

~Decompressor()
{
inflateEnd(stream_.get());
}

bool needs_reset() const
{
return reset_before_decompress_;
}

int window_bits() const
{
return window_bits_;
}

std::string decompress(std::string src)
{
if (reset_before_decompress_)
{
inflateReset(stream_.get());
}

src.push_back('\x00');
src.push_back('\x00');
src.push_back('\xff');
src.push_back('\xff');

stream_->next_in = reinterpret_cast<uint8_t*>(const_cast<char*>(src.c_str()));
stream_->avail_in = src.size();

constexpr const uint64_t bufferSize = 8192;
asio::streambuf buffer;
do
{
asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize);

uint8_t* next_out = asio::buffer_cast<uint8_t*>(chunk);

stream_->next_out = next_out;
stream_->avail_out = bufferSize;

::inflate(stream_.get(), reset_before_decompress_ ? Z_FINISH : Z_SYNC_FLUSH);
buffer.commit(stream_->next_out - next_out);
} while (stream_->avail_out == 0);

return std::string(asio::buffer_cast<const char*>(buffer.data()), buffer.size());
}

private:
std::unique_ptr<z_stream> stream_;

bool reset_before_decompress_;
int window_bits_;
};
} // namespace compression
} // namespace crow

Expand Down
68 changes: 66 additions & 2 deletions include/crow/websocket.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#pragma once
#include <array>
#include <memory>
#include "crow/logging.h"
#include "crow/socket_adaptors.h"
#include "crow/http_request.h"
#include "crow/TinySHA1.hpp"
#include "crow/utility.h"
#include "crow/compression.h"

namespace crow
{
Expand Down Expand Up @@ -107,6 +109,17 @@ namespace crow
userdata(ud);
}

#ifdef CROW_ENABLE_COMPRESSION
std::string extensions_header = req.get_header_value("Sec-WebSocket-Extensions");
if (extensions_header.find("permessage-deflate") != std::string::npos)
{
const bool reset_compressor = extensions_header.find("server_no_context_takeover") != std::string::npos;
compressor_ = std::make_unique<compression::Compressor>(reset_compressor, compression::DEFLATE, Z_BEST_COMPRESSION);
const bool reset_decompressor = extensions_header.find("client_no_context_takeover") != std::string::npos;
decompressor_ = std::make_unique<compression::Decompressor>(reset_decompressor, compression::DEFLATE);
}
#endif

// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Sec-WebSocket-Version: 13
std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
Expand Down Expand Up @@ -186,13 +199,29 @@ namespace crow
/// Send a binary encoded message.
void send_binary(std::string msg) override
{
send_data(0x2, std::move(msg));
int opcode = 0x2;
#ifdef CROW_ENABLE_COMPRESSION
if (compressor_)
{
opcode += 0x40;
msg = compressor_->compress(msg);
}
#endif
send_data(opcode, std::move(msg));
}

/// Send a plaintext message.
void send_text(std::string msg) override
{
send_data(0x1, std::move(msg));
int opcode = 0x1;
#ifdef CROW_ENABLE_COMPRESSION
if (compressor_)
{
opcode += 0x40;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't find where this comes from. May I ask you to send a link to spec? ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll look for it. Might take some time. I'm currently very busy sadly.

msg = compressor_->compress(msg);
}
#endif
send_data(opcode, std::move(msg));
}

/// Send a close signal.
Expand Down Expand Up @@ -265,6 +294,19 @@ namespace crow
write_buffers_.emplace_back(header);
write_buffers_.emplace_back(std::move(hello));
write_buffers_.emplace_back(crlf);
#ifdef CROW_ENABLE_COMPRESSION
if (compressor_ && decompressor_)
{
write_buffers_.emplace_back(
"Sec-WebSocket-Extensions: permessage-deflate"
"; server_max_window_bits=" +
std::to_string(compressor_->window_bits()) +
"; client_max_window_bits=" + std::to_string(decompressor_->window_bits()) +
(compressor_->needs_reset() ? "; server_no_context_takeover" : "") +
(decompressor_->needs_reset() ? "; client_no_context_takeover" : ""));
write_buffers_.emplace_back(crlf);
}
#endif
write_buffers_.emplace_back(crlf);
do_write();
if (open_handler_)
Expand Down Expand Up @@ -528,6 +570,12 @@ namespace crow
return mini_header_ & 0x8000;
}

/// Check if payload is compressed
bool is_compressed()
{
return mini_header_ & 0x4000;
}

/// Extract the opcode from the header.
int opcode()
{
Expand Down Expand Up @@ -555,7 +603,11 @@ namespace crow
if (is_FIN())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this if to separate method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do

{
if (message_handler_)
#ifdef CROW_ENABLE_COMPRESSION
message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_);
#else
message_handler_(*this, message_, is_binary_);
#endif
message_.clear();
}
}
Expand All @@ -567,7 +619,11 @@ namespace crow
if (is_FIN())
{
if (message_handler_)
#ifdef CROW_ENABLE_COMPRESSION
message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_);
#else
message_handler_(*this, message_, is_binary_);
#endif
message_.clear();
}
}
Expand All @@ -579,7 +635,11 @@ namespace crow
if (is_FIN())
{
if (message_handler_)
#ifdef CROW_ENABLE_COMPRESSION
message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_);
#else
message_handler_(*this, message_, is_binary_);
#endif
message_.clear();
}
}
Expand Down Expand Up @@ -734,6 +794,10 @@ namespace crow

std::shared_ptr<void> anchor_ = std::make_shared<int>(); // Value is just for placeholding

#ifdef CROW_ENABLE_COMPRESSION
std::unique_ptr<compression::Compressor> compressor_;
std::unique_ptr<compression::Decompressor> decompressor_;
#endif
std::function<void(crow::websocket::connection&)> open_handler_;
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
Expand Down