Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use unordered_set for whitelist lookups #21

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

#define DB_LOCK_TIMEOUT 50

namespace {
std::ostream& operator<<(mysqlpp::quote_type1 m, const peerid_t &peer_id) {
return (m << std::string(std::begin(peer_id), std::end(peer_id)));
}
}

mysql::mysql(config * conf) : u_active(false), t_active(false), p_active(false), s_active(false), tok_active(false) {
logger = spdlog::get("logger");
load_config(conf);
Expand Down Expand Up @@ -210,17 +216,24 @@ void mysql::load_tokens(torrent_list &torrents) {
}


void mysql::load_whitelist(std::vector<std::string> &whitelist) {
void mysql::load_whitelist(std::unordered_set<peerid_t> &whitelist) {
mysqlpp::Query query = conn.query("SELECT peer_id FROM xbt_client_whitelist;");
try {
mysqlpp::StoreQueryResult res = query.store();
size_t num_rows = res.num_rows();
std::lock_guard<std::mutex> wl_lock(whitelist_mutex);
whitelist.clear();
for (size_t i = 0; i<num_rows; i++) {
std::string peer_id;
res[i][0].to_string(peer_id);
whitelist.push_back(peer_id);
auto &cell = res[i][0];
peerid_t id;
const auto expected_len = id.size();
if (cell.size() == expected_len) {
std::copy(std::begin(cell), std::end(cell), std::begin(id));
whitelist.insert(std::move(id));
} else {
logger->warn("Peer ID length in row " + std::to_string(i) +
" not equal to " + std::to_string(expected_len) + ", ignoring");
}
}
} catch (const mysqlpp::BadQuery &er) {
logger->error("Query error in load_whitelist: " + std::string(er.what()));
Expand Down Expand Up @@ -255,7 +268,7 @@ void mysql::record_torrent(const std::string &record) {
update_torrent_buffer += record;
}

void mysql::record_peer(const std::string &record, const std::string &ip, const std::string &peer_id, const std::string &useragent) {
void mysql::record_peer(const std::string &record, const std::string &ip, const peerid_t &peer_id, const std::string &useragent) {
if (update_heavy_peer_buffer != "") {
update_heavy_peer_buffer += ",";
}
Expand All @@ -264,7 +277,7 @@ void mysql::record_peer(const std::string &record, const std::string &ip, const

update_heavy_peer_buffer += q.str();
}
void mysql::record_peer(const std::string &record, const std::string &peer_id) {
void mysql::record_peer(const std::string &record, const peerid_t &peer_id) {
if (update_light_peer_buffer != "") {
update_light_peer_buffer += ",";
}
Expand Down
8 changes: 5 additions & 3 deletions src/db.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
#include <spdlog/spdlog.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <queue>
#include <mutex>
#include "config.h"
#include "ocelot.h"

class mysql {
private:
Expand Down Expand Up @@ -64,13 +66,13 @@ class mysql {
bool connected();
void load_torrents(torrent_list &torrents);
void load_users(user_list &users);
void load_whitelist(std::vector<std::string> &whitelist);
void load_whitelist(std::unordered_set<peerid_t> &whitelist);

void record_user(const std::string &record); // (id,uploaded_change,downloaded_change)
void record_torrent(const std::string &record); // (id,seeders,leechers,snatched_change,balance)
void record_snatch(const std::string &record, const std::string &ip); // (uid,fid,tstamp)
void record_peer(const std::string &record, const std::string &ip, const std::string &peer_id, const std::string &useragent); // (uid,fid,active,peerid,useragent,ip,uploaded,downloaded,upspeed,downspeed,left,timespent,announces,tstamp)
void record_peer(const std::string &record, const std::string &peer_id); // (fid,peerid,timespent,announces,tstamp)
void record_peer(const std::string &record, const std::string &ip, const peerid_t &peer_id, const std::string &useragent); // (uid,fid,active,peerid,useragent,ip,uploaded,downloaded,upspeed,downspeed,left,timespent,announces,tstamp)
void record_peer(const std::string &record, const peerid_t &peer_id); // (fid,peerid,timespent,announces,tstamp)
void record_token(const std::string &record);

void flush();
Expand Down
32 changes: 0 additions & 32 deletions src/misc_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,6 @@ std::string inttostr(const int i) {
return str;
}

std::string hex_decode(const std::string &in) {
std::string out;
out.reserve(20);
unsigned int in_length = in.length();
for (unsigned int i = 0; i < in_length; i++) {
unsigned char x = '0';
if (in[i] == '%' && (i + 2) < in_length) {
i++;
if (in[i] >= 'a' && in[i] <= 'f') {
x = static_cast<unsigned char>((in[i]-87) << 4);
} else if (in[i] >= 'A' && in[i] <= 'F') {
x = static_cast<unsigned char>((in[i]-55) << 4);
} else if (in[i] >= '0' && in[i] <= '9') {
x = static_cast<unsigned char>((in[i]-48) << 4);
}

i++;
if (in[i] >= 'a' && in[i] <= 'f') {
x += static_cast<unsigned char>(in[i]-87);
} else if (in[i] >= 'A' && in[i] <= 'F') {
x += static_cast<unsigned char>(in[i]-55);
} else if (in[i] >= '0' && in[i] <= '9') {
x += static_cast<unsigned char>(in[i]-48);
}
} else {
x = in[i];
}
out.push_back(x);
}
return out;
}

std::string bintohex(const std::string &in) {
std::string out;
size_t length = in.length();
Expand Down
48 changes: 47 additions & 1 deletion src/misc_functions.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
#ifndef MISC_FUNCTIONS__H
#define MISC_FUNCTIONS__H
#include <string>
#include <array>

int32_t strtoint32(const std::string& str);
int64_t strtoint64(const std::string& str);
std::string inttostr(int i);
std::string hex_decode(const std::string &in);

inline std::uint8_t hexchar_to_bin(char in) {
auto out = static_cast<std::uint8_t>(in);
if (in >= 'a' && in <= 'f') {
return out - 'a' + 10;
} else if (in >= 'A' && in <= 'F') {
return out - 'A' + 10;
} else if (in >= '0' && in <= '9') {
return out - '0';
} else {
return '0';
}
}

template<typename Oiter, typename F>
inline bool hex_decode_impl(const std::string& in, Oiter out, F out_is_end)
{
unsigned int i;
for (i = 0; i < in.length() && !out_is_end(out); i++, out++) {
unsigned char x = '0';
if (in[i] == '%' && (i + 2) < in.length()) {
x = (hexchar_to_bin(in[i + 1]) << 4) | hexchar_to_bin(in[i + 2]);
i += 2;
} else {
x = in[i];
}
*out = x;
}
return (i == in.length());
}

template<std::size_t N>
inline bool hex_decode(std::array<std::uint8_t, N> &out, const std::string &in) {
auto end = std::end(out);
return hex_decode_impl(in, std::begin(out),
[=](typename std::array<std::uint8_t, N>::iterator it) {
return it == end; });
}

inline std::string hex_decode(const std::string &in) {
std::string out;
out.reserve(20);
hex_decode_impl(in, std::back_inserter(out),
[](std::back_insert_iterator<std::string>) { return false; });
return out;
}
std::string bintohex(const std::string &in);

#endif
3 changes: 2 additions & 1 deletion src/ocelot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "db.h"
#include "worker.h"
#include "events.h"
#include <unordered_set>

static connection_mother *mother;
static worker *work;
Expand Down Expand Up @@ -157,7 +158,7 @@ int main(int argc, char **argv) {

user_list users_list;
torrent_list torrents_list;
std::vector<std::string> whitelist;
std::unordered_set<peerid_t> whitelist;
db->load_users(users_list);
db->load_torrents(torrents_list);
db->load_whitelist(whitelist);
Expand Down
36 changes: 35 additions & 1 deletion src/ocelot.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,32 @@
#include <memory>
#include <atomic>
#include <time.h>
#include <array>
#include <cstdint>

#include <spdlog/fmt/ostr.h>

typedef uint32_t torid_t;
typedef uint32_t userid_t;

struct peerid_t : public std::array<std::uint8_t, 20> {
template<typename OStream>
friend OStream &operator<<(OStream &os, const peerid_t &peerid)
{
fmt::print(os, "{:.20}", peerid.data());
return os;
}
};

namespace std {
template <> struct hash<peerid_t> {
std::size_t operator()(const peerid_t &prid) const noexcept {
auto ptr = reinterpret_cast<const std::uint32_t *>(prid.data());
return ptr[0] ^ ptr[1] ^ ptr[2] ^ ptr[3] ^ ptr[4];
}
};
}

class user;
typedef std::shared_ptr<user> user_ptr;

Expand Down Expand Up @@ -87,7 +109,19 @@ typedef struct {

typedef std::unordered_map<std::string, torrent> torrent_list;
typedef std::unordered_map<std::string, user_ptr> user_list;
typedef std::unordered_map<std::string, std::string> params_type;

struct params_type : public std::unordered_map<std::string, std::string> {
template<std::size_t N>
bool get_array(std::array<std::uint8_t, N> &out, const std::string &key) const {
auto it = find(key);
if (it != end() && it->second.length() == N) {
std::copy(std::begin(it->second), std::end(it->second), std::begin(out));
return true;
} else {
return false;
}
}
};

struct stats_t {
std::atomic<uint32_t> open_connections;
Expand Down
69 changes: 31 additions & 38 deletions src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "user.h"

//---------- Worker - does stuff with input
worker::worker(config * conf_obj, torrent_list &torrents, user_list &users, std::vector<std::string> &_whitelist, mysql * db_obj, site_comm * sc) :
worker::worker(config * conf_obj, torrent_list &torrents, user_list &users, std::unordered_set<peerid_t> &_whitelist, mysql * db_obj, site_comm * sc) :
conf(conf_obj), db(db_obj), s_comm(sc), torrents_list(torrents), users_list(users), whitelist(_whitelist), status(OPEN), reaper_active(false)
{
logger = spdlog::get("logger");
Expand Down Expand Up @@ -295,30 +295,25 @@ std::string worker::announce(const std::string &input, torrent &tor, user_ptr &u
if (peer_id_iterator == params.end()) {
return error("No peer ID", client_opts);
}
const std::string peer_id = hex_decode(peer_id_iterator->second);
if (peer_id.length() != 20) {
peerid_t peer_id;
if (!hex_decode(peer_id, peer_id_iterator->second)) {
return error("Invalid peer ID", client_opts);
}

std::unique_lock<std::mutex> wl_lock(db->whitelist_mutex);
if (whitelist.size() > 0) {
bool found = false; // Found client in whitelist?
for (unsigned int i = 0; i < whitelist.size(); i++) {
if (peer_id.compare(0, whitelist[i].length(), whitelist[i]) == 0) {
found = true;
break;
{
std::unique_lock<std::mutex> wl_lock(db->whitelist_mutex);
if (!whitelist.empty()) {
auto it = whitelist.find(peer_id);
if (it == std::end(whitelist)) {
return error("Your client is not on the whitelist", client_opts);
}
}
if (!found) {
return error("Your client is not on the whitelist", client_opts);
}
}
wl_lock.unlock();

std::stringstream peer_key_stream;
peer_key_stream << peer_id[12 + (tor.id & 7)] // "Randomize" the element order in the peer map by prefixing with a peer id byte
<< userid // Include user id in the key to lower chance of peer id collisions
<< peer_id;
<< userid; // Include user id in the key to lower chance of peer id collisions
peer_key_stream.write(reinterpret_cast<const char*>(peer_id.data()), peer_id.size());
const std::string peer_key(peer_key_stream.str());

if (params["event"] == "completed") {
Expand Down Expand Up @@ -952,32 +947,30 @@ std::string worker::update(params_type &params, client_opts_t &client_opts) {
logger->info("Updated user " + passkey);
}
} else if (params["action"] == "add_whitelist") {
std::string peer_id = params["peer_id"];
std::lock_guard<std::mutex> wl_lock(db->whitelist_mutex);
whitelist.push_back(peer_id);
logger->info("Whitelisted " + peer_id);
peerid_t peer_id;
if (params.get_array(peer_id, "peer_id")) {
std::lock_guard<std::mutex> wl_lock(db->whitelist_mutex);
whitelist.insert(peer_id);
logger->info("Whitelisted {}", peer_id);
}
} else if (params["action"] == "remove_whitelist") {
std::string peer_id = params["peer_id"];
std::lock_guard<std::mutex> wl_lock(db->whitelist_mutex);
for (unsigned int i = 0; i < whitelist.size(); i++) {
if (whitelist[i].compare(peer_id) == 0) {
whitelist.erase(whitelist.begin() + i);
break;
}
peerid_t peer_id;
if (params.get_array(peer_id, "peer_id")) {
std::lock_guard<std::mutex> wl_lock(db->whitelist_mutex);
whitelist.erase(peer_id);
logger->info("De-whitelisted {}", peer_id);
}
logger->info("De-whitelisted " + peer_id);
} else if (params["action"] == "edit_whitelist") {
std::string new_peer_id = params["new_peer_id"];
std::string old_peer_id = params["old_peer_id"];
std::lock_guard<std::mutex> wl_lock(db->whitelist_mutex);
for (unsigned int i = 0; i < whitelist.size(); i++) {
if (whitelist[i].compare(old_peer_id) == 0) {
whitelist.erase(whitelist.begin() + i);
break;
}
peerid_t new_peer_id, old_peer_id;
if (params.get_array(new_peer_id, "new_peer_id") &&
params.get_array(old_peer_id, "old_peer_id")) {
std::lock_guard<std::mutex> wl_lock(db->whitelist_mutex);
whitelist.erase(old_peer_id);
whitelist.insert(new_peer_id);
logger->info("Edited whitelist item from {} to {}", old_peer_id, new_peer_id);
} else {
logger->warn("edit_whitelist request received with invalid parameters");
}
whitelist.push_back(new_peer_id);
logger->info("Edited whitelist item from " + old_peer_id + " to " + new_peer_id);
} else if (params["action"] == "update_announce_interval") {
const std::string interval = params["new_announce_interval"];
conf->set("announce_interval", interval);
Expand Down
5 changes: 3 additions & 2 deletions src/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>
#include <list>
#include <unordered_map>
#include <unordered_set>
#include <iostream>
#include <mutex>
#include <ctime>
Expand All @@ -21,7 +22,7 @@ class worker {
site_comm * s_comm;
torrent_list &torrents_list;
user_list &users_list;
std::vector<std::string> &whitelist;
std::unordered_set<peerid_t> &whitelist;
std::unordered_map<std::string, del_message> del_reasons;
tracker_status status;
bool reaper_active;
Expand All @@ -46,7 +47,7 @@ class worker {
inline bool peer_is_visible(user_ptr &u, peer *p);

public:
worker(config * conf_obj, torrent_list &torrents, user_list &users, std::vector<std::string> &_whitelist, mysql * db_obj, site_comm * sc);
worker(config * conf_obj, torrent_list &torrents, user_list &users, std::unordered_set<peerid_t> &_whitelist, mysql * db_obj, site_comm * sc);
void reload_config(config * conf);
std::string work(const std::string &input, std::string &ip, client_opts_t &client_opts);
std::string announce(const std::string &input, torrent &tor, user_ptr &u, params_type &params, params_type &headers, std::string &ip, client_opts_t &client_opts);
Expand Down
Loading