diff --git a/thrift/lib/cpp/protocol/CMakeLists.txt b/thrift/lib/cpp/protocol/CMakeLists.txt new file mode 100644 index 00000000000..0317ecef133 --- /dev/null +++ b/thrift/lib/cpp/protocol/CMakeLists.txt @@ -0,0 +1,65 @@ +cmake_minimum_required(VERSION 3.16) +project(fbthrift_protocol_fuzzer LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(THRIFT_FUZZ_SANITIZER_FLAGS "-fsanitize=fuzzer,address,undefined" CACHE STRING + "Sanitizers used when building fuzz_thrift") +separate_arguments(THRIFT_FUZZ_SANITIZER_FLAG_LIST NATIVE_COMMAND "${THRIFT_FUZZ_SANITIZER_FLAGS}") + +get_filename_component(THRIFT_SOURCE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../../.." ABSOLUTE) + +set(GEN_INCLUDE_DIR "${CMAKE_CURRENT_BINARY_DIR}") +file(MAKE_DIRECTORY + "${GEN_INCLUDE_DIR}/thrift/lib/cpp" "${GEN_INCLUDE_DIR}/thrift/lib/cpp/util" + "${GEN_INCLUDE_DIR}/folly/lang" "${GEN_INCLUDE_DIR}/folly/portability" + "${GEN_INCLUDE_DIR}/folly/io/async") +file(WRITE "${GEN_INCLUDE_DIR}/thrift/lib/cpp/thrift_config.h" "#pragma once\n") +file(WRITE "${GEN_INCLUDE_DIR}/thrift/lib/cpp/util/VarintUtils.h" [=[ +#pragma once +#include +namespace apache::thrift::util { +constexpr uint32_t i32ToZigzag(int32_t n){return (static_cast(n)<<1)^static_cast(n>>31);} constexpr uint64_t i64ToZigzag(int64_t n){return (static_cast(n)<<1)^static_cast(n>>63);} inline int32_t zigzagToI32(uint32_t n){return static_cast((n>>1)^(~(n&1)+1));} inline int64_t zigzagToI64(uint64_t n){return static_cast((n>>1)^(~(n&1)+1));} +} +]=]) +file(WRITE "${GEN_INCLUDE_DIR}/folly/_compat.h" [=[ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define FOLLY_EXPORT +#define FOLLY_PUSH_WARNING +#define FOLLY_POP_WARNING +#define FOLLY_CLANG_DISABLE_WARNING(x) +#define FOLLY_PACK_ATTR __attribute__((__packed__)) +#define FOLLY_ALWAYS_INLINE inline __attribute__((__always_inline__)) +#define FOLLY_NOINLINE __attribute__((__noinline__)) +#if defined(__GNUC__) || defined(__clang__) +#define FOLLY_LIKELY(x) __builtin_expect(!!(x), 1) +#define FOLLY_UNLIKELY(x) __builtin_expect(!!(x), 0) +#else +#define FOLLY_LIKELY(x) (x) +#define FOLLY_UNLIKELY(x) (x) +#endif +#define LIKELY(x) FOLLY_LIKELY(x) +#define UNLIKELY(x) FOLLY_UNLIKELY(x) +namespace folly { +using StringPiece = std::string_view; template class Range{public: constexpr Range():b_(),e_(){} constexpr Range(Iter b, Iter e):b_(b),e_(e){} constexpr Iter begin()const{return b_;} constexpr Iter end()const{return e_;} private: Iter b_,e_;}; class fbstring:public std::string{public: using std::string::string; fbstring(const std::string& v):std::string(v){} fbstring(std::string&& v) noexcept:std::string(std::move(v)){} fbstring& operator=(const std::string& v){std::string::operator=(v);return *this;} fbstring& operator=(std::string&& v) noexcept{std::string::operator=(std::move(v));return *this;} std::string toStdString() const{return *this;}}; template > using fbvector = std::vector; template Out to(Args&&... args){std::ostringstream os; (os << ... << args); return Out(os.str());} template ::value>> constexpr auto to_underlying(Enum v) noexcept{return static_cast>(v);} template inline constexpr bool always_false = false; using std::bit_cast; template constexpr T swapBytes(T v){using U = std::make_unsigned_t; if constexpr (sizeof(T)==1) return v; if constexpr (sizeof(T)==2) return static_cast(__builtin_bswap16(static_cast(v))); if constexpr (sizeof(T)==4) return static_cast(__builtin_bswap32(static_cast(v))); return static_cast(__builtin_bswap64(static_cast(v)));} struct Endian{template static constexpr T big(T v){if constexpr (std::endian::native==std::endian::big) return v; return swapBytes(v);} template static constexpr T little(T v){if constexpr (std::endian::native==std::endian::little) return v; return swapBytes(v);}}; class AsyncSocketException: public std::exception{public: enum AsyncSocketExceptionType{UNKNOWN=0}; AsyncSocketException(AsyncSocketExceptionType t=UNKNOWN,std::string m="",int e=0):type_(t),msg_(std::move(m)),err_(e){} AsyncSocketExceptionType getType() const noexcept{return type_;} int getErrno() const noexcept{return err_;} const char* what() const noexcept override{return msg_.c_str();} private: AsyncSocketExceptionType type_; std::string msg_; int err_;}; +} +]=]) +foreach(hdr Conv.h FBString.h FBVector.h Likely.h Range.h Traits.h lang/Bits.h portability/Time.h portability/SysTime.h io/async/AsyncSocketException.h) + file(WRITE "${GEN_INCLUDE_DIR}/folly/${hdr}" "#pragma once\n#include \n") +endforeach() + +add_executable(fuzz_thrift fuzz_thrift.cpp) +target_include_directories(fuzz_thrift PRIVATE "${GEN_INCLUDE_DIR}" "${THRIFT_SOURCE_ROOT}") +target_compile_options(fuzz_thrift PRIVATE ${THRIFT_FUZZ_SANITIZER_FLAG_LIST}) +target_link_options(fuzz_thrift PRIVATE ${THRIFT_FUZZ_SANITIZER_FLAG_LIST}) diff --git a/thrift/lib/cpp/protocol/fuzz_thrift.cpp b/thrift/lib/cpp/protocol/fuzz_thrift.cpp new file mode 100644 index 00000000000..46c7e3d162c --- /dev/null +++ b/thrift/lib/cpp/protocol/fuzz_thrift.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace apache::thrift::protocol { + +[[noreturn]] void TProtocolException::throwInvalidSkipType(TType /*type*/) { throw TProtocolException(TProtocolException::INVALID_DATA, "invalid skip type while fuzzing protocol input"); } +[[noreturn]] void TProtocolException::throwExceededSizeLimit(size_t /*size*/, size_t /*limit*/) { throw TProtocolException(TProtocolException::SIZE_LIMIT, "size limit exceeded while fuzzing protocol input"); } + +} // namespace apache::thrift::protocol + +namespace apache::thrift::transport { +std::string TTransportException::getDefaultMessage( + TTransportExceptionType /*type*/, std::string&& message) { + return std::move(message); +} +} // namespace apache::thrift::transport + +namespace { +constexpr uint32_t kMaxFields = 256; +constexpr int32_t kStringSizeLimit = 1 << 20; +constexpr int32_t kContainerSizeLimit = 1 << 16; + +class FuzzMemoryTransport : public apache::thrift::transport::TTransport { + public: + FuzzMemoryTransport(const uint8_t* data, size_t size) : data_(data), size_(size), offset_(0) {} + + uint32_t read_virt(uint8_t* out, uint32_t len) override { + const size_t remaining = size_ - offset_; + if (remaining == 0) return 0; + const uint32_t toCopy = static_cast(std::min(len, remaining)); + if (toCopy != 0) { + std::memcpy(out, data_ + offset_, toCopy); + offset_ += toCopy; + } + return toCopy; + } + + uint32_t readAll_virt(uint8_t* out, uint32_t len) override { + uint32_t copied = 0; + while (copied < len) { + const uint32_t n = read_virt(out + copied, len - copied); + if (n == 0) throw std::runtime_error("truncated input"); + copied += n; + } + return copied; + } + + void write_virt(const uint8_t* /*buf*/, uint32_t /*len*/) override {} + + const uint8_t* borrow_virt(uint8_t* /*buf*/, uint32_t* len) override { + if (len == nullptr) return nullptr; + const size_t remaining = size_ - offset_; + if (remaining < *len) return nullptr; + *len = static_cast(remaining); + return data_ + offset_; + } + + void consume_virt(uint32_t len) override { + const size_t remaining = size_ - offset_; + if (len > remaining) throw std::runtime_error("consume past end"); + offset_ += len; + } + + private: + const uint8_t* data_; + size_t size_; + size_t offset_; +}; + +template +void parseStruct(Protocol& p) { + std::string name; + p.readStructBegin(name); + for (uint32_t i = 0; i < kMaxFields; ++i) { + apache::thrift::protocol::TType fieldType = apache::thrift::protocol::T_STOP; + int16_t fieldId = 0; + p.readFieldBegin(name, fieldType, fieldId); + if (fieldType == apache::thrift::protocol::T_STOP) break; + apache::thrift::protocol::skip(p, fieldType, 0); + p.readFieldEnd(); + } + p.readStructEnd(); +} + +template +void fuzzProtocol(const uint8_t* data, size_t size, bool parseAsMessage) { + FuzzMemoryTransport transport(data, size); + Protocol p(&transport); + p.setStringSizeLimit(kStringSizeLimit); + p.setContainerSizeLimit(kContainerSizeLimit); + if (parseAsMessage) { + std::string name; + apache::thrift::protocol::TMessageType messageType = apache::thrift::protocol::T_CALL; + int32_t seqId = 0; + p.readMessageBegin(name, messageType, seqId); + parseStruct(p); + p.readMessageEnd(); + return; + } + parseStruct(p); +} + +template +void runOne(const uint8_t* data, size_t size, bool parseAsMessage) { + try { + fuzzProtocol(data, size, parseAsMessage); + } catch (...) {} +} + +} // namespace + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size <= 1) return 0; + const uint8_t mode = data[0]; + const uint8_t* payload = data + 1; + const size_t payloadSize = size - 1; + runOne>(payload, payloadSize, (mode & 0x1u) != 0); + runOne>(payload, payloadSize, (mode & 0x2u) != 0); + return 0; +}